Race conditions
This commit is contained in:
parent
25b14d7928
commit
eb12d95bc5
@ -1,18 +1,22 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Any
|
||||||
import warnings
|
import warnings
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper()
|
logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper()
|
||||||
|
|
||||||
|
|
||||||
class RelativePathFormatter(logging.Formatter):
|
class RelativePathFormatter(logging.Formatter):
|
||||||
def __init__(self, fmt=None, datefmt=None, remove_prefix=None):
|
def __init__(
|
||||||
|
self, fmt: Any = None, datefmt: Any = None, remove_prefix: str | None = None
|
||||||
|
):
|
||||||
super().__init__(fmt, datefmt)
|
super().__init__(fmt, datefmt)
|
||||||
self.remove_prefix = remove_prefix or os.getcwd()
|
self.remove_prefix = remove_prefix or os.getcwd()
|
||||||
# Ensure the prefix ends with a separator
|
# Ensure the prefix ends with a separator
|
||||||
if not self.remove_prefix.endswith(os.sep):
|
if not self.remove_prefix.endswith(os.sep):
|
||||||
self.remove_prefix += os.sep
|
self.remove_prefix += os.sep
|
||||||
|
|
||||||
def format(self, record):
|
def format(self, record: logging.LogRecord):
|
||||||
# Make a copy of the record to avoid modifying the original
|
# Make a copy of the record to avoid modifying the original
|
||||||
record = logging.makeLogRecord(record.__dict__)
|
record = logging.makeLogRecord(record.__dict__)
|
||||||
|
|
||||||
@ -23,9 +27,11 @@ class RelativePathFormatter(logging.Formatter):
|
|||||||
return super().format(record)
|
return super().format(record)
|
||||||
|
|
||||||
|
|
||||||
def _setup_logging(level=logging_level) -> logging.Logger:
|
def _setup_logging(level: Any = logging_level) -> logging.Logger:
|
||||||
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR"
|
||||||
warnings.filterwarnings("ignore", message="Overriding a previously registered kernel")
|
warnings.filterwarnings(
|
||||||
|
"ignore", message="Overriding a previously registered kernel"
|
||||||
|
)
|
||||||
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
warnings.filterwarnings("ignore", message="Warning only once for all operators")
|
||||||
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
|
warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*")
|
||||||
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
|
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to")
|
||||||
@ -38,7 +44,8 @@ def _setup_logging(level=logging_level) -> logging.Logger:
|
|||||||
|
|
||||||
# Create a custom formatter
|
# Create a custom formatter
|
||||||
formatter = RelativePathFormatter(
|
formatter = RelativePathFormatter(
|
||||||
fmt="%(levelname)s - %(pathname)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
fmt="%(levelname)s - %(pathname)s:%(lineno)d - %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a handler (e.g., StreamHandler for console output)
|
# Create a handler (e.g., StreamHandler for console output)
|
||||||
@ -59,6 +66,7 @@ def _setup_logging(level=logging_level) -> logging.Logger:
|
|||||||
"uvicorn.access",
|
"uvicorn.access",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"starlette",
|
"starlette",
|
||||||
|
"httpx",
|
||||||
):
|
):
|
||||||
logger = logging.getLogger(noisy_logger)
|
logger = logging.getLogger(noisy_logger)
|
||||||
logger.setLevel(logging.WARNING)
|
logger.setLevel(logging.WARNING)
|
||||||
|
@ -478,17 +478,18 @@ class Lobby:
|
|||||||
async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None:
|
async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None:
|
||||||
"""Broadcast a chat message to all connected sessions in the lobby"""
|
"""Broadcast a chat message to all connected sessions in the lobby"""
|
||||||
failed_sessions: list[Session] = []
|
failed_sessions: list[Session] = []
|
||||||
for session in self.sessions.values():
|
for peer in self.sessions.values():
|
||||||
if session.ws:
|
if peer.ws:
|
||||||
try:
|
try:
|
||||||
await session.ws.send_json(
|
logger.info(f"{self.getName()} -> chat_message({peer.getName()})")
|
||||||
|
await peer.ws.send_json(
|
||||||
{"type": "chat_message", "data": chat_message.model_dump()}
|
{"type": "chat_message", "data": chat_message.model_dump()}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to send chat message to {session.getName()}: {e}"
|
f"Failed to send chat message to {peer.getName()}: {e}"
|
||||||
)
|
)
|
||||||
failed_sessions.append(session)
|
failed_sessions.append(peer)
|
||||||
|
|
||||||
# Clean up failed sessions
|
# Clean up failed sessions
|
||||||
for failed_session in failed_sessions:
|
for failed_session in failed_sessions:
|
||||||
@ -1260,7 +1261,8 @@ def getSession(session_id: str) -> Session | None:
|
|||||||
def getLobby(lobby_id: str) -> Lobby:
|
def getLobby(lobby_id: str) -> Lobby:
|
||||||
lobby = lobbies.get(lobby_id, None)
|
lobby = lobbies.get(lobby_id, None)
|
||||||
if not lobby:
|
if not lobby:
|
||||||
logger.error(f"Lobby not found: {lobby_id}")
|
# Check if this might be a stale reference after cleanup
|
||||||
|
logger.warning(f"Lobby not found: {lobby_id} (may have been cleaned up)")
|
||||||
raise Exception(f"Lobby not found: {lobby_id}")
|
raise Exception(f"Lobby not found: {lobby_id}")
|
||||||
return lobby
|
return lobby
|
||||||
|
|
||||||
@ -2162,6 +2164,13 @@ async def lobby_join(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to relay session description: {e}")
|
logger.warning(f"Failed to relay session description: {e}")
|
||||||
|
|
||||||
|
case "status_check":
|
||||||
|
# Simple status check - just respond with success to keep connection alive
|
||||||
|
logger.debug(f"{session.getName()} <- status_check")
|
||||||
|
await websocket.send_json(
|
||||||
|
{"type": "status_ok", "timestamp": time.time()}
|
||||||
|
)
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
await websocket.send_json(
|
await websocket.send_json(
|
||||||
{
|
{
|
||||||
|
@ -162,7 +162,13 @@ class WebRTCSignalingClient:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Attempting websocket connection to {ws_url} with ssl={ws_ssl}"
|
f"Attempting websocket connection to {ws_url} with ssl={ws_ssl}"
|
||||||
)
|
)
|
||||||
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
|
self.websocket = await websockets.connect(
|
||||||
|
ws_url,
|
||||||
|
ssl=ws_ssl,
|
||||||
|
ping_interval=30, # Send ping every 30 seconds
|
||||||
|
ping_timeout=10, # Wait up to 10 seconds for pong response
|
||||||
|
close_timeout=10 # Wait up to 10 seconds for close handshake
|
||||||
|
)
|
||||||
logger.info("Connected to signaling server")
|
logger.info("Connected to signaling server")
|
||||||
|
|
||||||
# Set up local media
|
# Set up local media
|
||||||
@ -203,6 +209,11 @@ class WebRTCSignalingClient:
|
|||||||
async def _periodic_registration_check(self):
|
async def _periodic_registration_check(self):
|
||||||
"""Periodically check registration status and re-register if needed"""
|
"""Periodically check registration status and re-register if needed"""
|
||||||
|
|
||||||
|
# Add a small random delay to prevent all bots from checking at the same time
|
||||||
|
import random
|
||||||
|
initial_delay = random.uniform(0, 5)
|
||||||
|
await asyncio.sleep(initial_delay)
|
||||||
|
|
||||||
while not self.shutdown_requested:
|
while not self.shutdown_requested:
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(self.registration_check_interval)
|
await asyncio.sleep(self.registration_check_interval)
|
||||||
@ -246,6 +257,11 @@ class WebRTCSignalingClient:
|
|||||||
await self._send_message("status_check", {"timestamp": time.time()})
|
await self._send_message("status_check", {"timestamp": time.time()})
|
||||||
logger.debug("Registration status check sent")
|
logger.debug("Registration status check sent")
|
||||||
return True
|
return True
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
logger.warning(f"WebSocket connection closed during status check: {e}")
|
||||||
|
# Mark websocket as None to ensure reconnection attempts
|
||||||
|
self.websocket = None
|
||||||
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to send status check: {e}")
|
logger.warning(f"Failed to send status check: {e}")
|
||||||
return False
|
return False
|
||||||
@ -267,16 +283,35 @@ class WebRTCSignalingClient:
|
|||||||
logger.info("WebSocket lost, attempting to reconnect")
|
logger.info("WebSocket lost, attempting to reconnect")
|
||||||
await self._reconnect_websocket()
|
await self._reconnect_websocket()
|
||||||
|
|
||||||
# Re-send name and join messages
|
# Double-check that websocket is connected after reconnect attempt
|
||||||
|
if not self.websocket:
|
||||||
|
logger.error("Failed to establish WebSocket connection for re-registration")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Re-send name and join messages with retry logic
|
||||||
name_payload: MessageData = {"name": self.session_name}
|
name_payload: MessageData = {"name": self.session_name}
|
||||||
if self.name_password:
|
if self.name_password:
|
||||||
name_payload["password"] = self.name_password
|
name_payload["password"] = self.name_password
|
||||||
|
|
||||||
logger.info("Re-sending set_name message")
|
logger.info("Re-sending set_name message")
|
||||||
await self._send_message("set_name", name_payload)
|
try:
|
||||||
|
await self._send_message("set_name", name_payload)
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
logger.warning("Connection closed while sending set_name, aborting re-registration")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send set_name during re-registration: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
logger.info("Re-sending join message")
|
logger.info("Re-sending join message")
|
||||||
await self._send_message("join", {})
|
try:
|
||||||
|
await self._send_message("join", {})
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
logger.warning("Connection closed while sending join, aborting re-registration")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send join during re-registration: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
# Mark as registered after successful re-join
|
# Mark as registered after successful re-join
|
||||||
self.is_registered = True
|
self.is_registered = True
|
||||||
@ -318,7 +353,13 @@ class WebRTCSignalingClient:
|
|||||||
ws_ssl = None
|
ws_ssl = None
|
||||||
|
|
||||||
logger.info(f"Reconnecting to signaling server: {ws_url}")
|
logger.info(f"Reconnecting to signaling server: {ws_url}")
|
||||||
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
|
self.websocket = await websockets.connect(
|
||||||
|
ws_url,
|
||||||
|
ssl=ws_ssl,
|
||||||
|
ping_interval=30, # Send ping every 30 seconds
|
||||||
|
ping_timeout=10, # Wait up to 10 seconds for pong response
|
||||||
|
close_timeout=10 # Wait up to 10 seconds for close handshake
|
||||||
|
)
|
||||||
logger.info("Successfully reconnected to signaling server")
|
logger.info("Successfully reconnected to signaling server")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -424,20 +465,31 @@ class WebRTCSignalingClient:
|
|||||||
if not self.websocket:
|
if not self.websocket:
|
||||||
logger.error("No websocket connection")
|
logger.error("No websocket connection")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
ws = cast(WebSocketProtocol, self.websocket)
|
||||||
|
|
||||||
# Build message with explicit type to avoid type narrowing
|
# Build message with explicit type to avoid type narrowing
|
||||||
message: dict[str, object] = {"type": message_type}
|
message: dict[str, object] = {"type": message_type}
|
||||||
if data is not None:
|
if data is not None:
|
||||||
message["data"] = data
|
message["data"] = data
|
||||||
|
|
||||||
ws = cast(WebSocketProtocol, self.websocket)
|
|
||||||
try:
|
try:
|
||||||
logger.debug(f"_send_message: Sending {message_type} with data: {data}")
|
logger.debug(f"_send_message: Sending {message_type} with data: {data}")
|
||||||
await ws.send(json.dumps(message))
|
await ws.send(json.dumps(message))
|
||||||
logger.debug(f"_send_message: Sent message: {message_type}")
|
logger.debug(f"_send_message: Sent message: {message_type}")
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
logger.error(
|
||||||
|
f"_send_message: WebSocket connection closed while sending {message_type}: {e}"
|
||||||
|
)
|
||||||
|
# Mark websocket as None to trigger reconnection
|
||||||
|
self.websocket = None
|
||||||
|
self.is_registered = False
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"_send_message: Failed to send {message_type}: {e}", exc_info=True
|
f"_send_message: Failed to send {message_type}: {e}", exc_info=True
|
||||||
)
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
async def _handle_messages(self):
|
async def _handle_messages(self):
|
||||||
"""Handle incoming messages from signaling server"""
|
"""Handle incoming messages from signaling server"""
|
||||||
@ -550,6 +602,10 @@ class WebRTCSignalingClient:
|
|||||||
# Handle status check messages - these are used to verify connection
|
# Handle status check messages - these are used to verify connection
|
||||||
logger.debug(f"Received status check message: {data}")
|
logger.debug(f"Received status check message: {data}")
|
||||||
# No special processing needed for status checks, just acknowledge receipt
|
# No special processing needed for status checks, just acknowledge receipt
|
||||||
|
elif msg_type == "status_ok":
|
||||||
|
# Handle status_ok response from server
|
||||||
|
logger.debug(f"Received status_ok from server: {data}")
|
||||||
|
# This confirms the connection is healthy
|
||||||
elif msg_type == "chat_message":
|
elif msg_type == "chat_message":
|
||||||
try:
|
try:
|
||||||
validated = ChatMessageModel.model_validate(data)
|
validated = ChatMessageModel.model_validate(data)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user