145 lines
5.3 KiB
Python
145 lines
5.3 KiB
Python
"""
|
|
Main client logic for voicebot.
|
|
|
|
This module contains the main client functionality and entry points.
|
|
"""
|
|
|
|
import asyncio
|
|
import sys
|
|
import os
|
|
from shared.logger import logger
|
|
from voicebot.bots.synthetic_media import AnimatedVideoTrack
|
|
|
|
# Add the parent directory to sys.path to allow absolute imports
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from voicebot.models import VoicebotArgs, Peer
|
|
from voicebot.session_manager import create_or_get_session, create_or_get_lobby
|
|
from voicebot.webrtc_signaling import WebRTCSignalingClient
|
|
from voicebot.utils import ws_url
|
|
from aiortc import MediaStreamTrack
|
|
|
|
|
|
async def main_with_args(args: VoicebotArgs):
|
|
"""Main voicebot client logic that accepts arguments object."""
|
|
|
|
# Resolve session id (create if needed)
|
|
try:
|
|
session_id = create_or_get_session(
|
|
args.server_url, args.session_id, insecure=args.insecure
|
|
)
|
|
print(f"Using session id: {session_id}")
|
|
except Exception as e:
|
|
print(f"Failed to get/create session: {e}")
|
|
return
|
|
|
|
# Create or get lobby id
|
|
try:
|
|
lobby_id = create_or_get_lobby(
|
|
args.server_url,
|
|
session_id,
|
|
args.lobby,
|
|
args.private,
|
|
insecure=args.insecure,
|
|
)
|
|
print(f"Using lobby id: {lobby_id} (name={args.lobby})")
|
|
except Exception as e:
|
|
print(f"Failed to create/get lobby: {e}")
|
|
return
|
|
|
|
# Build websocket base URL (ws:// or wss://) from server_url and pass to client so
|
|
# it constructs the final websocket path (/ws/lobby/{lobby}/{session}) itself.
|
|
ws_base = ws_url(args.server_url)
|
|
|
|
client = WebRTCSignalingClient(
|
|
ws_base, lobby_id, session_id, args.session_name,
|
|
insecure=args.insecure,
|
|
registration_check_interval=args.registration_check_interval
|
|
)
|
|
|
|
# Set up event handlers
|
|
async def on_peer_added(peer: Peer):
|
|
print(f"Peer added: {peer.peer_name}")
|
|
|
|
async def on_peer_removed(peer: Peer):
|
|
print(f"Peer removed: {peer.peer_name}")
|
|
|
|
# Remove any video tracks from this peer from our synthetic video track
|
|
if "video" in client.local_tracks:
|
|
synthetic_video_track = client.local_tracks["video"]
|
|
if isinstance(synthetic_video_track, AnimatedVideoTrack):
|
|
# We need to identify and remove tracks from this specific peer
|
|
# Since we don't have a direct mapping, we'll need to track this differently
|
|
# For now, this is a placeholder - we might need to enhance the peer tracking
|
|
logger.info(
|
|
f"Peer {peer.peer_name} removed - may need to clean up video tracks"
|
|
)
|
|
|
|
async def on_track_received(peer: Peer, track: MediaStreamTrack):
|
|
print(f"Received {track.kind} track from {peer.peer_name}")
|
|
|
|
# If it's a video track, attach it to our synthetic video track for edge detection
|
|
if track.kind == "video" and "video" in client.local_tracks:
|
|
synthetic_video_track = client.local_tracks["video"]
|
|
if isinstance(synthetic_video_track, AnimatedVideoTrack):
|
|
synthetic_video_track.add_remote_video_track(track)
|
|
logger.info(
|
|
f"Attached remote video track from {peer.peer_name} to synthetic video track"
|
|
)
|
|
|
|
client.on_peer_added = on_peer_added
|
|
client.on_peer_removed = on_peer_removed
|
|
client.on_track_received = on_track_received
|
|
|
|
# Retry loop for connection resilience
|
|
max_retries = 5
|
|
retry_delay = 5.0 # seconds
|
|
retry_count = 0
|
|
|
|
while retry_count < max_retries:
|
|
try:
|
|
# If a password was provided on the CLI, store it on the client for use when setting name
|
|
if args.password:
|
|
client.name_password = args.password
|
|
|
|
print(f"Attempting to connect (attempt {retry_count + 1}/{max_retries})")
|
|
await client.connect()
|
|
|
|
except KeyboardInterrupt:
|
|
print("Shutting down...")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Connection failed (attempt {retry_count + 1}): {e}")
|
|
retry_count += 1
|
|
|
|
if retry_count < max_retries:
|
|
print(f"Retrying in {retry_delay} seconds...")
|
|
await asyncio.sleep(retry_delay)
|
|
# Exponential backoff with max delay of 60 seconds
|
|
retry_delay = min(retry_delay * 1.5, 60.0)
|
|
else:
|
|
print("Max retries exceeded. Giving up.")
|
|
break
|
|
finally:
|
|
try:
|
|
await client.disconnect()
|
|
except Exception as e:
|
|
logger.error(f"Error during disconnect: {e}")
|
|
|
|
print("Voicebot client stopped.")
|
|
|
|
|
|
def start_client_with_reload(args: VoicebotArgs):
|
|
"""Start the client with auto-reload functionality."""
|
|
|
|
logger.info("Creating client app for uvicorn...")
|
|
from voicebot.client_app import create_client_app
|
|
create_client_app(args)
|
|
|
|
# Note: This function is called when --reload is specified
|
|
# The actual uvicorn execution should be handled by the entrypoint script
|
|
logger.info("Client app created. Uvicorn should be started by entrypoint script.")
|
|
|
|
# Fall back to running client directly if not using uvicorn
|
|
asyncio.run(main_with_args(args))
|