Add server reconnect logic
This commit is contained in:
parent
39666eddbe
commit
2e91a4eadb
213
voicebot/main.py
213
voicebot/main.py
@ -92,6 +92,7 @@ class VoicebotArgs(BaseModel):
|
|||||||
password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name")
|
password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name")
|
||||||
private: bool = Field(default=False, description="Create the lobby as private")
|
private: bool = Field(default=False, description="Create the lobby as private")
|
||||||
insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL")
|
insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL")
|
||||||
|
registration_check_interval: float = Field(default=30.0, description="Interval in seconds for checking registration status", ge=5.0, le=300.0)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environment(cls) -> 'VoicebotArgs':
|
def from_environment(cls) -> 'VoicebotArgs':
|
||||||
@ -110,7 +111,8 @@ class VoicebotArgs(BaseModel):
|
|||||||
session_id=os.getenv('VOICEBOT_SESSION_ID', None),
|
session_id=os.getenv('VOICEBOT_SESSION_ID', None),
|
||||||
password=os.getenv('VOICEBOT_PASSWORD', None),
|
password=os.getenv('VOICEBOT_PASSWORD', None),
|
||||||
private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true',
|
private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true',
|
||||||
insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true'
|
insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true',
|
||||||
|
registration_check_interval=float(os.getenv('VOICEBOT_REGISTRATION_CHECK_INTERVAL', '30.0'))
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -128,7 +130,8 @@ class VoicebotArgs(BaseModel):
|
|||||||
session_id=getattr(args, 'session_id', None),
|
session_id=getattr(args, 'session_id', None),
|
||||||
password=getattr(args, 'password', None),
|
password=getattr(args, 'password', None),
|
||||||
private=getattr(args, 'private', False),
|
private=getattr(args, 'private', False),
|
||||||
insecure=getattr(args, 'insecure', False)
|
insecure=getattr(args, 'insecure', False),
|
||||||
|
registration_check_interval=float(getattr(args, 'registration_check_interval', 30.0))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bot orchestration imports
|
# Bot orchestration imports
|
||||||
@ -186,6 +189,7 @@ class WebRTCSignalingClient:
|
|||||||
session_name: str,
|
session_name: str,
|
||||||
insecure: bool = False,
|
insecure: bool = False,
|
||||||
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
|
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
|
||||||
|
registration_check_interval: float = 30.0,
|
||||||
):
|
):
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.lobby_id = lobby_id
|
self.lobby_id = lobby_id
|
||||||
@ -212,6 +216,12 @@ class WebRTCSignalingClient:
|
|||||||
self.initiated_offer: set[str] = set()
|
self.initiated_offer: set[str] = set()
|
||||||
self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {}
|
self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {}
|
||||||
|
|
||||||
|
# Registration status tracking
|
||||||
|
self.is_registered: bool = False
|
||||||
|
self.last_registration_check: float = 0
|
||||||
|
self.registration_check_interval: float = registration_check_interval
|
||||||
|
self.registration_check_task: Optional[asyncio.Task[None]] = None
|
||||||
|
|
||||||
# Event callbacks
|
# Event callbacks
|
||||||
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
|
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
|
||||||
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
|
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
|
||||||
@ -267,17 +277,154 @@ class WebRTCSignalingClient:
|
|||||||
await self._send_message("set_name", name_payload)
|
await self._send_message("set_name", name_payload)
|
||||||
logger.info("Sending join message")
|
logger.info("Sending join message")
|
||||||
await self._send_message("join", {})
|
await self._send_message("join", {})
|
||||||
|
|
||||||
|
# Mark as registered after successful join
|
||||||
|
self.is_registered = True
|
||||||
|
import time
|
||||||
|
self.last_registration_check = time.time()
|
||||||
|
|
||||||
|
# Start periodic registration check
|
||||||
|
self.registration_check_task = asyncio.create_task(self._periodic_registration_check())
|
||||||
|
|
||||||
# Start message handling
|
# Start message handling
|
||||||
logger.info("Starting message handler loop")
|
logger.info("Starting message handler loop")
|
||||||
await self._handle_messages()
|
try:
|
||||||
|
await self._handle_messages()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Message handling stopped: {e}")
|
||||||
|
self.is_registered = False
|
||||||
|
raise
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to signaling server: {e}", exc_info=True)
|
logger.error(f"Failed to connect to signaling server: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def _periodic_registration_check(self):
|
||||||
|
"""Periodically check registration status and re-register if needed"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(self.registration_check_interval)
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - self.last_registration_check < self.registration_check_interval:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if we're still connected and registered
|
||||||
|
if not await self._check_registration_status():
|
||||||
|
logger.warning("Registration check failed, attempting to re-register")
|
||||||
|
await self._re_register()
|
||||||
|
|
||||||
|
self.last_registration_check = current_time
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Registration check task cancelled")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in periodic registration check: {e}", exc_info=True)
|
||||||
|
# Continue checking even if one iteration fails
|
||||||
|
continue
|
||||||
|
|
||||||
|
async def _check_registration_status(self) -> bool:
|
||||||
|
"""Check if the voicebot is still registered with the server"""
|
||||||
|
try:
|
||||||
|
# First check if websocket is still connected
|
||||||
|
if not self.websocket:
|
||||||
|
logger.warning("WebSocket connection lost")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Try to send a ping/status check message to verify connection
|
||||||
|
# We'll use a simple status message to check connectivity
|
||||||
|
try:
|
||||||
|
import time
|
||||||
|
await self._send_message("status_check", {"timestamp": time.time()})
|
||||||
|
logger.debug("Registration status check sent")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send status check: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking registration status: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _re_register(self):
|
||||||
|
"""Attempt to re-register with the server"""
|
||||||
|
try:
|
||||||
|
logger.info("Attempting to re-register with server")
|
||||||
|
|
||||||
|
# Mark as not registered during re-registration attempt
|
||||||
|
self.is_registered = False
|
||||||
|
|
||||||
|
# Try to reconnect the websocket if it's lost
|
||||||
|
if not self.websocket:
|
||||||
|
logger.info("WebSocket lost, attempting to reconnect")
|
||||||
|
await self._reconnect_websocket()
|
||||||
|
|
||||||
|
# Re-send name and join messages
|
||||||
|
name_payload: MessageData = {"name": self.session_name}
|
||||||
|
if self.name_password:
|
||||||
|
name_payload["password"] = self.name_password
|
||||||
|
|
||||||
|
logger.info("Re-sending set_name message")
|
||||||
|
await self._send_message("set_name", name_payload)
|
||||||
|
|
||||||
|
logger.info("Re-sending join message")
|
||||||
|
await self._send_message("join", {})
|
||||||
|
|
||||||
|
# Mark as registered after successful re-join
|
||||||
|
self.is_registered = True
|
||||||
|
import time
|
||||||
|
self.last_registration_check = time.time()
|
||||||
|
|
||||||
|
logger.info("Successfully re-registered with server")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to re-register with server: {e}", exc_info=True)
|
||||||
|
# Will try again on next check interval
|
||||||
|
|
||||||
|
async def _reconnect_websocket(self):
|
||||||
|
"""Reconnect the WebSocket connection"""
|
||||||
|
try:
|
||||||
|
# Close existing connection if any
|
||||||
|
if self.websocket:
|
||||||
|
try:
|
||||||
|
ws = cast(WebSocketProtocol, self.websocket)
|
||||||
|
await ws.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.websocket = None
|
||||||
|
|
||||||
|
# Reconnect
|
||||||
|
ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
|
||||||
|
|
||||||
|
# If insecure (self-signed certs), create an SSL context for the websocket
|
||||||
|
ws_ssl = None
|
||||||
|
if self.insecure:
|
||||||
|
ws_ssl = ssl.create_default_context()
|
||||||
|
ws_ssl.check_hostname = False
|
||||||
|
ws_ssl.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
logger.info(f"Reconnecting to signaling server: {ws_url}")
|
||||||
|
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
|
||||||
|
logger.info("Successfully reconnected to signaling server")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to reconnect websocket: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
"""Disconnect from signaling server and cleanup"""
|
"""Disconnect from signaling server and cleanup"""
|
||||||
|
# Cancel the registration check task
|
||||||
|
if self.registration_check_task and not self.registration_check_task.done():
|
||||||
|
self.registration_check_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.registration_check_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self.registration_check_task = None
|
||||||
|
|
||||||
if self.websocket:
|
if self.websocket:
|
||||||
ws = cast(WebSocketProtocol, self.websocket)
|
ws = cast(WebSocketProtocol, self.websocket)
|
||||||
await ws.close()
|
await ws.close()
|
||||||
@ -290,6 +437,9 @@ class WebRTCSignalingClient:
|
|||||||
for track in self.local_tracks.values():
|
for track in self.local_tracks.values():
|
||||||
track.stop()
|
track.stop()
|
||||||
|
|
||||||
|
# Reset registration status
|
||||||
|
self.is_registered = False
|
||||||
|
|
||||||
logger.info("Disconnected from signaling server")
|
logger.info("Disconnected from signaling server")
|
||||||
|
|
||||||
async def _setup_local_media(self):
|
async def _setup_local_media(self):
|
||||||
@ -355,9 +505,12 @@ class WebRTCSignalingClient:
|
|||||||
continue
|
continue
|
||||||
await self._process_message(data)
|
await self._process_message(data)
|
||||||
except websockets.exceptions.ConnectionClosed as e:
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
logger.info(f"WebSocket connection closed: {e}")
|
logger.warning(f"WebSocket connection closed: {e}")
|
||||||
|
self.is_registered = False
|
||||||
|
# The periodic registration check will detect this and attempt reconnection
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error handling messages: {e}", exc_info=True)
|
logger.error(f"Error handling messages: {e}", exc_info=True)
|
||||||
|
self.is_registered = False
|
||||||
|
|
||||||
async def _process_message(self, message: MessageData):
|
async def _process_message(self, message: MessageData):
|
||||||
"""Process incoming signaling messages"""
|
"""Process incoming signaling messages"""
|
||||||
@ -1157,7 +1310,9 @@ async def main_with_args(args: VoicebotArgs):
|
|||||||
ws_base = _ws_url(args.server_url)
|
ws_base = _ws_url(args.server_url)
|
||||||
|
|
||||||
client = WebRTCSignalingClient(
|
client = WebRTCSignalingClient(
|
||||||
ws_base, lobby_id, session_id, args.session_name, insecure=args.insecure
|
ws_base, lobby_id, session_id, args.session_name,
|
||||||
|
insecure=args.insecure,
|
||||||
|
registration_check_interval=args.registration_check_interval
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set up event handlers
|
# Set up event handlers
|
||||||
@ -1194,16 +1349,42 @@ async def main_with_args(args: VoicebotArgs):
|
|||||||
client.on_peer_removed = on_peer_removed
|
client.on_peer_removed = on_peer_removed
|
||||||
client.on_track_received = on_track_received
|
client.on_track_received = on_track_received
|
||||||
|
|
||||||
try:
|
# Retry loop for connection resilience
|
||||||
# Connect and run
|
max_retries = 5
|
||||||
# If a password was provided on the CLI, store it on the client for use when setting name
|
retry_delay = 5.0 # seconds
|
||||||
if args.password:
|
retry_count = 0
|
||||||
client.name_password = args.password
|
|
||||||
await client.connect()
|
while retry_count < max_retries:
|
||||||
except KeyboardInterrupt:
|
try:
|
||||||
print("Shutting down...")
|
# If a password was provided on the CLI, store it on the client for use when setting name
|
||||||
finally:
|
if args.password:
|
||||||
await client.disconnect()
|
client.name_password = args.password
|
||||||
|
|
||||||
|
print(f"Attempting to connect (attempt {retry_count + 1}/{max_retries})")
|
||||||
|
await client.connect()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Shutting down...")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Connection failed (attempt {retry_count + 1}): {e}")
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
if retry_count < max_retries:
|
||||||
|
print(f"Retrying in {retry_delay} seconds...")
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
# Exponential backoff with max delay of 60 seconds
|
||||||
|
retry_delay = min(retry_delay * 1.5, 60.0)
|
||||||
|
else:
|
||||||
|
print("Max retries exceeded. Giving up.")
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await client.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during disconnect: {e}")
|
||||||
|
|
||||||
|
print("Voicebot client stopped.")
|
||||||
|
|
||||||
|
|
||||||
# --- FastAPI service for bot discovery and orchestration -------------------
|
# --- FastAPI service for bot discovery and orchestration -------------------
|
||||||
@ -1578,6 +1759,8 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--private", action="store_true", help="Create the lobby as private (client mode)")
|
parser.add_argument("--private", action="store_true", help="Create the lobby as private (client mode)")
|
||||||
parser.add_argument("--insecure", action="store_true",
|
parser.add_argument("--insecure", action="store_true",
|
||||||
help="Allow insecure connections")
|
help="Allow insecure connections")
|
||||||
|
parser.add_argument("--registration-check-interval", type=float, default=30.0,
|
||||||
|
help="Interval in seconds for checking registration status (5.0-300.0)")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user