From 2e91a4eadb469bb038375fbda7b565fe35742676 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Wed, 3 Sep 2025 14:10:47 -0700 Subject: [PATCH] Add server reconnect logic --- voicebot/main.py | 213 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 198 insertions(+), 15 deletions(-) diff --git a/voicebot/main.py b/voicebot/main.py index 61e1f51..d49d3ac 100644 --- a/voicebot/main.py +++ b/voicebot/main.py @@ -92,6 +92,7 @@ class VoicebotArgs(BaseModel): password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name") private: bool = Field(default=False, description="Create the lobby as private") insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL") + registration_check_interval: float = Field(default=30.0, description="Interval in seconds for checking registration status", ge=5.0, le=300.0) @classmethod def from_environment(cls) -> 'VoicebotArgs': @@ -110,7 +111,8 @@ class VoicebotArgs(BaseModel): session_id=os.getenv('VOICEBOT_SESSION_ID', None), password=os.getenv('VOICEBOT_PASSWORD', None), private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true', - insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true' + insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true', + registration_check_interval=float(os.getenv('VOICEBOT_REGISTRATION_CHECK_INTERVAL', '30.0')) ) @classmethod @@ -128,7 +130,8 @@ class VoicebotArgs(BaseModel): session_id=getattr(args, 'session_id', None), password=getattr(args, 'password', None), private=getattr(args, 'private', False), - insecure=getattr(args, 'insecure', False) + insecure=getattr(args, 'insecure', False), + registration_check_interval=float(getattr(args, 'registration_check_interval', 30.0)) ) # Bot orchestration imports @@ -186,6 +189,7 @@ class WebRTCSignalingClient: session_name: str, insecure: bool = False, create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None, + registration_check_interval: float = 30.0, ): self.server_url = server_url self.lobby_id = lobby_id @@ -212,6 +216,12 @@ class WebRTCSignalingClient: self.initiated_offer: set[str] = set() self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {} + # Registration status tracking + self.is_registered: bool = False + self.last_registration_check: float = 0 + self.registration_check_interval: float = registration_check_interval + self.registration_check_task: Optional[asyncio.Task[None]] = None + # Event callbacks self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None @@ -267,17 +277,154 @@ class WebRTCSignalingClient: await self._send_message("set_name", name_payload) logger.info("Sending join message") await self._send_message("join", {}) + + # Mark as registered after successful join + self.is_registered = True + import time + self.last_registration_check = time.time() + + # Start periodic registration check + self.registration_check_task = asyncio.create_task(self._periodic_registration_check()) # Start message handling logger.info("Starting message handler loop") - await self._handle_messages() + try: + await self._handle_messages() + except Exception as e: + logger.error(f"Message handling stopped: {e}") + self.is_registered = False + raise except Exception as e: logger.error(f"Failed to connect to signaling server: {e}", exc_info=True) raise + async def _periodic_registration_check(self): + """Periodically check registration status and re-register if needed""" + import time + + while True: + try: + await asyncio.sleep(self.registration_check_interval) + + current_time = time.time() + if current_time - self.last_registration_check < self.registration_check_interval: + continue + + # Check if we're still connected and registered + if not await self._check_registration_status(): + logger.warning("Registration check failed, attempting to re-register") + await self._re_register() + + self.last_registration_check = current_time + + except asyncio.CancelledError: + logger.info("Registration check task cancelled") + break + except Exception as e: + logger.error(f"Error in periodic registration check: {e}", exc_info=True) + # Continue checking even if one iteration fails + continue + + async def _check_registration_status(self) -> bool: + """Check if the voicebot is still registered with the server""" + try: + # First check if websocket is still connected + if not self.websocket: + logger.warning("WebSocket connection lost") + return False + + # Try to send a ping/status check message to verify connection + # We'll use a simple status message to check connectivity + try: + import time + await self._send_message("status_check", {"timestamp": time.time()}) + logger.debug("Registration status check sent") + return True + except Exception as e: + logger.warning(f"Failed to send status check: {e}") + return False + + except Exception as e: + logger.error(f"Error checking registration status: {e}") + return False + + async def _re_register(self): + """Attempt to re-register with the server""" + try: + logger.info("Attempting to re-register with server") + + # Mark as not registered during re-registration attempt + self.is_registered = False + + # Try to reconnect the websocket if it's lost + if not self.websocket: + logger.info("WebSocket lost, attempting to reconnect") + await self._reconnect_websocket() + + # Re-send name and join messages + name_payload: MessageData = {"name": self.session_name} + if self.name_password: + name_payload["password"] = self.name_password + + logger.info("Re-sending set_name message") + await self._send_message("set_name", name_payload) + + logger.info("Re-sending join message") + await self._send_message("join", {}) + + # Mark as registered after successful re-join + self.is_registered = True + import time + self.last_registration_check = time.time() + + logger.info("Successfully re-registered with server") + + except Exception as e: + logger.error(f"Failed to re-register with server: {e}", exc_info=True) + # Will try again on next check interval + + async def _reconnect_websocket(self): + """Reconnect the WebSocket connection""" + try: + # Close existing connection if any + if self.websocket: + try: + ws = cast(WebSocketProtocol, self.websocket) + await ws.close() + except Exception: + pass + self.websocket = None + + # Reconnect + ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}" + + # If insecure (self-signed certs), create an SSL context for the websocket + ws_ssl = None + if self.insecure: + ws_ssl = ssl.create_default_context() + ws_ssl.check_hostname = False + ws_ssl.verify_mode = ssl.CERT_NONE + + logger.info(f"Reconnecting to signaling server: {ws_url}") + self.websocket = await websockets.connect(ws_url, ssl=ws_ssl) + logger.info("Successfully reconnected to signaling server") + + except Exception as e: + logger.error(f"Failed to reconnect websocket: {e}", exc_info=True) + raise + async def disconnect(self): """Disconnect from signaling server and cleanup""" + # Cancel the registration check task + if self.registration_check_task and not self.registration_check_task.done(): + self.registration_check_task.cancel() + try: + await self.registration_check_task + except asyncio.CancelledError: + pass + self.registration_check_task = None + if self.websocket: ws = cast(WebSocketProtocol, self.websocket) await ws.close() @@ -290,6 +437,9 @@ class WebRTCSignalingClient: for track in self.local_tracks.values(): track.stop() + # Reset registration status + self.is_registered = False + logger.info("Disconnected from signaling server") async def _setup_local_media(self): @@ -355,9 +505,12 @@ class WebRTCSignalingClient: continue await self._process_message(data) except websockets.exceptions.ConnectionClosed as e: - logger.info(f"WebSocket connection closed: {e}") + logger.warning(f"WebSocket connection closed: {e}") + self.is_registered = False + # The periodic registration check will detect this and attempt reconnection except Exception as e: logger.error(f"Error handling messages: {e}", exc_info=True) + self.is_registered = False async def _process_message(self, message: MessageData): """Process incoming signaling messages""" @@ -1157,7 +1310,9 @@ async def main_with_args(args: VoicebotArgs): ws_base = _ws_url(args.server_url) client = WebRTCSignalingClient( - ws_base, lobby_id, session_id, args.session_name, insecure=args.insecure + ws_base, lobby_id, session_id, args.session_name, + insecure=args.insecure, + registration_check_interval=args.registration_check_interval ) # Set up event handlers @@ -1194,16 +1349,42 @@ async def main_with_args(args: VoicebotArgs): client.on_peer_removed = on_peer_removed client.on_track_received = on_track_received - try: - # Connect and run - # If a password was provided on the CLI, store it on the client for use when setting name - if args.password: - client.name_password = args.password - await client.connect() - except KeyboardInterrupt: - print("Shutting down...") - finally: - await client.disconnect() + # Retry loop for connection resilience + max_retries = 5 + retry_delay = 5.0 # seconds + retry_count = 0 + + while retry_count < max_retries: + try: + # If a password was provided on the CLI, store it on the client for use when setting name + if args.password: + client.name_password = args.password + + print(f"Attempting to connect (attempt {retry_count + 1}/{max_retries})") + await client.connect() + + except KeyboardInterrupt: + print("Shutting down...") + break + except Exception as e: + logger.error(f"Connection failed (attempt {retry_count + 1}): {e}") + retry_count += 1 + + if retry_count < max_retries: + print(f"Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + # Exponential backoff with max delay of 60 seconds + retry_delay = min(retry_delay * 1.5, 60.0) + else: + print("Max retries exceeded. Giving up.") + break + finally: + try: + await client.disconnect() + except Exception as e: + logger.error(f"Error during disconnect: {e}") + + print("Voicebot client stopped.") # --- FastAPI service for bot discovery and orchestration ------------------- @@ -1578,6 +1759,8 @@ if __name__ == "__main__": parser.add_argument("--private", action="store_true", help="Create the lobby as private (client mode)") parser.add_argument("--insecure", action="store_true", help="Allow insecure connections") + parser.add_argument("--registration-check-interval", type=float, default=30.0, + help="Interval in seconds for checking registration status (5.0-300.0)") args = parser.parse_args()