Add server reconnect logic

This commit is contained in:
James Ketr 2025-09-03 14:10:47 -07:00
parent 39666eddbe
commit 2e91a4eadb

View File

@ -92,6 +92,7 @@ class VoicebotArgs(BaseModel):
password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name")
private: bool = Field(default=False, description="Create the lobby as private")
insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL")
registration_check_interval: float = Field(default=30.0, description="Interval in seconds for checking registration status", ge=5.0, le=300.0)
@classmethod
def from_environment(cls) -> 'VoicebotArgs':
@ -110,7 +111,8 @@ class VoicebotArgs(BaseModel):
session_id=os.getenv('VOICEBOT_SESSION_ID', None),
password=os.getenv('VOICEBOT_PASSWORD', None),
private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true',
insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true'
insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true',
registration_check_interval=float(os.getenv('VOICEBOT_REGISTRATION_CHECK_INTERVAL', '30.0'))
)
@classmethod
@ -128,7 +130,8 @@ class VoicebotArgs(BaseModel):
session_id=getattr(args, 'session_id', None),
password=getattr(args, 'password', None),
private=getattr(args, 'private', False),
insecure=getattr(args, 'insecure', False)
insecure=getattr(args, 'insecure', False),
registration_check_interval=float(getattr(args, 'registration_check_interval', 30.0))
)
# Bot orchestration imports
@ -186,6 +189,7 @@ class WebRTCSignalingClient:
session_name: str,
insecure: bool = False,
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
registration_check_interval: float = 30.0,
):
self.server_url = server_url
self.lobby_id = lobby_id
@ -212,6 +216,12 @@ class WebRTCSignalingClient:
self.initiated_offer: set[str] = set()
self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {}
# Registration status tracking
self.is_registered: bool = False
self.last_registration_check: float = 0
self.registration_check_interval: float = registration_check_interval
self.registration_check_task: Optional[asyncio.Task[None]] = None
# Event callbacks
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
@ -267,17 +277,154 @@ class WebRTCSignalingClient:
await self._send_message("set_name", name_payload)
logger.info("Sending join message")
await self._send_message("join", {})
# Mark as registered after successful join
self.is_registered = True
import time
self.last_registration_check = time.time()
# Start periodic registration check
self.registration_check_task = asyncio.create_task(self._periodic_registration_check())
# Start message handling
logger.info("Starting message handler loop")
await self._handle_messages()
try:
await self._handle_messages()
except Exception as e:
logger.error(f"Message handling stopped: {e}")
self.is_registered = False
raise
except Exception as e:
logger.error(f"Failed to connect to signaling server: {e}", exc_info=True)
raise
async def _periodic_registration_check(self):
"""Periodically check registration status and re-register if needed"""
import time
while True:
try:
await asyncio.sleep(self.registration_check_interval)
current_time = time.time()
if current_time - self.last_registration_check < self.registration_check_interval:
continue
# Check if we're still connected and registered
if not await self._check_registration_status():
logger.warning("Registration check failed, attempting to re-register")
await self._re_register()
self.last_registration_check = current_time
except asyncio.CancelledError:
logger.info("Registration check task cancelled")
break
except Exception as e:
logger.error(f"Error in periodic registration check: {e}", exc_info=True)
# Continue checking even if one iteration fails
continue
async def _check_registration_status(self) -> bool:
"""Check if the voicebot is still registered with the server"""
try:
# First check if websocket is still connected
if not self.websocket:
logger.warning("WebSocket connection lost")
return False
# Try to send a ping/status check message to verify connection
# We'll use a simple status message to check connectivity
try:
import time
await self._send_message("status_check", {"timestamp": time.time()})
logger.debug("Registration status check sent")
return True
except Exception as e:
logger.warning(f"Failed to send status check: {e}")
return False
except Exception as e:
logger.error(f"Error checking registration status: {e}")
return False
async def _re_register(self):
"""Attempt to re-register with the server"""
try:
logger.info("Attempting to re-register with server")
# Mark as not registered during re-registration attempt
self.is_registered = False
# Try to reconnect the websocket if it's lost
if not self.websocket:
logger.info("WebSocket lost, attempting to reconnect")
await self._reconnect_websocket()
# Re-send name and join messages
name_payload: MessageData = {"name": self.session_name}
if self.name_password:
name_payload["password"] = self.name_password
logger.info("Re-sending set_name message")
await self._send_message("set_name", name_payload)
logger.info("Re-sending join message")
await self._send_message("join", {})
# Mark as registered after successful re-join
self.is_registered = True
import time
self.last_registration_check = time.time()
logger.info("Successfully re-registered with server")
except Exception as e:
logger.error(f"Failed to re-register with server: {e}", exc_info=True)
# Will try again on next check interval
async def _reconnect_websocket(self):
"""Reconnect the WebSocket connection"""
try:
# Close existing connection if any
if self.websocket:
try:
ws = cast(WebSocketProtocol, self.websocket)
await ws.close()
except Exception:
pass
self.websocket = None
# Reconnect
ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
# If insecure (self-signed certs), create an SSL context for the websocket
ws_ssl = None
if self.insecure:
ws_ssl = ssl.create_default_context()
ws_ssl.check_hostname = False
ws_ssl.verify_mode = ssl.CERT_NONE
logger.info(f"Reconnecting to signaling server: {ws_url}")
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
logger.info("Successfully reconnected to signaling server")
except Exception as e:
logger.error(f"Failed to reconnect websocket: {e}", exc_info=True)
raise
async def disconnect(self):
"""Disconnect from signaling server and cleanup"""
# Cancel the registration check task
if self.registration_check_task and not self.registration_check_task.done():
self.registration_check_task.cancel()
try:
await self.registration_check_task
except asyncio.CancelledError:
pass
self.registration_check_task = None
if self.websocket:
ws = cast(WebSocketProtocol, self.websocket)
await ws.close()
@ -290,6 +437,9 @@ class WebRTCSignalingClient:
for track in self.local_tracks.values():
track.stop()
# Reset registration status
self.is_registered = False
logger.info("Disconnected from signaling server")
async def _setup_local_media(self):
@ -355,9 +505,12 @@ class WebRTCSignalingClient:
continue
await self._process_message(data)
except websockets.exceptions.ConnectionClosed as e:
logger.info(f"WebSocket connection closed: {e}")
logger.warning(f"WebSocket connection closed: {e}")
self.is_registered = False
# The periodic registration check will detect this and attempt reconnection
except Exception as e:
logger.error(f"Error handling messages: {e}", exc_info=True)
self.is_registered = False
async def _process_message(self, message: MessageData):
"""Process incoming signaling messages"""
@ -1157,7 +1310,9 @@ async def main_with_args(args: VoicebotArgs):
ws_base = _ws_url(args.server_url)
client = WebRTCSignalingClient(
ws_base, lobby_id, session_id, args.session_name, insecure=args.insecure
ws_base, lobby_id, session_id, args.session_name,
insecure=args.insecure,
registration_check_interval=args.registration_check_interval
)
# Set up event handlers
@ -1194,16 +1349,42 @@ async def main_with_args(args: VoicebotArgs):
client.on_peer_removed = on_peer_removed
client.on_track_received = on_track_received
try:
# Connect and run
# If a password was provided on the CLI, store it on the client for use when setting name
if args.password:
client.name_password = args.password
await client.connect()
except KeyboardInterrupt:
print("Shutting down...")
finally:
await client.disconnect()
# Retry loop for connection resilience
max_retries = 5
retry_delay = 5.0 # seconds
retry_count = 0
while retry_count < max_retries:
try:
# If a password was provided on the CLI, store it on the client for use when setting name
if args.password:
client.name_password = args.password
print(f"Attempting to connect (attempt {retry_count + 1}/{max_retries})")
await client.connect()
except KeyboardInterrupt:
print("Shutting down...")
break
except Exception as e:
logger.error(f"Connection failed (attempt {retry_count + 1}): {e}")
retry_count += 1
if retry_count < max_retries:
print(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
# Exponential backoff with max delay of 60 seconds
retry_delay = min(retry_delay * 1.5, 60.0)
else:
print("Max retries exceeded. Giving up.")
break
finally:
try:
await client.disconnect()
except Exception as e:
logger.error(f"Error during disconnect: {e}")
print("Voicebot client stopped.")
# --- FastAPI service for bot discovery and orchestration -------------------
@ -1578,6 +1759,8 @@ if __name__ == "__main__":
parser.add_argument("--private", action="store_true", help="Create the lobby as private (client mode)")
parser.add_argument("--insecure", action="store_true",
help="Allow insecure connections")
parser.add_argument("--registration-check-interval", type=float, default=30.0,
help="Interval in seconds for checking registration status (5.0-300.0)")
args = parser.parse_args()