""" Session management for the AI Voice Bot server. This module handles session lifecycle, persistence, and cleanup operations. """ from __future__ import annotations import json import os import time import threading import secrets import asyncio from typing import Optional, List, Dict, Any from fastapi import WebSocket from pydantic import ValidationError # Import shared models try: # Try relative import first (when running as part of the package) from ...shared.models import SessionSaved, LobbySaved, SessionsPayload, NamePasswordRecord except ImportError: try: # Try absolute import (when running directly) import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from shared.models import SessionSaved, LobbySaved, SessionsPayload, NamePasswordRecord except ImportError: # Fallback: create minimal models for testing from pydantic import BaseModel class SessionSaved(BaseModel): id: str name: str = "" protected: bool = False is_bot: bool = False has_media: bool = True bot_run_id: Optional[str] = None lobbies: List[str] = [] class LobbySaved(BaseModel): name: str private: bool = False class SessionsPayload(BaseModel): sessions: List[SessionSaved] class NamePasswordRecord(BaseModel): name: str password: str from logger import logger # Import WebRTC signaling for peer management from websocket.webrtc_signaling import WebRTCSignalingHandlers # Use try/except for importing events to handle both relative and absolute imports try: from ..models.events import event_bus, SessionDisconnected, UserNameChanged, SessionJoinedLobby, SessionLeftLobby except ImportError: try: from models.events import event_bus, SessionDisconnected, UserNameChanged, SessionJoinedLobby, SessionLeftLobby except ImportError: # Create dummy event system for standalone testing class DummyEventBus: async def publish(self, event): pass event_bus = DummyEventBus() class SessionDisconnected: pass class UserNameChanged: pass class SessionJoinedLobby: pass class SessionLeftLobby: pass class SessionConfig: """Configuration class for session management""" ANONYMOUS_SESSION_TIMEOUT = int( os.getenv("ANONYMOUS_SESSION_TIMEOUT", "60") ) # 1 minute DISPLACED_SESSION_TIMEOUT = int( os.getenv("DISPLACED_SESSION_TIMEOUT", "10800") ) # 3 hours CLEANUP_INTERVAL = int(os.getenv("CLEANUP_INTERVAL", "300")) # 5 minutes MAX_SESSIONS_PER_CLEANUP = int( os.getenv("MAX_SESSIONS_PER_CLEANUP", "100") ) # Circuit breaker SESSION_VALIDATION_INTERVAL = int( os.getenv("SESSION_VALIDATION_INTERVAL", "1800") ) # 30 minutes class Session: """Individual session representing a user or bot connection""" def __init__(self, id: str, is_bot: bool = False, has_media: bool = True): logger.info( f"Instantiating new session {id} (bot: {is_bot}, media: {has_media})" ) self.id = id self.short = id[:8] self.name = "" self.lobbies: List[Any] = [] # List of lobby objects this session is in self.lobby_peers: Dict[str, List[str]] = {} # lobby ID -> list of peer session IDs self.ws: Optional[WebSocket] = None self.created_at = time.time() self.last_used = time.time() self.displaced_at: Optional[float] = None # When name was taken over self.is_bot = is_bot # Whether this session represents a bot self.has_media = has_media # Whether this session provides audio/video streams self.bot_run_id: Optional[str] = None # Bot run ID for tracking self.bot_provider_id: Optional[str] = None # Bot provider ID self.session_lock = threading.RLock() # Instance-level lock def getName(self) -> str: with self.session_lock: return f"{self.short}:{self.name if self.name else '[ ---- ]'}" def setName(self, name: str): with self.session_lock: old_name = self.name self.name = name self.update_last_used() # Get lobby IDs for event lobby_ids = [lobby.id for lobby in self.lobbies] # Publish name change event (don't await here to avoid blocking) asyncio.create_task(event_bus.publish(UserNameChanged( session_id=self.id, old_name=old_name, new_name=name, lobby_ids=lobby_ids ))) def update_last_used(self): """Update the last_used timestamp""" with self.session_lock: self.last_used = time.time() def mark_displaced(self): """Mark this session as having its name taken over""" with self.session_lock: self.displaced_at = time.time() async def join_lobby(self, lobby): """Join a lobby and establish WebRTC peer connections""" with self.session_lock: if lobby not in self.lobbies: self.lobbies.append(lobby) # Initialize lobby_peers for this lobby if not exists if lobby.id not in self.lobby_peers: self.lobby_peers[lobby.id] = [] # Add to lobby first await lobby.addSession(self) # Get existing peer sessions in this lobby for WebRTC setup peer_sessions = [] for session in lobby.sessions.values(): if ( session.id != self.id and session.ws ): # Don't include self and only connected sessions peer_sessions.append(session) # Establish WebRTC peer connections with existing sessions using signaling handlers for peer_session in peer_sessions: await WebRTCSignalingHandlers.handle_add_peer(self, peer_session, lobby) # Publish join event await event_bus.publish(SessionJoinedLobby( session_id=self.id, lobby_id=lobby.id, session_name=self.name or self.short )) async def leave_lobby(self, lobby): """Leave a lobby and clean up WebRTC peer connections""" # Get peer sessions before removing from lobby peer_sessions = [] if lobby.id in self.lobby_peers: for peer_id in self.lobby_peers[lobby.id]: peer_session = None # Find peer session in lobby for session in lobby.sessions.values(): if session.id == peer_id: peer_session = session break if peer_session and peer_session.ws: peer_sessions.append(peer_session) # Handle WebRTC peer disconnections using signaling handlers for peer_session in peer_sessions: await WebRTCSignalingHandlers.handle_remove_peer(self, peer_session, lobby) # Clean up our lobby_peers and lobbies with self.session_lock: if lobby in self.lobbies: self.lobbies.remove(lobby) if lobby.id in self.lobby_peers: del self.lobby_peers[lobby.id] # Remove from lobby await lobby.removeSession(self) # Publish leave event await event_bus.publish(SessionLeftLobby( session_id=self.id, lobby_id=lobby.id, session_name=self.name or self.short )) def to_saved(self) -> SessionSaved: """Convert session to saved format for persistence""" with self.session_lock: lobbies_list: List[LobbySaved] = [ LobbySaved( id=lobby.id, name=lobby.name, private=lobby.private ) for lobby in self.lobbies ] return SessionSaved( id=self.id, name=self.name or "", lobbies=lobbies_list, created_at=self.created_at, last_used=self.last_used, displaced_at=self.displaced_at, is_bot=self.is_bot, has_media=self.has_media, bot_run_id=self.bot_run_id, bot_provider_id=self.bot_provider_id, ) class SessionManager: """Manages all sessions and their lifecycle""" def __init__(self, save_file: str = "sessions.json"): self._instances: List[Session] = [] self._save_file = save_file self._loaded = False self.lock = threading.RLock() # Thread safety for class-level operations # Background task management self.cleanup_task_running = False self.cleanup_task: Optional[asyncio.Task] = None self.validation_task_running = False self.validation_task: Optional[asyncio.Task] = None def create_session(self, session_id: Optional[str] = None, is_bot: bool = False, has_media: bool = True) -> Session: """Create a new session with given or generated ID""" if not session_id: session_id = secrets.token_hex(16) with self.lock: # Check if session already exists (now inside the lock for atomicity) existing_session = self.get_session(session_id) if existing_session: logger.debug( f"Session {session_id[:8]} already exists, returning existing session" ) return existing_session # Create new session session = Session(session_id, is_bot=is_bot, has_media=has_media) self._instances.append(session) self.save() return session def get_or_create_session(self, session_id: Optional[str] = None, is_bot: bool = False, has_media: bool = True) -> Session: """Get existing session or create a new one""" if session_id: existing_session = self.get_session(session_id) if existing_session: return existing_session return self.create_session(session_id, is_bot=is_bot, has_media=has_media) def get_session(self, session_id: str) -> Optional[Session]: """Get session by ID""" with self.lock: if not self._loaded: self.load() logger.info(f"Loaded {len(self._instances)} sessions from disk...") for s in self._instances: if s.id == session_id: return s return None def get_session_by_name(self, name: str) -> Optional[Session]: """Get session by name""" if not name: return None lname = name.lower() with self.lock: for s in self._instances: with s.session_lock: if s.name and s.name.lower() == lname: return s return None def is_unique_name(self, name: str) -> bool: """Check if a name is unique across all sessions""" if not name: return False with self.lock: for s in self._instances: with s.session_lock: if s.name.lower() == name.lower(): return False return True def remove_session(self, session: Session): """Remove a session from the manager""" with self.lock: if session in self._instances: self._instances.remove(session) # Publish disconnect event lobby_ids = [lobby.id for lobby in session.lobbies] asyncio.create_task(event_bus.publish(SessionDisconnected( session_id=session.id, session_name=session.name or session.short, lobby_ids=lobby_ids ))) def save(self): """Save all sessions to disk""" try: with self.lock: sessions_list: List[SessionSaved] = [] for s in self._instances: sessions_list.append(s.to_saved()) # Note: We'll need to handle name_passwords separately or inject it # For now, create empty dict - this will be handled by AuthManager saved_pw: Dict[str, NamePasswordRecord] = {} payload_model = SessionsPayload( sessions=sessions_list, name_passwords=saved_pw ) payload = payload_model.model_dump() # Atomic write using temp file temp_file = self._save_file + ".tmp" with open(temp_file, "w") as f: json.dump(payload, f, indent=2) # Atomic rename os.rename(temp_file, self._save_file) logger.info( f"Saved {len(sessions_list)} sessions to {self._save_file}" ) except Exception as e: logger.error(f"Failed to save sessions: {e}") # Clean up temp file if it exists try: if os.path.exists(self._save_file + ".tmp"): os.remove(self._save_file + ".tmp") except Exception: pass def load(self): """Load sessions from disk""" if not os.path.exists(self._save_file): logger.info(f"No session save file found: {self._save_file}") return try: with open(self._save_file, "r") as f: raw = json.load(f) except Exception as e: logger.error(f"Failed to read session save file: {e}") return try: payload = SessionsPayload.model_validate(raw) except ValidationError as e: logger.exception(f"Failed to validate sessions payload: {e}") return current_time = time.time() sessions_loaded = 0 sessions_expired = 0 with self.lock: for s_saved in payload.sessions: # Check if this session should be expired during loading created_at = getattr(s_saved, "created_at", time.time()) last_used = getattr(s_saved, "last_used", time.time()) displaced_at = getattr(s_saved, "displaced_at", None) name = s_saved.name or "" # Apply same removal criteria as cleanup_old_sessions should_expire = self._should_remove_session_static( name, None, created_at, last_used, displaced_at, current_time ) if should_expire: sessions_expired += 1 logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load") continue # Check if session already exists in _instances (deduplication) existing_session = None for existing in self._instances: if existing.id == s_saved.id: existing_session = existing break if existing_session: logger.debug( f"Session {s_saved.id[:8]} already loaded, skipping duplicate" ) continue session = Session( s_saved.id, is_bot=getattr(s_saved, "is_bot", False), has_media=getattr(s_saved, "has_media", True), ) session.name = name session.created_at = created_at session.last_used = last_used session.displaced_at = displaced_at session.is_bot = getattr(s_saved, "is_bot", False) session.has_media = getattr(s_saved, "has_media", True) session.bot_run_id = getattr(s_saved, "bot_run_id", None) session.bot_provider_id = getattr(s_saved, "bot_provider_id", None) # Note: Lobby restoration will be handled by LobbyManager self._instances.append(session) sessions_loaded += 1 logger.info(f"Loaded {sessions_loaded} sessions from {self._save_file}") if sessions_expired > 0: logger.info(f"Expired {sessions_expired} old sessions during load") self.save() # Mark as loaded to prevent duplicate loads self._loaded = True @staticmethod def _should_remove_session_static( name: str, ws: Optional[WebSocket], created_at: float, last_used: float, displaced_at: Optional[float], current_time: float, ) -> bool: """Static method to determine if a session should be removed""" # Rule 1: Delete sessions with no active connection and no name that are older than threshold if ( not ws and not name and current_time - created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT ): return True # Rule 2: Delete inactive sessions that had their nick taken over and haven't been used recently if ( not ws and displaced_at is not None and current_time - last_used > SessionConfig.DISPLACED_SESSION_TIMEOUT ): return True return False def cleanup_old_sessions(self) -> int: """Clean up old/stale sessions and return count of removed sessions""" current_time = time.time() removed_count = 0 with self.lock: sessions_to_remove = [] for session in self._instances: with session.session_lock: if self._should_remove_session_static( session.name, session.ws, session.created_at, session.last_used, session.displaced_at, current_time, ): sessions_to_remove.append(session) if len(sessions_to_remove) >= SessionConfig.MAX_SESSIONS_PER_CLEANUP: break # Remove sessions for session in sessions_to_remove: try: # Clean up websocket if open if session.ws: asyncio.create_task(session.ws.close()) # Remove from lobbies (will be handled by lobby manager events) for lobby in session.lobbies[:]: asyncio.create_task(session.leave_lobby(lobby)) self._instances.remove(session) removed_count += 1 logger.info(f"Cleaned up session {session.getName()}") except Exception as e: logger.warning(f"Error cleaning up session {session.getName()}: {e}") if removed_count > 0: self.save() return removed_count async def start_background_tasks(self): """Start background cleanup and validation tasks""" logger.info("Starting session background tasks...") self.cleanup_task_running = True self.validation_task_running = True self.cleanup_task = asyncio.create_task(self._periodic_cleanup()) self.validation_task = asyncio.create_task(self._periodic_validation()) logger.info("Session background tasks started") async def stop_background_tasks(self): """Stop background tasks gracefully""" logger.info("Shutting down session background tasks...") self.cleanup_task_running = False self.validation_task_running = False # Cancel tasks for task in [self.cleanup_task, self.validation_task]: if task: task.cancel() try: await task except asyncio.CancelledError: pass # Clean up all sessions gracefully await self._cleanup_all_sessions() logger.info("Session background tasks stopped") async def _periodic_cleanup(self): """Background task to periodically clean up old sessions""" cleanup_errors = 0 max_consecutive_errors = 5 while self.cleanup_task_running: try: removed_count = self.cleanup_old_sessions() if removed_count > 0: logger.info(f"Periodic cleanup removed {removed_count} old sessions") cleanup_errors = 0 # Reset error counter on success # Run cleanup at configured interval await asyncio.sleep(SessionConfig.CLEANUP_INTERVAL) except Exception as e: cleanup_errors += 1 logger.error( f"Error in session cleanup task (attempt {cleanup_errors}): {e}" ) if cleanup_errors >= max_consecutive_errors: logger.error( f"Too many consecutive cleanup errors ({cleanup_errors}), stopping cleanup task" ) break # Exponential backoff on errors await asyncio.sleep(min(60 * cleanup_errors, 300)) async def _periodic_validation(self): """Background task to periodically validate session integrity""" while self.validation_task_running: try: issues = self.validate_session_integrity() if issues: logger.warning(f"Session integrity issues found: {len(issues)} issues") for issue in issues[:10]: # Log first 10 issues logger.warning(f"Integrity issue: {issue}") await asyncio.sleep(SessionConfig.SESSION_VALIDATION_INTERVAL) except Exception as e: logger.error(f"Error in session validation task: {e}") await asyncio.sleep(300) # Wait 5 minutes before retrying on error def validate_session_integrity(self) -> List[str]: """Validate session integrity and return list of issues""" issues = [] with self.lock: for session in self._instances: with session.session_lock: # Check for sessions with invalid state if not session.id: issues.append(f"Session with empty ID: {session}") if session.created_at > time.time(): issues.append(f"Session {session.getName()} has future creation time") if session.last_used > time.time(): issues.append(f"Session {session.getName()} has future last_used time") # Check for duplicate names if session.name: count = sum(1 for s in self._instances if s.name and s.name.lower() == session.name.lower()) if count > 1: issues.append(f"Duplicate name '{session.name}' found in {count} sessions") return issues async def _cleanup_all_sessions(self): """Clean up all sessions during shutdown""" with self.lock: for session in self._instances[:]: try: if session.ws: await session.ws.close() except Exception as e: logger.warning(f"Error closing WebSocket for {session.getName()}: {e}") logger.info("All sessions cleaned up") def get_all_sessions(self) -> List[Session]: """Get all sessions (for admin/debugging purposes)""" with self.lock: return self._instances.copy() def get_session_count(self) -> int: """Get total session count""" with self.lock: return len(self._instances)