Add server reconnect logic
This commit is contained in:
parent
39666eddbe
commit
2e91a4eadb
213
voicebot/main.py
213
voicebot/main.py
@ -92,6 +92,7 @@ class VoicebotArgs(BaseModel):
|
||||
password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name")
|
||||
private: bool = Field(default=False, description="Create the lobby as private")
|
||||
insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL")
|
||||
registration_check_interval: float = Field(default=30.0, description="Interval in seconds for checking registration status", ge=5.0, le=300.0)
|
||||
|
||||
@classmethod
|
||||
def from_environment(cls) -> 'VoicebotArgs':
|
||||
@ -110,7 +111,8 @@ class VoicebotArgs(BaseModel):
|
||||
session_id=os.getenv('VOICEBOT_SESSION_ID', None),
|
||||
password=os.getenv('VOICEBOT_PASSWORD', None),
|
||||
private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true',
|
||||
insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true'
|
||||
insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true',
|
||||
registration_check_interval=float(os.getenv('VOICEBOT_REGISTRATION_CHECK_INTERVAL', '30.0'))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -128,7 +130,8 @@ class VoicebotArgs(BaseModel):
|
||||
session_id=getattr(args, 'session_id', None),
|
||||
password=getattr(args, 'password', None),
|
||||
private=getattr(args, 'private', False),
|
||||
insecure=getattr(args, 'insecure', False)
|
||||
insecure=getattr(args, 'insecure', False),
|
||||
registration_check_interval=float(getattr(args, 'registration_check_interval', 30.0))
|
||||
)
|
||||
|
||||
# Bot orchestration imports
|
||||
@ -186,6 +189,7 @@ class WebRTCSignalingClient:
|
||||
session_name: str,
|
||||
insecure: bool = False,
|
||||
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
|
||||
registration_check_interval: float = 30.0,
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.lobby_id = lobby_id
|
||||
@ -212,6 +216,12 @@ class WebRTCSignalingClient:
|
||||
self.initiated_offer: set[str] = set()
|
||||
self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {}
|
||||
|
||||
# Registration status tracking
|
||||
self.is_registered: bool = False
|
||||
self.last_registration_check: float = 0
|
||||
self.registration_check_interval: float = registration_check_interval
|
||||
self.registration_check_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
# Event callbacks
|
||||
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
|
||||
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
|
||||
@ -267,17 +277,154 @@ class WebRTCSignalingClient:
|
||||
await self._send_message("set_name", name_payload)
|
||||
logger.info("Sending join message")
|
||||
await self._send_message("join", {})
|
||||
|
||||
# Mark as registered after successful join
|
||||
self.is_registered = True
|
||||
import time
|
||||
self.last_registration_check = time.time()
|
||||
|
||||
# Start periodic registration check
|
||||
self.registration_check_task = asyncio.create_task(self._periodic_registration_check())
|
||||
|
||||
# Start message handling
|
||||
logger.info("Starting message handler loop")
|
||||
await self._handle_messages()
|
||||
try:
|
||||
await self._handle_messages()
|
||||
except Exception as e:
|
||||
logger.error(f"Message handling stopped: {e}")
|
||||
self.is_registered = False
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to signaling server: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _periodic_registration_check(self):
|
||||
"""Periodically check registration status and re-register if needed"""
|
||||
import time
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.registration_check_interval)
|
||||
|
||||
current_time = time.time()
|
||||
if current_time - self.last_registration_check < self.registration_check_interval:
|
||||
continue
|
||||
|
||||
# Check if we're still connected and registered
|
||||
if not await self._check_registration_status():
|
||||
logger.warning("Registration check failed, attempting to re-register")
|
||||
await self._re_register()
|
||||
|
||||
self.last_registration_check = current_time
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Registration check task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in periodic registration check: {e}", exc_info=True)
|
||||
# Continue checking even if one iteration fails
|
||||
continue
|
||||
|
||||
async def _check_registration_status(self) -> bool:
|
||||
"""Check if the voicebot is still registered with the server"""
|
||||
try:
|
||||
# First check if websocket is still connected
|
||||
if not self.websocket:
|
||||
logger.warning("WebSocket connection lost")
|
||||
return False
|
||||
|
||||
# Try to send a ping/status check message to verify connection
|
||||
# We'll use a simple status message to check connectivity
|
||||
try:
|
||||
import time
|
||||
await self._send_message("status_check", {"timestamp": time.time()})
|
||||
logger.debug("Registration status check sent")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send status check: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking registration status: {e}")
|
||||
return False
|
||||
|
||||
async def _re_register(self):
|
||||
"""Attempt to re-register with the server"""
|
||||
try:
|
||||
logger.info("Attempting to re-register with server")
|
||||
|
||||
# Mark as not registered during re-registration attempt
|
||||
self.is_registered = False
|
||||
|
||||
# Try to reconnect the websocket if it's lost
|
||||
if not self.websocket:
|
||||
logger.info("WebSocket lost, attempting to reconnect")
|
||||
await self._reconnect_websocket()
|
||||
|
||||
# Re-send name and join messages
|
||||
name_payload: MessageData = {"name": self.session_name}
|
||||
if self.name_password:
|
||||
name_payload["password"] = self.name_password
|
||||
|
||||
logger.info("Re-sending set_name message")
|
||||
await self._send_message("set_name", name_payload)
|
||||
|
||||
logger.info("Re-sending join message")
|
||||
await self._send_message("join", {})
|
||||
|
||||
# Mark as registered after successful re-join
|
||||
self.is_registered = True
|
||||
import time
|
||||
self.last_registration_check = time.time()
|
||||
|
||||
logger.info("Successfully re-registered with server")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to re-register with server: {e}", exc_info=True)
|
||||
# Will try again on next check interval
|
||||
|
||||
async def _reconnect_websocket(self):
|
||||
"""Reconnect the WebSocket connection"""
|
||||
try:
|
||||
# Close existing connection if any
|
||||
if self.websocket:
|
||||
try:
|
||||
ws = cast(WebSocketProtocol, self.websocket)
|
||||
await ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.websocket = None
|
||||
|
||||
# Reconnect
|
||||
ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
|
||||
|
||||
# If insecure (self-signed certs), create an SSL context for the websocket
|
||||
ws_ssl = None
|
||||
if self.insecure:
|
||||
ws_ssl = ssl.create_default_context()
|
||||
ws_ssl.check_hostname = False
|
||||
ws_ssl.verify_mode = ssl.CERT_NONE
|
||||
|
||||
logger.info(f"Reconnecting to signaling server: {ws_url}")
|
||||
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
|
||||
logger.info("Successfully reconnected to signaling server")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reconnect websocket: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from signaling server and cleanup"""
|
||||
# Cancel the registration check task
|
||||
if self.registration_check_task and not self.registration_check_task.done():
|
||||
self.registration_check_task.cancel()
|
||||
try:
|
||||
await self.registration_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.registration_check_task = None
|
||||
|
||||
if self.websocket:
|
||||
ws = cast(WebSocketProtocol, self.websocket)
|
||||
await ws.close()
|
||||
@ -290,6 +437,9 @@ class WebRTCSignalingClient:
|
||||
for track in self.local_tracks.values():
|
||||
track.stop()
|
||||
|
||||
# Reset registration status
|
||||
self.is_registered = False
|
||||
|
||||
logger.info("Disconnected from signaling server")
|
||||
|
||||
async def _setup_local_media(self):
|
||||
@ -355,9 +505,12 @@ class WebRTCSignalingClient:
|
||||
continue
|
||||
await self._process_message(data)
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.info(f"WebSocket connection closed: {e}")
|
||||
logger.warning(f"WebSocket connection closed: {e}")
|
||||
self.is_registered = False
|
||||
# The periodic registration check will detect this and attempt reconnection
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling messages: {e}", exc_info=True)
|
||||
self.is_registered = False
|
||||
|
||||
async def _process_message(self, message: MessageData):
|
||||
"""Process incoming signaling messages"""
|
||||
@ -1157,7 +1310,9 @@ async def main_with_args(args: VoicebotArgs):
|
||||
ws_base = _ws_url(args.server_url)
|
||||
|
||||
client = WebRTCSignalingClient(
|
||||
ws_base, lobby_id, session_id, args.session_name, insecure=args.insecure
|
||||
ws_base, lobby_id, session_id, args.session_name,
|
||||
insecure=args.insecure,
|
||||
registration_check_interval=args.registration_check_interval
|
||||
)
|
||||
|
||||
# Set up event handlers
|
||||
@ -1194,16 +1349,42 @@ async def main_with_args(args: VoicebotArgs):
|
||||
client.on_peer_removed = on_peer_removed
|
||||
client.on_track_received = on_track_received
|
||||
|
||||
try:
|
||||
# Connect and run
|
||||
# If a password was provided on the CLI, store it on the client for use when setting name
|
||||
if args.password:
|
||||
client.name_password = args.password
|
||||
await client.connect()
|
||||
except KeyboardInterrupt:
|
||||
print("Shutting down...")
|
||||
finally:
|
||||
await client.disconnect()
|
||||
# Retry loop for connection resilience
|
||||
max_retries = 5
|
||||
retry_delay = 5.0 # seconds
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# If a password was provided on the CLI, store it on the client for use when setting name
|
||||
if args.password:
|
||||
client.name_password = args.password
|
||||
|
||||
print(f"Attempting to connect (attempt {retry_count + 1}/{max_retries})")
|
||||
await client.connect()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Shutting down...")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Connection failed (attempt {retry_count + 1}): {e}")
|
||||
retry_count += 1
|
||||
|
||||
if retry_count < max_retries:
|
||||
print(f"Retrying in {retry_delay} seconds...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
# Exponential backoff with max delay of 60 seconds
|
||||
retry_delay = min(retry_delay * 1.5, 60.0)
|
||||
else:
|
||||
print("Max retries exceeded. Giving up.")
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect: {e}")
|
||||
|
||||
print("Voicebot client stopped.")
|
||||
|
||||
|
||||
# --- FastAPI service for bot discovery and orchestration -------------------
|
||||
@ -1578,6 +1759,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--private", action="store_true", help="Create the lobby as private (client mode)")
|
||||
parser.add_argument("--insecure", action="store_true",
|
||||
help="Allow insecure connections")
|
||||
parser.add_argument("--registration-check-interval", type=float, default=30.0,
|
||||
help="Interval in seconds for checking registration status (5.0-300.0)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user