Race conditions

This commit is contained in:
James Ketr 2025-09-03 17:33:52 -07:00
parent 25b14d7928
commit eb12d95bc5
3 changed files with 90 additions and 17 deletions

View File

@ -1,18 +1,22 @@
import os
from typing import Any
import warnings
import logging
logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper()
class RelativePathFormatter(logging.Formatter):
def __init__(self, fmt=None, datefmt=None, remove_prefix=None):
def __init__(
self, fmt: Any = None, datefmt: Any = None, remove_prefix: str | None = None
):
super().__init__(fmt, datefmt)
self.remove_prefix = remove_prefix or os.getcwd()
# Ensure the prefix ends with a separator
if not self.remove_prefix.endswith(os.sep):
self.remove_prefix += os.sep
def format(self, record):
def format(self, record: logging.LogRecord):
# Make a copy of the record to avoid modifying the original
record = logging.makeLogRecord(record.__dict__)
@ -23,9 +27,11 @@ class RelativePathFormatter(logging.Formatter):
return super().format(record)
def _setup_logging(level=logging_level) -> logging.Logger:
def _setup_logging(level: Any = logging_level) -> logging.Logger:
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
warnings.filterwarnings(
"ignore", message="Overriding a previously registered kernel"
)
warnings.filterwarnings("ignore", message="Warning only once for all operators")
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
@ -38,7 +44,8 @@ def _setup_logging(level=logging_level) -> logging.Logger:
# Create a custom formatter
formatter = RelativePathFormatter(
fmt="%(levelname)s - %(pathname)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
fmt="%(levelname)s - %(pathname)s:%(lineno)d - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Create a handler (e.g., StreamHandler for console output)
@ -59,6 +66,7 @@ def _setup_logging(level=logging_level) -> logging.Logger:
"uvicorn.access",
"fastapi",
"starlette",
"httpx",
):
logger = logging.getLogger(noisy_logger)
logger.setLevel(logging.WARNING)

View File

