Added session cleanup

This commit is contained in:
James Ketr 2025-09-01 20:09:49 -07:00
parent a292b14028
commit 2ad9871ea4
2 changed files with 159 additions and 2 deletions

View File

@ -16,6 +16,8 @@ import json
import hashlib import hashlib
import binascii import binascii
import sys import sys
import asyncio
from contextlib import asynccontextmanager
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
import httpx import httpx
@ -72,7 +74,48 @@ public_url = os.getenv("PUBLIC_URL", "/")
if not public_url.endswith("/"): if not public_url.endswith("/"):
public_url += "/" public_url += "/"
app = FastAPI() # Global variable to control the cleanup task
cleanup_task_running = False
cleanup_task = None
async def periodic_cleanup():
"""Background task to periodically clean up old sessions"""
global cleanup_task_running
while cleanup_task_running:
try:
Session.cleanup_old_sessions()
# Run cleanup every 5 minutes
await asyncio.sleep(300)
except Exception as e:
logger.error(f"Error in session cleanup task: {e}")
await asyncio.sleep(60) # Wait 1 minute before retrying on error
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
global cleanup_task_running, cleanup_task
# Startup
cleanup_task_running = True
cleanup_task = asyncio.create_task(periodic_cleanup())
logger.info("Session cleanup task started")
yield
# Shutdown
cleanup_task_running = False
if cleanup_task:
cleanup_task.cancel()
try:
await cleanup_task
except asyncio.CancelledError:
pass
logger.info("Session cleanup task stopped")
app = FastAPI(lifespan=lifespan)
logger.info(f"Starting server with public URL: {public_url}") logger.info(f"Starting server with public URL: {public_url}")
@ -117,6 +160,18 @@ def admin_clear_password(request: Request, payload: AdminClearPassword = Body(..
return {"status": "not_found", "name": payload.name} return {"status": "not_found", "name": payload.name}
@app.post(public_url + "api/admin/cleanup_sessions", response_model=AdminActionResponse)
def admin_cleanup_sessions(request: Request):
if not _require_admin(request):
return Response(status_code=403)
try:
removed_count = Session.cleanup_old_sessions()
return {"status": "ok", "name": f"Removed {removed_count} sessions"}
except Exception as e:
logger.error(f"Error during manual session cleanup: {e}")
return {"status": "not_found", "name": f"Error: {str(e)}"}
lobbies: dict[str, Lobby] = {} lobbies: dict[str, Lobby] = {}
@ -230,6 +285,7 @@ class Session:
_loaded = False _loaded = False
def __init__(self, id: str): def __init__(self, id: str):
import time
logger.info(f"Instantiating new session {id}") logger.info(f"Instantiating new session {id}")
self._instances.append(self) self._instances.append(self)
self.id = id self.id = id
@ -240,6 +296,9 @@ class Session:
str, list[str] str, list[str]
] = {} # lobby ID -> list of peer session IDs ] = {} # lobby ID -> list of peer session IDs
self.ws: WebSocket | None = None self.ws: WebSocket | None = None
self.created_at = time.time()
self.last_used = time.time()
self.displaced_at: float | None = None # When name was taken over
self.save() self.save()
@classmethod @classmethod
@ -251,7 +310,14 @@ class Session:
for lobby in s.lobbies for lobby in s.lobbies
] ]
sessions_list.append( sessions_list.append(
SessionSaved(id=s.id, name=s.name or "", lobbies=lobbies_list) SessionSaved(
id=s.id,
name=s.name or "",
lobbies=lobbies_list,
created_at=s.created_at,
last_used=s.last_used,
displaced_at=s.displaced_at,
)
) )
# Prepare name password store for persistence (salt+hash). Only structured records are supported. # Prepare name password store for persistence (salt+hash). Only structured records are supported.
saved_pw: dict[str, NamePasswordRecord] = { saved_pw: dict[str, NamePasswordRecord] = {
@ -271,6 +337,7 @@ class Session:
@classmethod @classmethod
def load(cls): def load(cls):
import time
if not os.path.exists(cls._save_file): if not os.path.exists(cls._save_file):
logger.info(f"No session save file found: {cls._save_file}") logger.info(f"No session save file found: {cls._save_file}")
return return
@ -292,6 +359,10 @@ class Session:
for s_saved in payload.sessions: for s_saved in payload.sessions:
session = Session(s_saved.id) session = Session(s_saved.id)
session.name = s_saved.name or "" session.name = s_saved.name or ""
# Load timestamps, with defaults for backward compatibility
session.created_at = getattr(s_saved, "created_at", time.time())
session.last_used = getattr(s_saved, "last_used", time.time())
session.displaced_at = getattr(s_saved, "displaced_at", None)
for lobby_saved in s_saved.lobbies: for lobby_saved in s_saved.lobbies:
session.lobbies.append( session.lobbies.append(
Lobby( Lobby(
@ -347,8 +418,87 @@ class Session:
def setName(self, name: str): def setName(self, name: str):
self.name = name self.name = name
self.update_last_used()
self.save() self.save()
def update_last_used(self):
"""Update the last_used timestamp"""
import time
self.last_used = time.time()
def mark_displaced(self):
"""Mark this session as having its name taken over"""
import time
self.displaced_at = time.time()
@classmethod
def cleanup_old_sessions(cls) -> int:
"""Clean up old sessions based on the specified criteria"""
import time
current_time = time.time()
one_minute = 60.0
twenty_four_hours = 24 * 60 * 60.0
sessions_removed = 0
# Make a copy of the list to avoid modifying it while iterating
sessions_to_remove: list[Session] = []
for session in cls._instances[:]:
# Rule 1: Delete sessions with no active connection and no name that are older than 1 minute
if (
not session.ws
and not session.name
and current_time - session.created_at > one_minute
):
logger.info(
f"Removing session {session.getName()} - no connection, no name, older than 1 minute"
)
sessions_to_remove.append(session)
continue
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used in 24 hours
if (
not session.ws
and session.displaced_at is not None
and current_time - session.last_used > twenty_four_hours
):
logger.info(
f"Removing session {session.getName()} - displaced and unused for 24+ hours"
)
sessions_to_remove.append(session)
continue
# Remove the sessions
for session in sessions_to_remove:
# Remove from lobbies first
for lobby in session.lobbies[
:
]: # Copy list to avoid modification during iteration
try:
# Use async cleanup if needed, but for cleanup we'll just remove from data structures
if session.id in lobby.sessions:
del lobby.sessions[session.id]
if lobby.id in session.lobby_peers:
del session.lobby_peers[lobby.id]
except Exception as e:
logger.warning(
f"Error removing session {session.getName()} from lobby {lobby.getName()}: {e}"
)
# Remove from instances list
if session in cls._instances:
cls._instances.remove(session)
sessions_removed += 1
if sessions_removed > 0:
cls.save()
logger.info(f"Session cleanup: removed {sessions_removed} old sessions")
return sessions_removed
async def join(self, lobby: Lobby): async def join(self, lobby: Lobby):
if not self.ws: if not self.ws:
logger.error( logger.error(
@ -543,6 +693,7 @@ async def session(
session = Session(session_id) session = Session(session_id)
logger.info(f"{session.getName()}: New session created.") logger.info(f"{session.getName()}: New session created.")
else: else:
session.update_last_used() # Update activity on session resumption
logger.info(f"{session.getName()}: Existing session resumed.") logger.info(f"{session.getName()}: Existing session resumed.")
# Part all lobbies for this session that have no active websocket # Part all lobbies for this session that have no active websocket
for lobby_id in list(session.lobby_peers.keys()): for lobby_id in list(session.lobby_peers.keys()):
@ -673,6 +824,7 @@ async def lobby_join(
logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})") logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})")
session.ws = websocket session.ws = websocket
session.update_last_used() # Update activity timestamp
if session.id in lobby.sessions: if session.id in lobby.sessions:
logger.info( logger.info(
f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining." f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining."
@ -702,6 +854,7 @@ async def lobby_join(
try: try:
while True: while True:
packet = await websocket.receive_json() packet = await websocket.receive_json()
session.update_last_used() # Update activity on each message
type = packet.get("type", None) type = packet.get("type", None)
data: dict[str, Any] | None = packet.get("data", None) data: dict[str, Any] | None = packet.get("data", None)
if not type: if not type:
@ -805,6 +958,7 @@ async def lobby_join(
fallback = f"{displaced.name}-{secrets.token_hex(3)}" fallback = f"{displaced.name}-{secrets.token_hex(3)}"
displaced.setName(fallback) displaced.setName(fallback)
displaced.mark_displaced()
logger.info( logger.info(
f"{displaced.getName()} <- displaced by takeover, new name {fallback}" f"{displaced.getName()} <- displaced by takeover, new name {fallback}"
) )

View File

@ -259,6 +259,9 @@ class SessionSaved(BaseModel):
id: str id: str
name: str = "" name: str = ""
lobbies: List[LobbySaved] = [] lobbies: List[LobbySaved] = []
created_at: float = 0.0
last_used: float = 0.0
displaced_at: Optional[float] = None # When name was taken over
class SessionsPayload(BaseModel): class SessionsPayload(BaseModel):