Added session cleanup
This commit is contained in:
parent
a292b14028
commit
2ad9871ea4
158
server/main.py
158
server/main.py
@ -16,6 +16,8 @@ import json
|
||||
import hashlib
|
||||
import binascii
|
||||
import sys
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
import httpx
|
||||
@ -72,7 +74,48 @@ public_url = os.getenv("PUBLIC_URL", "/")
|
||||
if not public_url.endswith("/"):
|
||||
public_url += "/"
|
||||
|
||||
app = FastAPI()
|
||||
# Global variable to control the cleanup task
|
||||
cleanup_task_running = False
|
||||
cleanup_task = None
|
||||
|
||||
|
||||
async def periodic_cleanup():
|
||||
"""Background task to periodically clean up old sessions"""
|
||||
global cleanup_task_running
|
||||
while cleanup_task_running:
|
||||
try:
|
||||
Session.cleanup_old_sessions()
|
||||
# Run cleanup every 5 minutes
|
||||
await asyncio.sleep(300)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in session cleanup task: {e}")
|
||||
await asyncio.sleep(60) # Wait 1 minute before retrying on error
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup and shutdown events"""
|
||||
global cleanup_task_running, cleanup_task
|
||||
|
||||
# Startup
|
||||
cleanup_task_running = True
|
||||
cleanup_task = asyncio.create_task(periodic_cleanup())
|
||||
logger.info("Session cleanup task started")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
cleanup_task_running = False
|
||||
if cleanup_task:
|
||||
cleanup_task.cancel()
|
||||
try:
|
||||
await cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Session cleanup task stopped")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
logger.info(f"Starting server with public URL: {public_url}")
|
||||
|
||||
@ -117,6 +160,18 @@ def admin_clear_password(request: Request, payload: AdminClearPassword = Body(..
|
||||
return {"status": "not_found", "name": payload.name}
|
||||
|
||||
|
||||
@app.post(public_url + "api/admin/cleanup_sessions", response_model=AdminActionResponse)
|
||||
def admin_cleanup_sessions(request: Request):
|
||||
if not _require_admin(request):
|
||||
return Response(status_code=403)
|
||||
try:
|
||||
removed_count = Session.cleanup_old_sessions()
|
||||
return {"status": "ok", "name": f"Removed {removed_count} sessions"}
|
||||
except Exception as e:
|
||||
logger.error(f"Error during manual session cleanup: {e}")
|
||||
return {"status": "not_found", "name": f"Error: {str(e)}"}
|
||||
|
||||
|
||||
lobbies: dict[str, Lobby] = {}
|
||||
|
||||
|
||||
@ -230,6 +285,7 @@ class Session:
|
||||
_loaded = False
|
||||
|
||||
def __init__(self, id: str):
|
||||
import time
|
||||
logger.info(f"Instantiating new session {id}")
|
||||
self._instances.append(self)
|
||||
self.id = id
|
||||
@ -240,6 +296,9 @@ class Session:
|
||||
str, list[str]
|
||||
] = {} # lobby ID -> list of peer session IDs
|
||||
self.ws: WebSocket | None = None
|
||||
self.created_at = time.time()
|
||||
self.last_used = time.time()
|
||||
self.displaced_at: float | None = None # When name was taken over
|
||||
self.save()
|
||||
|
||||
@classmethod
|
||||
@ -251,7 +310,14 @@ class Session:
|
||||
for lobby in s.lobbies
|
||||
]
|
||||
sessions_list.append(
|
||||
SessionSaved(id=s.id, name=s.name or "", lobbies=lobbies_list)
|
||||
SessionSaved(
|
||||
id=s.id,
|
||||
name=s.name or "",
|
||||
lobbies=lobbies_list,
|
||||
created_at=s.created_at,
|
||||
last_used=s.last_used,
|
||||
displaced_at=s.displaced_at,
|
||||
)
|
||||
)
|
||||
# Prepare name password store for persistence (salt+hash). Only structured records are supported.
|
||||
saved_pw: dict[str, NamePasswordRecord] = {
|
||||
@ -271,6 +337,7 @@ class Session:
|
||||
|
||||
@classmethod
|
||||
def load(cls):
|
||||
import time
|
||||
if not os.path.exists(cls._save_file):
|
||||
logger.info(f"No session save file found: {cls._save_file}")
|
||||
return
|
||||
@ -292,6 +359,10 @@ class Session:
|
||||
for s_saved in payload.sessions:
|
||||
session = Session(s_saved.id)
|
||||
session.name = s_saved.name or ""
|
||||
# Load timestamps, with defaults for backward compatibility
|
||||
session.created_at = getattr(s_saved, "created_at", time.time())
|
||||
session.last_used = getattr(s_saved, "last_used", time.time())
|
||||
session.displaced_at = getattr(s_saved, "displaced_at", None)
|
||||
for lobby_saved in s_saved.lobbies:
|
||||
session.lobbies.append(
|
||||
Lobby(
|
||||
@ -347,8 +418,87 @@ class Session:
|
||||
|
||||
def setName(self, name: str):
|
||||
self.name = name
|
||||
self.update_last_used()
|
||||
self.save()
|
||||
|
||||
def update_last_used(self):
|
||||
"""Update the last_used timestamp"""
|
||||
import time
|
||||
|
||||
self.last_used = time.time()
|
||||
|
||||
def mark_displaced(self):
|
||||
"""Mark this session as having its name taken over"""
|
||||
import time
|
||||
|
||||
self.displaced_at = time.time()
|
||||
|
||||
@classmethod
|
||||
def cleanup_old_sessions(cls) -> int:
|
||||
"""Clean up old sessions based on the specified criteria"""
|
||||
import time
|
||||
|
||||
current_time = time.time()
|
||||
one_minute = 60.0
|
||||
twenty_four_hours = 24 * 60 * 60.0
|
||||
sessions_removed = 0
|
||||
|
||||
# Make a copy of the list to avoid modifying it while iterating
|
||||
sessions_to_remove: list[Session] = []
|
||||
|
||||
for session in cls._instances[:]:
|
||||
# Rule 1: Delete sessions with no active connection and no name that are older than 1 minute
|
||||
if (
|
||||
not session.ws
|
||||
and not session.name
|
||||
and current_time - session.created_at > one_minute
|
||||
):
|
||||
logger.info(
|
||||
f"Removing session {session.getName()} - no connection, no name, older than 1 minute"
|
||||
)
|
||||
sessions_to_remove.append(session)
|
||||
continue
|
||||
|
||||
# Rule 2: Delete inactive sessions that had their nick taken over and haven't been used in 24 hours
|
||||
if (
|
||||
not session.ws
|
||||
and session.displaced_at is not None
|
||||
and current_time - session.last_used > twenty_four_hours
|
||||
):
|
||||
logger.info(
|
||||
f"Removing session {session.getName()} - displaced and unused for 24+ hours"
|
||||
)
|
||||
sessions_to_remove.append(session)
|
||||
continue
|
||||
|
||||
# Remove the sessions
|
||||
for session in sessions_to_remove:
|
||||
# Remove from lobbies first
|
||||
for lobby in session.lobbies[
|
||||
:
|
||||
]: # Copy list to avoid modification during iteration
|
||||
try:
|
||||
# Use async cleanup if needed, but for cleanup we'll just remove from data structures
|
||||
if session.id in lobby.sessions:
|
||||
del lobby.sessions[session.id]
|
||||
if lobby.id in session.lobby_peers:
|
||||
del session.lobby_peers[lobby.id]
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error removing session {session.getName()} from lobby {lobby.getName()}: {e}"
|
||||
)
|
||||
|
||||
# Remove from instances list
|
||||
if session in cls._instances:
|
||||
cls._instances.remove(session)
|
||||
sessions_removed += 1
|
||||
|
||||
if sessions_removed > 0:
|
||||
cls.save()
|
||||
logger.info(f"Session cleanup: removed {sessions_removed} old sessions")
|
||||
|
||||
return sessions_removed
|
||||
|
||||
async def join(self, lobby: Lobby):
|
||||
if not self.ws:
|
||||
logger.error(
|
||||
@ -543,6 +693,7 @@ async def session(
|
||||
session = Session(session_id)
|
||||
logger.info(f"{session.getName()}: New session created.")
|
||||
else:
|
||||
session.update_last_used() # Update activity on session resumption
|
||||
logger.info(f"{session.getName()}: Existing session resumed.")
|
||||
# Part all lobbies for this session that have no active websocket
|
||||
for lobby_id in list(session.lobby_peers.keys()):
|
||||
@ -673,6 +824,7 @@ async def lobby_join(
|
||||
logger.info(f"{session.getName()} <- lobby_joined({lobby.getName()})")
|
||||
|
||||
session.ws = websocket
|
||||
session.update_last_used() # Update activity timestamp
|
||||
if session.id in lobby.sessions:
|
||||
logger.info(
|
||||
f"{session.getName()} - Stale session in lobby {lobby.getName()}. Re-joining."
|
||||
@ -702,6 +854,7 @@ async def lobby_join(
|
||||
try:
|
||||
while True:
|
||||
packet = await websocket.receive_json()
|
||||
session.update_last_used() # Update activity on each message
|
||||
type = packet.get("type", None)
|
||||
data: dict[str, Any] | None = packet.get("data", None)
|
||||
if not type:
|
||||
@ -805,6 +958,7 @@ async def lobby_join(
|
||||
fallback = f"{displaced.name}-{secrets.token_hex(3)}"
|
||||
|
||||
displaced.setName(fallback)
|
||||
displaced.mark_displaced()
|
||||
logger.info(
|
||||
f"{displaced.getName()} <- displaced by takeover, new name {fallback}"
|
||||
)
|
||||
|
@ -259,6 +259,9 @@ class SessionSaved(BaseModel):
|
||||
id: str
|
||||
name: str = ""
|
||||
lobbies: List[LobbySaved] = []
|
||||
created_at: float = 0.0
|
||||
last_used: float = 0.0
|
||||
displaced_at: Optional[float] = None # When name was taken over
|
||||
|
||||
|
||||
class SessionsPayload(BaseModel):
|
||||
|
Loading…
x
Reference in New Issue
Block a user