Improv leave / part for bots

This commit is contained in:
James Ketr 2025-09-03 16:07:55 -07:00
parent 8ef309d4f1
commit 1bd0a5ab71
2 changed files with 36 additions and 4 deletions

View File

@ -162,13 +162,23 @@ async def stop_run(run_id: str):
client = registry.get(run_id)
if not client:
raise HTTPException(status_code=404, detail="Run not found")
try:
await client.disconnect()
# Request graceful shutdown instead of awaiting disconnect from different loop
client.request_shutdown()
# Give the client a moment to shutdown gracefully
await asyncio.sleep(0.5)
# Remove from registry
registry.pop(run_id, None)
return {"status": "stopped", "run_id": run_id}
except Exception:
logger.exception("Failed to stop run %s", run_id)
# Still remove from registry even if shutdown failed
registry.pop(run_id, None)
raise HTTPException(status_code=500, detail="Failed to stop run")
registry.pop(run_id, None)
return {"status": "stopped", "run_id": run_id}
@app.get("/bots/runs")

View File

@ -122,6 +122,9 @@ class WebRTCSignalingClient:
self.registration_check_interval: float = registration_check_interval
self.registration_check_task: Optional[asyncio.Task[None]] = None
# Shutdown flag for graceful termination
self.shutdown_requested: bool = False
# Event callbacks
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
@ -185,6 +188,9 @@ class WebRTCSignalingClient:
logger.error(f"Message handling stopped: {e}")
self.is_registered = False
raise
finally:
# Clean disconnect when exiting
await self.disconnect()
except Exception as e:
logger.error(f"Failed to connect to signaling server: {e}", exc_info=True)
@ -193,10 +199,14 @@ class WebRTCSignalingClient:
async def _periodic_registration_check(self):
"""Periodically check registration status and re-register if needed"""
while True:
while not self.shutdown_requested:
try:
await asyncio.sleep(self.registration_check_interval)
# Check shutdown flag again after sleep
if self.shutdown_requested:
break
current_time = time.time()
if current_time - self.last_registration_check < self.registration_check_interval:
continue
@ -215,6 +225,8 @@ class WebRTCSignalingClient:
logger.error(f"Error in periodic registration check: {e}", exc_info=True)
# Continue checking even if one iteration fails
continue
logger.info("Registration check loop ended")
async def _check_registration_status(self) -> bool:
"""Check if the voicebot is still registered with the server"""
@ -354,6 +366,11 @@ class WebRTCSignalingClient:
logger.info("Disconnected from signaling server")
def request_shutdown(self):
"""Request graceful shutdown - can be called from any thread"""
self.shutdown_requested = True
logger.info("Shutdown requested for WebRTC signaling client")
async def _setup_local_media(self):
"""Create local media tracks"""
# If a bot provided a create_tracks callable, use it to create tracks.
@ -407,6 +424,11 @@ class WebRTCSignalingClient:
try:
ws = cast(WebSocketProtocol, self.websocket)
async for message in ws:
# Check for shutdown request
if self.shutdown_requested:
logger.info("Shutdown requested, breaking message loop")
break
logger.debug(f"_handle_messages: Received raw message: {message}")
try:
data = cast(MessageData, json.loads(message))