Added session cleanup

This commit is contained in:
James Ketr 2025-09-01 20:09:49 -07:00
parent a292b14028
commit 2ad9871ea4
2 changed files with 159 additions and 2 deletions

View File

@ -16,6 +16,8 @@ import json
import hashlib
import binascii
import sys
import asyncio
from contextlib import asynccontextmanager
from fastapi.staticfiles import StaticFiles
import httpx
@ -72,7 +74,48 @@ public_url = os.getenv("PUBLIC_URL", "/")
if not public_url.endswith("/"):
public_url += "/"
app = FastAPI()
# Global variable to control the cleanup task
cleanup_task_running = False
cleanup_task = None
async def periodic_cleanup():
"""Background task to periodically clean up old sessions"""
global cleanup_task_running
while cleanup_task_running:
try:
Session.cleanup_old_sessions()
# Run cleanup every 5 minutes
await asyncio.sleep(300)
except Exception as e:
logger.error(f"Error in session cleanup task: {e}")
await asyncio.sleep(60) # Wait 1 minute before retrying on error
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events"""
global cleanup_task_running, cleanup_task
# Startup
cleanup_task_running = True
cleanup_task = asyncio.create_task(periodic_cleanup())
logger.info("Session cleanup task started")
yield
# Shutdown
cleanup_task_running = False
if cleanup_task:
cleanup_task.cancel()
try:
await cleanup_task
except asyncio.CancelledError:
pass
logger.info("Session cleanup task stopped")
app = FastAPI(lifespan=lifespan)
logger.info(f"Starting server with public URL: {public_url}")
@ -117,6 +160,18 @@ def admin_clear_password(request: Request, payload: AdminClearPassword = Body(..
return {"status": "not_found", "name": payload.name}
@app.post(public_url + "api/admin/cleanup_sessions", response_model=AdminActionResponse)
def admin_cleanup_sessions(request: Request):
if not _require_admin(request):
return Response(status_code=403)
try:
removed_count = Session.cleanup_old_sessions()
return {"status": "ok", "name": f"Removed {removed_count} sessions"}
except Exception as e:
logger.error(f"Error during manual session cleanup: {e}")
return {"status": "not_found", "name": f"Error: {str(e)}"}
lobbies: dict[str, Lobby] = {}
@ -230,6 +285,7 @@ class Session:
_loaded = False
def __init__(self, id: str):
import time
logger.info(f"Instantiating new session {id}")
self._instances.append(self)
self.id = id
@ -240,6 +296,9 @@ class Session:
str, list[str]
] = {} # lobby ID -> list of peer session IDs
self.ws: WebSocket | None = None
self.created_at = time.time()
self.last_used = time.time()
self.displaced_at: float | None = None # When name was taken over
self.save()
@classmethod
@ -251,7 +310,14 @@ class Session:
for lobby in s.lobbies
]
sessions_list.append(
SessionSaved(id=s.id, name=s.name or "", lobbies=lobbies_list)
SessionSaved(
id=s.id,
name=s.name or "",
lobbies=lobbies_list,
created_at=s.created_at,
last_used=s.last_used,
displaced_at=s.displaced_at,
)
)
# Prepare name password store for persistence (salt+hash). Only structured records are supported.
saved_pw: dict[str, NamePasswordRecord] = {
@ -271,6 +337,7 @@ class Session:
@classmethod
def load(cls):
import time
if not os.path.exists(cls._save_file):
logger.info(f"No session save file found: {cls._save_file}")
return
@ -292,6 +359,10 @@ class Session:
for s_saved in payload.sessions:
session = Session(s_saved.id)
session.name = s_saved.name or ""
# Load timestamps, with defaults for backward compatibility
session.created_at = getattr(s_saved, "created_at", time.time())
session.last_used = getattr(s_saved, "last_used", time.time())
session.displaced_at = getattr(s_saved, "displaced_at", None)
for lobby_saved in s_saved.lobbies:
session.lobbies.append(
Lobby(
@ -347,8 +418,87 @@ class Session:
def setName(self, name: str):
self.name = name
self.update_last_used()
self.save()
def update_last_used(self):
"""Update the last_used timestamp"""
import time
self.last_used = time.time()
def mark_displaced(self):
"""Mark this session as having its name taken over"""
import time
self.displaced_at = time.time()
@classmethod
def cleanup_old_sessions(cls) -> int:
"""Clean up old sessions based on the specified criteria"""
import time
current_time = time.time()
one_minute = 60.0
twenty_four_hours = 24 * 60 * 60.0
sessions_removed = 0
# Make a copy of the list to avoid modifying it while iterating
sessions_to_remove: list[Session] = []
for session in cls._instances[:]:
# Rule 1: Delete sessions with no active connection and no name that are older than 1 minute
if (
not session.ws
and not session.name
and current_time - session.created_at > one_minute
):
logger.info(
f"Removing session {session.getName()} - no connection, no name, older than 1 minute"
)
sessions_to_remove.append(session)
continue
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used in 24 hours
if (
not session.ws
and session.displaced_at is not None
and current_time - session.last_used > twenty_four_hours
):
logger.info(
f"Removing session {session.getName()} - displaced and unused for 24+ hours"
)
sessions_to_remove.append(session)
continue
# Remove the sessions
for session in sessions_to_remove:
# Remove from lobbies first
for lobby in session.lobbies[
:
]: # Copy list to avoid modification during iteration
try:
# Use async cleanup if needed, but for cleanup we'll just remove from data structures
if session.id in lobby.sessions:
del lobby.sessions[session.id]
if lobby.id in session.lobby_peers:
del session.lobby_peers[lobby.id]
except Exception as e:
logger.warning(
f"Error removing session {session.getName()} from lobby {lobby.getName()}: {e}"
)
# Remove from instances list
if session in cls._instances:
cls._instances.remove(session)
sessions_removed += 1
if sessions_removed > 0:
cls.save()
logger.info(f"Session cleanup: removed {sessions_removed} old sessions")
return sessions_removed
async def join(self, lobby: Lobby):
if not self.ws:
logger.error(
@ -543,6 +693,7 @@ async def session(
session = Session(session_id)
logger.info(f"{session.getName()}: New session created.")
else:
session.update_last_used() # Update activity on session resumption
logger.info(f"{session.getName()}: Existing session resumed.")
# Part all lobbies for this session that have no active websocket
for lobby_id in list(session.lobby_peers.keys()):
@ -673,6 +824,7 @@ async def lobby_join(
logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})")
session.ws = websocket
session.update_last_used() # Update activity timestamp
if session.id in lobby.sessions:
logger.info(
f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining."
@ -702,6 +854,7 @@ async def lobby_join(
try:
while True:
packet = await websocket.receive_json()
session.update_last_used() # Update activity on each message
type = packet.get("type", None)
data: dict[str, Any] | None = packet.get("data", None)
if not type:
@ -805,6 +958,7 @@ async def lobby_join(
fallback = f"{displaced.name}-{secrets.token_hex(3)}"
displaced.setName(fallback)
displaced.mark_displaced()
logger.info(
f"{displaced.getName()} <- displaced by takeover, new name {fallback}"
)

View File

@ -259,6 +259,9 @@ class SessionSaved(BaseModel):
id: str
name: str = ""
lobbies: List[LobbySaved] = []
created_at: float = 0.0
last_used: float = 0.0
displaced_at: Optional[float] = None # When name was taken over
class SessionsPayload(BaseModel):