ai-voicebot/server/core/session_manager.py

722 lines
25 KiB
Python

"""
Session management for the AI Voice Bot server.
This module handles session lifecycle, persistence, and cleanup operations.
"""
from __future__ import annotations
import json
import os
import time
import threading
import secrets
import asyncio
from typing import Optional, List, Dict, Any
from fastapi import WebSocket
from pydantic import ValidationError
# Import shared models
try:
# Try relative import first (when running as part of the package)
from ...shared.models import (
SessionSaved,
LobbySaved,
SessionsPayload,
NamePasswordRecord,
)
except ImportError:
try:
# Try absolute import (when running directly)
import sys
import os
sys.path.append(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
from shared.models import (
SessionSaved,
LobbySaved,
SessionsPayload,
NamePasswordRecord,
)
except ImportError:
raise ImportError(
f"Failed to import shared models: {e}. Ensure shared/models.py is accessible and PYTHONPATH is correctly set."
)
from core.lobby_manager import Lobby
from logger import logger
# Import WebRTC signaling for peer management
from websocket.webrtc_signaling import WebRTCSignalingHandlers
# Use try/except for importing events to handle both relative and absolute imports
try:
from ..models.events import (
event_bus,
SessionDisconnected,
UserNameChanged,
SessionJoinedLobby,
SessionLeftLobby,
)
except ImportError:
try:
from models.events import (
event_bus,
SessionDisconnected,
UserNameChanged,
SessionJoinedLobby,
SessionLeftLobby,
)
except ImportError:
# Create dummy event system for standalone testing
class DummyEventBus:
async def publish(self, event):
pass
event_bus = DummyEventBus()
class SessionDisconnected:
pass
class UserNameChanged:
pass
class SessionJoinedLobby:
pass
class SessionLeftLobby:
pass
class SessionConfig:
"""Configuration class for session management"""
ANONYMOUS_SESSION_TIMEOUT = int(
os.getenv("ANONYMOUS_SESSION_TIMEOUT", "60")
) # 1 minute
DISPLACED_SESSION_TIMEOUT = int(
os.getenv("DISPLACED_SESSION_TIMEOUT", "10800")
) # 3 hours
CLEANUP_INTERVAL = int(os.getenv("CLEANUP_INTERVAL", "300")) # 5 minutes
MAX_SESSIONS_PER_CLEANUP = int(
os.getenv("MAX_SESSIONS_PER_CLEANUP", "100")
) # Circuit breaker
SESSION_VALIDATION_INTERVAL = int(
os.getenv("SESSION_VALIDATION_INTERVAL", "1800")
) # 30 minutes
class Session:
"""Individual session representing a user or bot connection"""
def __init__(
self, id: str, bot_instance_id: Optional[str] = None, has_media: bool = True
):
logger.info(
f"Instantiating new session {id} (bot: {True if bot_instance_id else False}, media: {has_media})"
)
self.id = id
self.short = id[:8]
self.name = ""
self.lobbies: List[Any] = [] # List of lobby objects this session is in
self.lobby_peers: Dict[
str, List[str]
] = {} # lobby ID -> list of peer session IDs
self.ws: Optional[WebSocket] = None
self.created_at = time.time()
self.last_used = time.time()
self.displaced_at: Optional[float] = None # When name was taken over
self.has_media = has_media # Whether this session provides audio/video streams
self.bot_run_id: Optional[str] = None # Bot run ID for tracking
self.bot_provider_id: Optional[str] = None # Bot provider ID
self.bot_instance_id = (
bot_instance_id # Bot instance ID if this is a bot session
)
self.bot_instance_id: Optional[str] = None # Bot instance ID for tracking
self.session_lock = threading.RLock() # Instance-level lock
def getName(self) -> str:
with self.session_lock:
return f"{self.short}:{self.name if self.name else '[ ---- ]'}"
def setName(self, name: str):
with self.session_lock:
old_name = self.name
self.name = name
self.update_last_used()
# Get lobby IDs for event
lobby_ids = [lobby.id for lobby in self.lobbies]
# Publish name change event (don't await here to avoid blocking)
asyncio.create_task(
event_bus.publish(
UserNameChanged(
session_id=self.id,
old_name=old_name,
new_name=name,
lobby_ids=lobby_ids,
)
)
)
def update_last_used(self):
"""Update the last_used timestamp"""
with self.session_lock:
self.last_used = time.time()
def mark_displaced(self):
"""Mark this session as having its name taken over"""
with self.session_lock:
self.displaced_at = time.time()
async def join_lobby(self, lobby: Lobby):
"""Join a lobby and establish WebRTC peer connections"""
with self.session_lock:
if lobby not in self.lobbies:
self.lobbies.append(lobby)
# Initialize lobby_peers for this lobby if not exists
if lobby.id not in self.lobby_peers:
self.lobby_peers[lobby.id] = []
# Add to lobby first
await lobby.addSession(self)
# Get existing peer sessions in this lobby for WebRTC setup
peer_sessions: list[Session] = []
for session in lobby.sessions.values():
if (
session.id != self.id and session.ws
): # Don't include self and only connected sessions
peer_sessions.append(session)
# Establish WebRTC peer connections with existing sessions using signaling handlers
for peer_session in peer_sessions:
await WebRTCSignalingHandlers.handle_add_peer(self, peer_session, lobby)
# Publish join event
await event_bus.publish(
SessionJoinedLobby(
session_id=self.id,
lobby_id=lobby.id,
session_name=self.name or self.short,
)
)
async def leave_lobby(self, lobby: Lobby):
"""Leave a lobby and clean up WebRTC peer connections"""
# Get peer sessions before removing from lobby
peer_sessions: list[Session] = []
if lobby.id in self.lobby_peers:
for peer_id in self.lobby_peers[lobby.id]:
peer_session = None
# Find peer session in lobby
for session in lobby.sessions.values():
if session.id == peer_id:
peer_session = session
break
if peer_session and peer_session.ws:
peer_sessions.append(peer_session)
# Handle WebRTC peer disconnections using signaling handlers
for peer_session in peer_sessions:
await WebRTCSignalingHandlers.handle_remove_peer(self, peer_session, lobby)
# Clean up our lobby_peers and lobbies
with self.session_lock:
if lobby in self.lobbies:
self.lobbies.remove(lobby)
if lobby.id in self.lobby_peers:
del self.lobby_peers[lobby.id]
# Remove from lobby
await lobby.removeSession(self)
# Publish leave event
await event_bus.publish(
SessionLeftLobby(
session_id=self.id,
lobby_id=lobby.id,
session_name=self.name or self.short,
)
)
def model_dump(self) -> Dict[str, Any]:
"""Convert session to dictionary format for API responses"""
with self.session_lock:
data: Dict[str, Any] = {
"id": self.id,
"name": self.name or "",
"bot_instance_id": self.bot_instance_id,
"has_media": self.has_media,
"created_at": self.created_at,
"last_used": self.last_used,
}
return data
def to_saved(self) -> SessionSaved:
"""Convert session to saved format for persistence"""
with self.session_lock:
lobbies_list: List[LobbySaved] = [
LobbySaved(id=lobby.id, name=lobby.name, private=lobby.private)
for lobby in self.lobbies
]
return SessionSaved(
id=self.id,
name=self.name or "",
lobbies=lobbies_list,
created_at=self.created_at,
last_used=self.last_used,
displaced_at=self.displaced_at,
has_media=self.has_media,
bot_run_id=self.bot_run_id,
bot_provider_id=self.bot_provider_id,
bot_instance_id=self.bot_instance_id,
)
class SessionManager:
"""Manages all sessions and their lifecycle"""
def __init__(self, save_file: str = "sessions.json"):
self._instances: List[Session] = []
self._save_file = save_file
self._loaded = False
self.lock = threading.RLock() # Thread safety for class-level operations
# Background task management
self.cleanup_task_running = False
self.cleanup_task: Optional[asyncio.Task] = None
self.validation_task_running = False
self.validation_task: Optional[asyncio.Task] = None
def create_session(
self,
session_id: Optional[str] = None,
bot_instance_id: Optional[str] = None,
has_media: bool = True,
) -> Session:
"""Create a new session with given or generated ID"""
if not session_id:
session_id = secrets.token_hex(16)
with self.lock:
# Check if session already exists (now inside the lock for atomicity)
existing_session = self.get_session(session_id)
if existing_session:
logger.debug(
f"Session {session_id[:8]} already exists, returning existing session"
)
return existing_session
# Create new session
session = Session(
session_id, bot_instance_id=bot_instance_id, has_media=has_media
)
self._instances.append(session)
self.save()
return session
def get_or_create_session(
self,
session_id: Optional[str] = None,
bot_instance_id: Optional[str] = None,
has_media: bool = True,
) -> Session:
"""Get existing session or create a new one"""
if session_id:
existing_session = self.get_session(session_id)
if existing_session:
return existing_session
return self.create_session(
session_id, bot_instance_id=bot_instance_id, has_media=has_media
)
def get_session(self, session_id: str) -> Optional[Session]:
"""Get session by ID"""
with self.lock:
if not self._loaded:
self.load()
logger.info(f"Loaded {len(self._instances)} sessions from disk...")
for s in self._instances:
if s.id == session_id:
return s
return None
def get_session_by_name(self, name: str) -> Optional[Session]:
"""Get session by name"""
if not name:
return None
lname = name.lower()
with self.lock:
for s in self._instances:
with s.session_lock:
if s.name and s.name.lower() == lname:
return s
return None
def is_unique_name(self, name: str) -> bool:
"""Check if a name is unique across all sessions"""
if not name:
return False
with self.lock:
for s in self._instances:
with s.session_lock:
if s.name.lower() == name.lower():
return False
return True
def remove_session(self, session: Session):
"""Remove a session from the manager"""
with self.lock:
if session in self._instances:
self._instances.remove(session)
# Publish disconnect event
lobby_ids = [lobby.id for lobby in session.lobbies]
asyncio.create_task(
event_bus.publish(
SessionDisconnected(
session_id=session.id,
session_name=session.name or session.short,
lobby_ids=lobby_ids,
)
)
)
def save(self):
"""Save all sessions to disk"""
try:
with self.lock:
sessions_list: List[SessionSaved] = []
for s in self._instances:
sessions_list.append(s.to_saved())
# Note: We'll need to handle name_passwords separately or inject it
# For now, create empty dict - this will be handled by AuthManager
saved_pw: Dict[str, NamePasswordRecord] = {}
payload_model = SessionsPayload(
sessions=sessions_list, name_passwords=saved_pw
)
payload = payload_model.model_dump()
# Atomic write using temp file
temp_file = self._save_file + ".tmp"
with open(temp_file, "w") as f:
json.dump(payload, f, indent=2)
# Atomic rename
os.rename(temp_file, self._save_file)
logger.info(f"Saved {len(sessions_list)} sessions to {self._save_file}")
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
# Clean up temp file if it exists
try:
if os.path.exists(self._save_file + ".tmp"):
os.remove(self._save_file + ".tmp")
except Exception:
pass
def load(self):
"""Load sessions from disk"""
if not os.path.exists(self._save_file):
logger.info(f"No session save file found: {self._save_file}")
return
try:
with open(self._save_file, "r") as f:
raw = json.load(f)
except Exception as e:
logger.error(f"Failed to read session save file: {e}")
return
try:
payload = SessionsPayload.model_validate(raw)
except ValidationError as e:
logger.exception(f"Failed to validate sessions payload: {e}")
return
current_time = time.time()
sessions_loaded = 0
sessions_expired = 0
with self.lock:
for s_saved in payload.sessions:
# Check if this session should be expired during loading
created_at = getattr(s_saved, "created_at", time.time())
last_used = getattr(s_saved, "last_used", time.time())
displaced_at = getattr(s_saved, "displaced_at", None)
name = s_saved.name or ""
# Apply same removal criteria as cleanup_old_sessions
should_expire = self._should_remove_session_static(
name, None, created_at, last_used, displaced_at, current_time
)
if should_expire:
sessions_expired += 1
logger.info(f"Expiring session {s_saved.id[:8]}:{name} during load")
continue
# Check if session already exists in _instances (deduplication)
existing_session = None
for existing in self._instances:
if existing.id == s_saved.id:
existing_session = existing
break
if existing_session:
logger.debug(
f"Session {s_saved.id[:8]} already loaded, skipping duplicate"
)
continue
session = Session(
s_saved.id,
bot_instance_id=s_saved.bot_instance_id,
has_media=s_saved.has_media,
)
session.name = name
session.created_at = created_at
session.last_used = last_used
session.displaced_at = displaced_at
session.bot_run_id = s_saved.bot_run_id
session.bot_provider_id = s_saved.bot_provider_id
# Note: Lobby restoration will be handled by LobbyManager
self._instances.append(session)
sessions_loaded += 1
logger.info(f"Loaded {sessions_loaded} sessions from {self._save_file}")
if sessions_expired > 0:
logger.info(f"Expired {sessions_expired} old sessions during load")
self.save()
# Mark as loaded to prevent duplicate loads
self._loaded = True
@staticmethod
def _should_remove_session_static(
name: str,
ws: Optional[WebSocket],
created_at: float,
last_used: float,
displaced_at: Optional[float],
current_time: float,
) -> bool:
"""Static method to determine if a session should be removed"""
# Rule 1: Delete sessions with no active connection and no name that are older than threshold
if (
not ws
and not name
and current_time - created_at > SessionConfig.ANONYMOUS_SESSION_TIMEOUT
):
return True
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used recently
if (
not ws
and displaced_at is not None
and current_time - last_used > SessionConfig.DISPLACED_SESSION_TIMEOUT
):
return True
return False
def cleanup_old_sessions(self) -> int:
"""Clean up old/stale sessions and return count of removed sessions"""
current_time = time.time()
removed_count = 0
with self.lock:
sessions_to_remove: list[Session] = []
for session in self._instances:
with session.session_lock:
if self._should_remove_session_static(
session.name,
session.ws,
session.created_at,
session.last_used,
session.displaced_at,
current_time,
):
sessions_to_remove.append(session)
if (
len(sessions_to_remove)
>= SessionConfig.MAX_SESSIONS_PER_CLEANUP
):
break
# Remove sessions
for session in sessions_to_remove:
try:
# Clean up websocket if open
if session.ws:
asyncio.create_task(session.ws.close())
# Remove from lobbies (will be handled by lobby manager events)
for lobby in session.lobbies[:]:
asyncio.create_task(session.leave_lobby(lobby))
self._instances.remove(session)
removed_count += 1
logger.info(f"Cleaned up session {session.getName()}")
except Exception as e:
logger.warning(
f"Error cleaning up session {session.getName()}: {e}"
)
if removed_count > 0:
self.save()
return removed_count
async def start_background_tasks(self):
"""Start background cleanup and validation tasks"""
logger.info("Starting session background tasks...")
self.cleanup_task_running = True
self.validation_task_running = True
self.cleanup_task = asyncio.create_task(self._periodic_cleanup())
self.validation_task = asyncio.create_task(self._periodic_validation())
logger.info("Session background tasks started")
async def stop_background_tasks(self):
"""Stop background tasks gracefully"""
logger.info("Shutting down session background tasks...")
self.cleanup_task_running = False
self.validation_task_running = False
# Cancel tasks
for task in [self.cleanup_task, self.validation_task]:
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Clean up all sessions gracefully
await self._cleanup_all_sessions()
logger.info("Session background tasks stopped")
async def _periodic_cleanup(self):
"""Background task to periodically clean up old sessions"""
cleanup_errors = 0
max_consecutive_errors = 5
while self.cleanup_task_running:
try:
removed_count = self.cleanup_old_sessions()
if removed_count > 0:
logger.info(
f"Periodic cleanup removed {removed_count} old sessions"
)
cleanup_errors = 0 # Reset error counter on success
# Run cleanup at configured interval
await asyncio.sleep(SessionConfig.CLEANUP_INTERVAL)
except Exception as e:
cleanup_errors += 1
logger.error(
f"Error in session cleanup task (attempt {cleanup_errors}): {e}"
)
if cleanup_errors >= max_consecutive_errors:
logger.error(
f"Too many consecutive cleanup errors ({cleanup_errors}), stopping cleanup task"
)
break
# Exponential backoff on errors
await asyncio.sleep(min(60 * cleanup_errors, 300))
async def _periodic_validation(self):
"""Background task to periodically validate session integrity"""
while self.validation_task_running:
try:
issues = self.validate_session_integrity()
if issues:
logger.warning(
f"Session integrity issues found: {len(issues)} issues"
)
for issue in issues[:10]: # Log first 10 issues
logger.warning(f"Integrity issue: {issue}")
await asyncio.sleep(SessionConfig.SESSION_VALIDATION_INTERVAL)
except Exception as e:
logger.error(f"Error in session validation task: {e}")
await asyncio.sleep(300) # Wait 5 minutes before retrying on error
def validate_session_integrity(self) -> List[str]:
"""Validate session integrity and return list of issues"""
issues = []
with self.lock:
for session in self._instances:
with session.session_lock:
# Check for sessions with invalid state
if not session.id:
issues.append(f"Session with empty ID: {session}")
if session.created_at > time.time():
issues.append(
f"Session {session.getName()} has future creation time"
)
if session.last_used > time.time():
issues.append(
f"Session {session.getName()} has future last_used time"
)
# Check for duplicate names
if session.name:
count = sum(
1
for s in self._instances
if s.name and s.name.lower() == session.name.lower()
)
if count > 1:
issues.append(
f"Duplicate name '{session.name}' found in {count} sessions"
)
return issues
async def _cleanup_all_sessions(self):
"""Clean up all sessions during shutdown"""
with self.lock:
for session in self._instances[:]:
try:
if session.ws:
await session.ws.close()
except Exception as e:
logger.warning(
f"Error closing WebSocket for {session.getName()}: {e}"
)
logger.info("All sessions cleaned up")
def get_all_sessions(self) -> List[Session]:
"""Get all sessions (for admin/debugging purposes)"""
with self.lock:
return self._instances.copy()
def get_session_count(self) -> int:
"""Get total session count"""
with self.lock:
return len(self._instances)