""" Session management for the AI Voice Bot server. This module handles session lifecycle, persistence, and cleanup operations. """ from __future__ import annotations import json import os import time import threading import secrets import asyncio from typing import Optional, List, Dict, Any from fastapi import WebSocket from pydantic import ValidationError # Import shared models try: # Try relative import first (when running as part of the package) from ...shared.models import ( SessionSaved, LobbySaved, SessionsPayload, NamePasswordRecord, ) except ImportError: try: # Try absolute import (when running directly) import sys import os sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ) from shared.models import ( SessionSaved, LobbySaved, SessionsPayload, NamePasswordRecord, ) except ImportError: raise ImportError( f"Failed to import shared models: {e}. Ensure shared/models.py is accessible and PYTHONPATH is correctly set." ) from core.lobby_manager import Lobby from shared.logger import logger # Import WebRTC signaling for peer management from websocket.webrtc_signaling import WebRTCSignalingHandlers # Use try/except for importing events to handle both relative and absolute imports try: from ..models.events import ( event_bus, SessionDisconnected, UserNameChanged, SessionJoinedLobby, SessionLeftLobby, ) except ImportError: try: from models.events import ( event_bus, SessionDisconnected, UserNameChanged, SessionJoinedLobby, SessionLeftLobby, ) except ImportError: # Create dummy event system for standalone testing class DummyEventBus: async def publish(self, event): pass event_bus = DummyEventBus() class SessionDisconnected: pass class UserNameChanged: pass class SessionJoinedLobby: pass class SessionLeftLobby: pass class SessionConfig: """Configuration class for session management""" ANONYMOUS_SESSION_TIMEOUT = int( os.getenv("ANONYMOUS_SESSION_TIMEOUT", "60") ) # 1 minute DISPLACED_SESSION_TIMEOUT = int( os.getenv("DISPLACED_SESSION_TIMEOUT", "10800") ) # 3 hours CLEANUP_INTERVAL = int(os.getenv("CLEANUP_INTERVAL", "300")) # 5 minutes MAX_SESSIONS_PER_CLEANUP = int( os.getenv("MAX_SESSIONS_PER_CLEANUP", "100") ) # Circuit breaker SESSION_VALIDATION_INTERVAL = int( os.getenv("SESSION_VALIDATION_INTERVAL", "1800") ) # 30 minutes class Session: """Individual session representing a user or bot connection""" def __init__( self, id: str, bot_instance_id: Optional[str] = None, has_media: bool = True, from_disk: bool = False, ): logger.info( f"Instantiating new session {id} (bot: {True if bot_instance_id else False}, media: {has_media})" ) self.id = id self.short = id[:8] self.name = "" self.lobbies: List[Any] = [] # List of lobby objects this session is in self.lobby_peers: Dict[ str, List[str] ] = {} # lobby ID -> list of peer session IDs self.ws: Optional[WebSocket] = None self.created_at = time.time() self.last_used = time.time() self.displaced_at: Optional[float] = None # When name was taken over self.has_media = has_media # Whether this session provides audio/video streams self.bot_run_id: Optional[str] = None # Bot run ID for tracking self.bot_provider_id: Optional[str] = None # Bot provider ID self.bot_instance_id = ( bot_instance_id # Bot instance ID if this is a bot session ) self.bot_instance_id: Optional[str] = None # Bot instance ID for tracking self.session_lock = threading.RLock() # Instance-level lock def is_bot(self) -> bool: """Check if this session represents a bot""" return bool(self.bot_run_id or self.bot_provider_id or self.bot_instance_id) def getName(self) -> str: with self.session_lock: return f"{self.short}:{self.name if self.name else '[ ---- ]'}" def setName(self, name: str): with self.session_lock: old_name = self.name self.name = name self.update_last_used() # Get lobby IDs for event lobby_ids = [lobby.id for lobby in self.lobbies] # Publish name change event (don't await here to avoid blocking) asyncio.create_task( event_bus.publish( UserNameChanged( session_id=self.id, old_name=old_name, new_name=name, lobby_ids=lobby_ids, ) ) ) def update_last_used(self): """Update the last_used timestamp""" with self.session_lock: self.last_used = time.time() def mark_displaced(self): """Mark this session as having its name taken over""" with self.session_lock: self.displaced_at = time.time() async def join_lobby(self, lobby: Lobby): """Join a lobby and establish WebRTC peer connections""" with self.session_lock: if lobby not in self.lobbies: self.lobbies.append(lobby) # Initialize lobby_peers for this lobby if not exists if lobby.id not in self.lobby_peers: self.lobby_peers[lobby.id] = [] # Add to lobby first await lobby.addSession(self) # Get existing peer sessions in this lobby for WebRTC setup peer_sessions: list[Session] = [] for session in lobby.sessions.values(): if ( session.id != self.id and session.ws ): # Don't include self and only connected sessions peer_sessions.append(session) # Establish WebRTC peer connections with existing sessions using signaling handlers for peer_session in peer_sessions: await WebRTCSignalingHandlers.handle_add_peer(self, peer_session, lobby) # Publish join event await event_bus.publish( SessionJoinedLobby( session_id=self.id, lobby_id=lobby.id, session_name=self.name or self.short, ) ) async def leave_lobby(self, lobby: Lobby): """Leave a lobby and clean up WebRTC peer connections""" # Get peer sessions before removing from lobby peer_sessions: list[Session] = [] if lobby.id in self.lobby_peers: for peer_id in self.lobby_peers[lobby.id]: peer_session = None # Find peer session in lobby for session in lobby.sessions.values(): if session.id == peer_id: peer_session = session break if peer_session and peer_session.ws: peer_sessions.append(peer_session) # Handle WebRTC peer disconnections using signaling handlers for peer_session in peer_sessions: await WebRTCSignalingHandlers.handle_remove_peer(self, peer_session, lobby) # Clean up our lobby_peers and lobbies with self.session_lock: if lobby in self.lobbies: self.lobbies.remove(lobby) if lobby.id in self.lobby_peers: del self.lobby_peers[lobby.id] # Remove from lobby await lobby.removeSession(self) # Publish leave event await event_bus.publish( SessionLeftLobby( session_id=self.id, lobby_id=lobby.id, session_name=self.name or self.short, ) ) def model_dump(self) -> Dict[str, Any]: """Convert session to dictionary format for API responses""" with self.session_lock: data: Dict[str, Any] = { "id": self.id, "name": self.name or "", "bot_instance_id": self.bot_instance_id, "has_media": self.has_media, "created_at": self.created_at, "last_used": self.last_used, } return data def to_saved(self) -> SessionSaved: """Convert session to saved format for persistence""" with self.session_lock: lobbies_list: List[LobbySaved] = [ LobbySaved(id=lobby.id, name=lobby.name, private=lobby.private) for lobby in self.lobbies ] return SessionSaved( id=self.id, name=self.name or "", lobbies=lobbies_list, created_at=self.created_at, last_used=self.last_used, displaced_at=self.displaced_at, has_media=self.has_media, bot_run_id=self.bot_run_id, bot_provider_id=self.bot_provider_id, bot_instance_id=self.bot_instance_id, ) class SessionManager: """Manages all sessions and their lifecycle""" def __init__(self, save_file: str = "sessions.json"): self._instances: List[Session] = [] self._save_file = save_file self._loaded = False self.lock = threading.RLock() # Thread safety for class-level operations # Background task management self.cleanup_task_running = False self.cleanup_task: Optional[asyncio.Task[None]] = None self.validation_task_running = False self.validation_task: Optional[asyncio.Task[None]] = None def create_session( self, session_id: Optional[str] = None, bot_instance_id: Optional[str] = None, has_media: bool = True, ) -> Session: """Create a new session with given or generated ID""" if not session_id: session_id = secrets.token_hex(16) with self.lock: # Check if session already exists (now inside the lock for atomicity) existing_session = self.get_session(session_id) if existing_session: logger.debug( f"Session {session_id[:8]} already exists, returning existing session" ) return existing_session # Create new session session = Session( session_id, bot_instance_id=bot_instance_id, has_media=has_media ) self._instances.append(session) self.save() return session def get_or_create_session( self, session_id: Optional[str] = None, bot_instance_id: Optional[str] = None, has_media: bool = True, ) -> Session: """Get existing session or create a new one""" if session_id: existing_session = self.get_session(session_id) if existing_session: return existing_session return self.create_session( session_id, bot_instance_id=bot_instance_id, has_media=has_media ) def get_session(self, session_id: str) -> Optional[Session]: """Get session by ID""" with self.lock: if not self._loaded: self.load() logger.info(f"Loaded {len(self._instances)} sessions from disk...") for s in self._instances: if s.id == session_id: return s return None def get_session_by_name(self, name: str) -> Optional[Session]: """Get session by name""" if not name: return None lname = name.lower() with self.lock: for s in self._instances: with s.session_lock: if s.name and s.name.lower() == lname: return s return None def is_unique_name(self, name: str) -> bool: """Check if a name is unique across all sessions""" if not name: return False with self.lock: for s in self._instances: with s.session_lock: if s.name.lower() == name.lower(): return False return True def remove_session(self, session: Session): """Remove a session from the manager""" with self.lock: if session in self._instances: self._instances.remove(session) # Publish disconnect event lobby_ids = [lobby.id for lobby in session.lobbies] asyncio.create_task( event_bus.publish( SessionDisconnected( session_id=session.id, session_name=session.name or session.short, lobby_ids=lobby_ids, ) ) ) def save(self): """Save all sessions to disk""" try: with self.lock: sessions_list: List[SessionSaved] = [] for s in self._instances: # Skip bot sessions - they should not be persisted # Bot sessions are managed by the voicebot service lifecycle if s.bot_instance_id is not None or s.bot_run_id is not None or s.bot_provider_id is not None: continue sessions_list.append(s.to_saved()) # Note: We'll need to handle name_passwords separately or inject it # For now, create empty dict - this will be handled by AuthManager saved_pw: Dict[str, NamePasswordRecord] = {} payload_model = SessionsPayload( sessions=sessions_list, name_passwords=saved_pw ) payload = payload_model.model_dump() # Atomic write using temp file temp_file = self._save_file + ".tmp" with open(temp_file, "w") as f: json.dump(payload, f, indent=2) # Atomic rename os.rename(temp_file, self._save_file) logger.info(f"Saved {len(sessions_list)} sessions to {self._save_file}") except Exception as e: logger.error(f"Failed to save sessions: {e}") # Clean up temp file if it exists try: if os.path.exists(self._save_file + ".tmp"): os.remove(self._save_file + ".tmp") except Exception: pass def load(self): """Load sessions from disk""" if not os.path.exists(self._save_file): logger.info(f"No session save file found: {self._save_file}") return try: with open(self._save_file, "r") as f: raw = json.load(f) except Exception as e: logger.error(f"Failed to read session save file: {e}") return try: payload = SessionsPayload.model_validate(raw) except ValidationError as e: logger.exception(f"Failed to validate sessions payload: {e}") return current_time = time.time() sessions_loaded = 0 sessions_expired = 0 with self.lock: for s_saved in payload.sessions: # Check if this session should be expired during loading created_at = getattr(s_saved, "created_at", time.time()) last_used = getattr(s_saved, "last_used", time.time()) displaced_at = getattr(s_saved, "displaced_at", None) name = s_saved.name or "" # Apply same removal criteria as cleanup_old_sessions should_expire = self._should_remove_session_static( name, None, created_at, last_used, displaced_at, current_time ) if should_expire: sessions_expired += 1 logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load") continue # Check if session already exists in _instances (deduplication) existing_session = None for existing in self._instances: if existing.id == s_saved.id: existing_session = existing break if existing_session: logger.debug( f"Session {s_saved.id[:8]} already loaded, skipping duplicate" ) continue session = Session( s_saved.id, bot_instance_id=s_saved.bot_instance_id, has_media=s_saved.has_media, from_disk=True, ) session.name = name session.created_at = created_at session.last_used = last_used session.displaced_at = displaced_at session.bot_run_id = s_saved.bot_run_id session.bot_provider_id = s_saved.bot_provider_id # Note: Lobby restoration will be handled by LobbyManager self._instances.append(session) sessions_loaded += 1 logger.info(f"Loaded {sessions_loaded} sessions from {self._save_file}") if sessions_expired > 0: logger.info(f"Expired {sessions_expired} old sessions during load") self.save() # Mark as loaded to prevent duplicate loads self._loaded = True @staticmethod def _should_remove_session_static( name: str, ws: Optional[WebSocket], created_at: float, last_used: float, displaced_at: Optional[float], current_time: float, ) -> bool: """Static method to determine if a session should be removed""" # Rule 1: Delete sessions with no active connection and no name that are older than threshold if ( not ws and not name and current_time - created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT ): return True # Rule 2: Delete inactive sessions that had their nick taken over and haven't been used recently if ( not ws and displaced_at is not None and current_time - last_used > SessionConfig.DISPLACED_SESSION_TIMEOUT ): return True return False def cleanup_old_sessions(self) -> int: """Clean up old/stale sessions and return count of removed sessions""" current_time = time.time() removed_count = 0 with self.lock: sessions_to_remove: list[Session] = [] for session in self._instances: with session.session_lock: if self._should_remove_session_static( session.name, session.ws, session.created_at, session.last_used, session.displaced_at, current_time, ): sessions_to_remove.append(session) if ( len(sessions_to_remove) >= SessionConfig.MAX_SESSIONS_PER_CLEANUP ): break # Remove sessions for session in sessions_to_remove: try: # Clean up websocket if open if session.ws: asyncio.create_task(session.ws.close()) # Remove from lobbies (will be handled by lobby manager events) for lobby in session.lobbies[:]: asyncio.create_task(session.leave_lobby(lobby)) self._instances.remove(session) removed_count += 1 logger.info(f"Cleaned up session {session.getName()}") except Exception as e: logger.warning( f"Error cleaning up session {session.getName()}: {e}" ) if removed_count > 0: self.save() return removed_count async def start_background_tasks(self): """Start background cleanup and validation tasks""" logger.info("Starting session background tasks...") self.cleanup_task_running = True self.validation_task_running = True self.cleanup_task = asyncio.create_task(self._periodic_cleanup()) self.validation_task = asyncio.create_task(self._periodic_validation()) logger.info("Session background tasks started") async def stop_background_tasks(self): """Stop background tasks gracefully""" logger.info("Shutting down session background tasks...") self.cleanup_task_running = False self.validation_task_running = False # Cancel tasks for task in [self.cleanup_task, self.validation_task]: if task: task.cancel() try: await task except asyncio.CancelledError: pass # Clean up all sessions gracefully await self._cleanup_all_sessions() logger.info("Session background tasks stopped") async def _periodic_cleanup(self): """Background task to periodically clean up old sessions""" cleanup_errors = 0 max_consecutive_errors = 5 while self.cleanup_task_running: try: removed_count = self.cleanup_old_sessions() if removed_count > 0: logger.info( f"Periodic cleanup removed {removed_count} old sessions" ) cleanup_errors = 0 # Reset error counter on success # Run cleanup at configured interval await asyncio.sleep(SessionConfig.CLEANUP_INTERVAL) except Exception as e: cleanup_errors += 1 logger.error( f"Error in session cleanup task (attempt {cleanup_errors}): {e}" ) if cleanup_errors >= max_consecutive_errors: logger.error( f"Too many consecutive cleanup errors ({cleanup_errors}), stopping cleanup task" ) break # Exponential backoff on errors await asyncio.sleep(min(60 * cleanup_errors, 300)) async def _periodic_validation(self): """Background task to periodically validate session integrity""" while self.validation_task_running: try: issues = self.validate_session_integrity() if issues: logger.warning( f"Session integrity issues found: {len(issues)} issues" ) for issue in issues[:10]: # Log first 10 issues logger.warning(f"Integrity issue: {issue}") await asyncio.sleep(SessionConfig.SESSION_VALIDATION_INTERVAL) except Exception as e: logger.error(f"Error in session validation task: {e}") await asyncio.sleep(300) # Wait 5 minutes before retrying on error def validate_session_integrity(self) -> List[str]: """Validate session integrity and return list of issues""" issues = [] with self.lock: for session in self._instances: with session.session_lock: # Check for sessions with invalid state if not session.id: issues.append(f"Session with empty ID: {session}") if session.created_at > time.time(): issues.append( f"Session {session.getName()} has future creation time" ) if session.last_used > time.time(): issues.append( f"Session {session.getName()} has future last_used time" ) # Check for duplicate names if session.name: count = sum( 1 for s in self._instances if s.name and s.name.lower() == session.name.lower() ) if count > 1: issues.append( f"Duplicate name '{session.name}' found in {count} sessions" ) return issues async def _cleanup_all_sessions(self): """Clean up all sessions during shutdown""" with self.lock: for session in self._instances[:]: try: if session.ws: await session.ws.close() except Exception as e: logger.warning( f"Error closing WebSocket for {session.getName()}: {e}" ) logger.info("All sessions cleaned up") def get_all_sessions(self) -> List[Session]: """Get all sessions (for admin/debugging purposes)""" with self.lock: return self._instances.copy() def get_session_count(self) -> int: """Get total session count""" with self.lock: return len(self._instances)