@ -478,17 +478,18 @@ class Lobby:
async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None:
"""Broadcast a chat message to all connected sessions in the lobby"""
failed_sessions: list[Session] = []
for session in self.sessions.values():
if session.ws:
for peer in self.sessions.values():
if peer.ws:
try:
await session.ws.send_json(
logger.info(f"{self.getName()} -> chat_message({peer.getName()})")
await peer.ws.send_json(
{"type": "chat_message", "data": chat_message.model_dump()}
)
except Exception as e:
logger.warning(
f"Failed to send chat message to {session.getName()}: {e}"
f"Failed to send chat message to {peer.getName()}: {e}"
)
failed_sessions.append(session)
failed_sessions.append(peer)
# Clean up failed sessions
for failed_session in failed_sessions:
@ -1260,7 +1261,8 @@ def getSession(session_id: str) -> Session | None:
def getLobby(lobby_id: str) -> Lobby:
lobby = lobbies.get(lobby_id, None)
if not lobby:
logger.error(f"Lobby not found: {lobby_id}")
# Check if this might be a stale reference after cleanup
logger.warning(f"Lobby not found: {lobby_id} (may have been cleaned up)")
raise Exception(f"Lobby not found: {lobby_id}")
return lobby
@ -2162,6 +2164,13 @@ async def lobby_join(
except Exception as e:
logger.warning(f"Failed to relay session description: {e}")
case "status_check":
# Simple status check - just respond with success to keep connection alive
logger.debug(f"{session.getName()} <- status_check")
await websocket.send_json(
{"type": "status_ok", "timestamp": time.time()}
)
case _:
await websocket.send_json(
{

View File

@ -162,7 +162,13 @@ class WebRTCSignalingClient:
logger.info(
f"Attempting websocket connection to {ws_url} with ssl={ws_ssl}"
)
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
self.websocket = await websockets.connect(
ws_url,
ssl=ws_ssl,
ping_interval=30, # Send ping every 30 seconds
ping_timeout=10, # Wait up to 10 seconds for pong response
close_timeout=10 # Wait up to 10 seconds for close handshake
)
logger.info("Connected to signaling server")
# Set up local media
@ -203,6 +209,11 @@ class WebRTCSignalingClient:
async def _periodic_registration_check(self):
"""Periodically check registration status and re-register if needed"""
# Add a small random delay to prevent all bots from checking at the same time
import random
initial_delay = random.uniform(0, 5)
await asyncio.sleep(initial_delay)
while not self.shutdown_requested:
try:
await asyncio.sleep(self.registration_check_interval)
@ -246,6 +257,11 @@ class WebRTCSignalingClient:
await self._send_message("status_check", {"timestamp": time.time()})
logger.debug("Registration status check sent")
return True
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"WebSocket connection closed during status check: {e}")
# Mark websocket as None to ensure reconnection attempts
self.websocket = None
return False
except Exception as e:
logger.warning(f"Failed to send status check: {e}")
return False
@ -266,17 +282,36 @@ class WebRTCSignalingClient:
if not self.websocket:
logger.info("WebSocket lost, attempting to reconnect")
await self._reconnect_websocket()
# Double-check that websocket is connected after reconnect attempt
if not self.websocket:
logger.error("Failed to establish WebSocket connection for re-registration")
return
# Re-send name and join messages
# Re-send name and join messages with retry logic
name_payload: MessageData = {"name": self.session_name}
if self.name_password:
name_payload["password"] = self.name_password
logger.info("Re-sending set_name message")
await self._send_message("set_name", name_payload)
try:
await self._send_message("set_name", name_payload)
except websockets.exceptions.ConnectionClosed:
logger.warning("Connection closed while sending set_name, aborting re-registration")
return
except Exception as e:
logger.error(f"Failed to send set_name during re-registration: {e}")
return
logger.info("Re-sending join message")
await self._send_message("join", {})
try:
await self._send_message("join", {})
except websockets.exceptions.ConnectionClosed:
logger.warning("Connection closed while sending join, aborting re-registration")
return
except Exception as e:
logger.error(f"Failed to send join during re-registration: {e}")
return
# Mark as registered after successful re-join
self.is_registered = True
@ -318,7 +353,13 @@ class WebRTCSignalingClient:
ws_ssl = None
logger.info(f"Reconnecting to signaling server: {ws_url}")
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
self.websocket = await websockets.connect(
ws_url,
ssl=ws_ssl,
ping_interval=30, # Send ping every 30 seconds
ping_timeout=10, # Wait up to 10 seconds for pong response
close_timeout=10 # Wait up to 10 seconds for close handshake
)
logger.info("Successfully reconnected to signaling server")
except Exception as e:
@ -424,20 +465,31 @@ class WebRTCSignalingClient:
if not self.websocket:
logger.error("No websocket connection")
return
ws = cast(WebSocketProtocol, self.websocket)
# Build message with explicit type to avoid type narrowing
message: dict[str, object] = {"type": message_type}
if data is not None:
message["data"] = data
ws = cast(WebSocketProtocol, self.websocket)
try:
logger.debug(f"_send_message: Sending {message_type} with data: {data}")
await ws.send(json.dumps(message))
logger.debug(f"_send_message: Sent message: {message_type}")
except websockets.exceptions.ConnectionClosed as e:
logger.error(
f"_send_message: WebSocket connection closed while sending {message_type}: {e}"
)
# Mark websocket as None to trigger reconnection
self.websocket = None
self.is_registered = False
raise
except Exception as e:
logger.error(
f"_send_message: Failed to send {message_type}: {e}", exc_info=True
)
raise
async def _handle_messages(self):
"""Handle incoming messages from signaling server"""
@ -550,6 +602,10 @@ class WebRTCSignalingClient:
# Handle status check messages - these are used to verify connection
logger.debug(f"Received status check message: {data}")
# No special processing needed for status checks, just acknowledge receipt
elif msg_type == "status_ok":
# Handle status_ok response from server
logger.debug(f"Received status_ok from server: {data}")
# This confirms the connection is healthy
elif msg_type == "chat_message":
try:
validated = ChatMessageModel.model_validate(data)