diff --git a/server/logger.py b/server/logger.py index 506fe89..c36454c 100644 --- a/server/logger.py +++ b/server/logger.py @@ -1,18 +1,22 @@ import os +from typing import Any import warnings import logging logging_level = os.getenv("LOGGING_LEVEL", "INFO").upper() + class RelativePathFormatter(logging.Formatter): - def __init__(self, fmt=None, datefmt=None, remove_prefix=None): + def __init__( + self, fmt: Any = None, datefmt: Any = None, remove_prefix: str | None = None + ): super().__init__(fmt, datefmt) self.remove_prefix = remove_prefix or os.getcwd() # Ensure the prefix ends with a separator if not self.remove_prefix.endswith(os.sep): self.remove_prefix += os.sep - def format(self, record): + def format(self, record: logging.LogRecord): # Make a copy of the record to avoid modifying the original record = logging.makeLogRecord(record.__dict__) @@ -23,9 +27,11 @@ class RelativePathFormatter(logging.Formatter): return super().format(record) -def _setup_logging(level=logging_level) -> logging.Logger: +def _setup_logging(level: Any = logging_level) -> logging.Logger: os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" - warnings.filterwarnings("ignore", message="Overriding a previously registered kernel") + warnings.filterwarnings( + "ignore", message="Overriding a previously registered kernel" + ) warnings.filterwarnings("ignore", message="Warning only once for all operators") warnings.filterwarnings("ignore", message=".*Couldn't find ffmpeg or avconv.*") warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to") @@ -38,7 +44,8 @@ def _setup_logging(level=logging_level) -> logging.Logger: # Create a custom formatter formatter = RelativePathFormatter( - fmt="%(levelname)s - %(pathname)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + fmt="%(levelname)s - %(pathname)s:%(lineno)d - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", ) # Create a handler (e.g., StreamHandler for console output) @@ -59,6 +66,7 @@ def _setup_logging(level=logging_level) -> logging.Logger: "uvicorn.access", "fastapi", "starlette", + "httpx", ): logger = logging.getLogger(noisy_logger) logger.setLevel(logging.WARNING) diff --git a/server/main.py b/server/main.py index d4d7579..af27d5d 100644 --- a/server/main.py +++ b/server/main.py @@ -478,17 +478,18 @@ class Lobby: async def broadcast_chat_message(self, chat_message: ChatMessageModel) -> None: """Broadcast a chat message to all connected sessions in the lobby""" failed_sessions: list[Session] = [] - for session in self.sessions.values(): - if session.ws: + for peer in self.sessions.values(): + if peer.ws: try: - await session.ws.send_json( + logger.info(f"{self.getName()} -> chat_message({peer.getName()})") + await peer.ws.send_json( {"type": "chat_message", "data": chat_message.model_dump()} ) except Exception as e: logger.warning( - f"Failed to send chat message to {session.getName()}: {e}" + f"Failed to send chat message to {peer.getName()}: {e}" ) - failed_sessions.append(session) + failed_sessions.append(peer) # Clean up failed sessions for failed_session in failed_sessions: @@ -1260,7 +1261,8 @@ def getSession(session_id: str) -> Session | None: def getLobby(lobby_id: str) -> Lobby: lobby = lobbies.get(lobby_id, None) if not lobby: - logger.error(f"Lobby not found: {lobby_id}") + # Check if this might be a stale reference after cleanup + logger.warning(f"Lobby not found: {lobby_id} (may have been cleaned up)") raise Exception(f"Lobby not found: {lobby_id}") return lobby @@ -2162,6 +2164,13 @@ async def lobby_join( except Exception as e: logger.warning(f"Failed to relay session description: {e}") + case "status_check": + # Simple status check - just respond with success to keep connection alive + logger.debug(f"{session.getName()} <- status_check") + await websocket.send_json( + {"type": "status_ok", "timestamp": time.time()} + ) + case _: await websocket.send_json( { diff --git a/voicebot/webrtc_signaling.py b/voicebot/webrtc_signaling.py index 218e1f2..f469d5b 100644 --- a/voicebot/webrtc_signaling.py +++ b/voicebot/webrtc_signaling.py @@ -162,7 +162,13 @@ class WebRTCSignalingClient: logger.info( f"Attempting websocket connection to {ws_url} with ssl={ws_ssl}" ) - self.websocket = await websockets.connect(ws_url, ssl=ws_ssl) + self.websocket = await websockets.connect( + ws_url, + ssl=ws_ssl, + ping_interval=30, # Send ping every 30 seconds + ping_timeout=10, # Wait up to 10 seconds for pong response + close_timeout=10 # Wait up to 10 seconds for close handshake + ) logger.info("Connected to signaling server") # Set up local media @@ -203,6 +209,11 @@ class WebRTCSignalingClient: async def _periodic_registration_check(self): """Periodically check registration status and re-register if needed""" + # Add a small random delay to prevent all bots from checking at the same time + import random + initial_delay = random.uniform(0, 5) + await asyncio.sleep(initial_delay) + while not self.shutdown_requested: try: await asyncio.sleep(self.registration_check_interval) @@ -246,6 +257,11 @@ class WebRTCSignalingClient: await self._send_message("status_check", {"timestamp": time.time()}) logger.debug("Registration status check sent") return True + except websockets.exceptions.ConnectionClosed as e: + logger.warning(f"WebSocket connection closed during status check: {e}") + # Mark websocket as None to ensure reconnection attempts + self.websocket = None + return False except Exception as e: logger.warning(f"Failed to send status check: {e}") return False @@ -266,17 +282,36 @@ class WebRTCSignalingClient: if not self.websocket: logger.info("WebSocket lost, attempting to reconnect") await self._reconnect_websocket() + + # Double-check that websocket is connected after reconnect attempt + if not self.websocket: + logger.error("Failed to establish WebSocket connection for re-registration") + return - # Re-send name and join messages + # Re-send name and join messages with retry logic name_payload: MessageData = {"name": self.session_name} if self.name_password: name_payload["password"] = self.name_password logger.info("Re-sending set_name message") - await self._send_message("set_name", name_payload) + try: + await self._send_message("set_name", name_payload) + except websockets.exceptions.ConnectionClosed: + logger.warning("Connection closed while sending set_name, aborting re-registration") + return + except Exception as e: + logger.error(f"Failed to send set_name during re-registration: {e}") + return logger.info("Re-sending join message") - await self._send_message("join", {}) + try: + await self._send_message("join", {}) + except websockets.exceptions.ConnectionClosed: + logger.warning("Connection closed while sending join, aborting re-registration") + return + except Exception as e: + logger.error(f"Failed to send join during re-registration: {e}") + return # Mark as registered after successful re-join self.is_registered = True @@ -318,7 +353,13 @@ class WebRTCSignalingClient: ws_ssl = None logger.info(f"Reconnecting to signaling server: {ws_url}") - self.websocket = await websockets.connect(ws_url, ssl=ws_ssl) + self.websocket = await websockets.connect( + ws_url, + ssl=ws_ssl, + ping_interval=30, # Send ping every 30 seconds + ping_timeout=10, # Wait up to 10 seconds for pong response + close_timeout=10 # Wait up to 10 seconds for close handshake + ) logger.info("Successfully reconnected to signaling server") except Exception as e: @@ -424,20 +465,31 @@ class WebRTCSignalingClient: if not self.websocket: logger.error("No websocket connection") return + + ws = cast(WebSocketProtocol, self.websocket) + # Build message with explicit type to avoid type narrowing message: dict[str, object] = {"type": message_type} if data is not None: message["data"] = data - ws = cast(WebSocketProtocol, self.websocket) try: logger.debug(f"_send_message: Sending {message_type} with data: {data}") await ws.send(json.dumps(message)) logger.debug(f"_send_message: Sent message: {message_type}") + except websockets.exceptions.ConnectionClosed as e: + logger.error( + f"_send_message: WebSocket connection closed while sending {message_type}: {e}" + ) + # Mark websocket as None to trigger reconnection + self.websocket = None + self.is_registered = False + raise except Exception as e: logger.error( f"_send_message: Failed to send {message_type}: {e}", exc_info=True ) + raise async def _handle_messages(self): """Handle incoming messages from signaling server""" @@ -550,6 +602,10 @@ class WebRTCSignalingClient: # Handle status check messages - these are used to verify connection logger.debug(f"Received status check message: {data}") # No special processing needed for status checks, just acknowledge receipt + elif msg_type == "status_ok": + # Handle status_ok response from server + logger.debug(f"Received status_ok from server: {data}") + # This confirms the connection is healthy elif msg_type == "chat_message": try: validated = ChatMessageModel.model_validate(data)