From b916db243bff2c4cfe5d4cb7b676c5cb9c0f76f6 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Wed, 3 Sep 2025 14:33:15 -0700 Subject: [PATCH] Refactored voicebot/main.py --- Dockerfile.voicebot | 1 + server/main.py | 4 +- voicebot/REFACTORING_SUMMARY.md | 82 ++ voicebot/__init__.py | 30 + voicebot/bot_orchestrator.py | 238 +++++ voicebot/bots/whisper.py | 154 +-- voicebot/client_app.py | 127 +++ voicebot/client_main.py | 144 +++ voicebot/main.py | 1697 +------------------------------ voicebot/models.py | 121 +++ voicebot/session_manager.py | 118 +++ voicebot/utils.py | 57 ++ voicebot/webrtc_signaling.py | 894 ++++++++++++++++ 13 files changed, 1904 insertions(+), 1763 deletions(-) create mode 100644 voicebot/REFACTORING_SUMMARY.md create mode 100644 voicebot/__init__.py create mode 100644 voicebot/bot_orchestrator.py create mode 100644 voicebot/client_app.py create mode 100644 voicebot/client_main.py create mode 100644 voicebot/models.py create mode 100644 voicebot/session_manager.py create mode 100644 voicebot/utils.py create mode 100644 voicebot/webrtc_signaling.py diff --git a/Dockerfile.voicebot b/Dockerfile.voicebot index 3b5d82f..9e34ddd 100644 --- a/Dockerfile.voicebot +++ b/Dockerfile.voicebot @@ -24,6 +24,7 @@ RUN apt-get update \ libgl1 \ libglib2.0-0t64 \ git \ + iproute2 \ && apt-get clean \ && rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log} diff --git a/server/main.py b/server/main.py index caed2ae..9b60f1a 100644 --- a/server/main.py +++ b/server/main.py @@ -523,7 +523,7 @@ class Session: try: if os.path.exists(cls._save_file + ".tmp"): os.remove(cls._save_file + ".tmp") - except: + except Exception as e: pass @classmethod @@ -1971,7 +1971,7 @@ async def lobby_join( ) try: await websocket.close() - except: + except Exception as e: pass diff --git a/voicebot/REFACTORING_SUMMARY.md b/voicebot/REFACTORING_SUMMARY.md new file mode 100644 index 0000000..0a5f739 --- /dev/null +++ b/voicebot/REFACTORING_SUMMARY.md @@ -0,0 +1,82 @@ +# Voicebot Module Refactoring + +The voicebot/main.py functionality has been broken down into individual Python files for better organization and maintainability: + +## New File Structure + +### Core Modules + +1. **`models.py`** - Data models and configuration + - `VoicebotArgs` - Pydantic model for CLI arguments and configuration + - `VoicebotMode` - Enum for client/provider modes + - `Peer` - WebRTC peer representation + - `JoinRequest` - Request model for joining lobbies + - `MessageData` - Type alias for message payloads + +2. **`webrtc_signaling.py`** - WebRTC signaling client functionality + - `WebRTCSignalingClient` - Main WebRTC signaling client class + - Handles peer connection management, ICE candidates, session descriptions + - Registration status tracking and reconnection logic + - Message processing and event handling + +3. **`session_manager.py`** - Session and lobby management + - `create_or_get_session()` - Session creation/retrieval + - `create_or_get_lobby()` - Lobby creation/retrieval + - HTTP API communication utilities + +4. **`bot_orchestrator.py`** - FastAPI bot orchestration service + - Bot discovery and management + - FastAPI endpoints for bot operations + - Provider registration with main server + - Bot instance lifecycle management + +5. **`client_main.py`** - Main client logic + - `main_with_args()` - Core client functionality + - `start_client_with_reload()` - Development mode with reload + - Event handlers for peer and track management + +6. **`client_app.py`** - Client FastAPI application + - `create_client_app()` - Creates FastAPI app for client mode + - Health check and status endpoints + - Process isolation and locking + +7. **`utils.py`** - Utility functions + - URL conversion utilities (`http_base_url`, `ws_url`) + - SSL context creation + - Network information logging + +8. **`main.py`** - Main orchestration and entry point + - Command-line argument parsing + - Mode selection (client vs provider) + - Entry points for both modes + +### Key Improvements + +- **Separation of Concerns**: Each file handles specific functionality +- **Better Maintainability**: Smaller, focused modules are easier to understand and modify +- **Reduced Coupling**: Dependencies between components are more explicit +- **Type Safety**: Proper type hints and Pydantic models throughout +- **Error Handling**: Centralized error handling and logging + +### Usage + +The refactored code maintains the same CLI interface: + +```bash +# Client mode +python voicebot/main.py --mode client --server-url http://localhost:8000/ai-voicebot + +# Provider mode +python voicebot/main.py --mode provider --host 0.0.0.0 --port 8788 +``` + +### Import Structure + +```python +from voicebot import VoicebotArgs, VoicebotMode, WebRTCSignalingClient +from voicebot.models import Peer, JoinRequest +from voicebot.session_manager import create_or_get_session, create_or_get_lobby +from voicebot.client_main import main_with_args +``` + +The original `main_old.py` contains the monolithic implementation for reference. diff --git a/voicebot/__init__.py b/voicebot/__init__.py new file mode 100644 index 0000000..8e64f70 --- /dev/null +++ b/voicebot/__init__.py @@ -0,0 +1,30 @@ +""" +Voicebot package. + +This package provides WebRTC signaling client functionality and bot orchestration +for AI voicebots. +""" + +import sys +import os + +# Add the parent directory to sys.path to allow absolute imports +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from voicebot.models import VoicebotArgs, VoicebotMode, Peer, JoinRequest +from voicebot.webrtc_signaling import WebRTCSignalingClient +from voicebot.session_manager import create_or_get_session, create_or_get_lobby +from voicebot.client_main import main_with_args +from voicebot.bot_orchestrator import app as bot_orchestrator_app + +__all__ = [ + 'VoicebotArgs', + 'VoicebotMode', + 'Peer', + 'JoinRequest', + 'WebRTCSignalingClient', + 'create_or_get_session', + 'create_or_get_lobby', + 'main_with_args', + 'bot_orchestrator_app', +] diff --git a/voicebot/bot_orchestrator.py b/voicebot/bot_orchestrator.py new file mode 100644 index 0000000..55cfb14 --- /dev/null +++ b/voicebot/bot_orchestrator.py @@ -0,0 +1,238 @@ +""" +Bot orchestrator FastAPI service. + +This module provides the FastAPI service for bot discovery and orchestration. +""" + +import asyncio +import threading +import uuid +import importlib +import pkgutil +import sys +import os +from typing import Dict, Any + +# Add the parent directory to sys.path to allow absolute imports +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import uvicorn +from fastapi import FastAPI, HTTPException + +from logger import logger +from voicebot.models import JoinRequest +from voicebot.webrtc_signaling import WebRTCSignalingClient + + +app = FastAPI(title="voicebot-bot-orchestrator") + +# Lightweight in-memory registry of running bot clients +registry: Dict[str, WebRTCSignalingClient] = {} + + +def discover_bots() -> Dict[str, Dict[str, Any]]: + """Discover bot modules under the voicebot.bots package that expose bot_info. + + This intentionally imports modules under `voicebot.bots` so heavy bot + implementations can remain in that package and be imported lazily. + """ + bots: Dict[str, Dict[str, Any]] = {} + try: + package = importlib.import_module("voicebot.bots") + package_path = package.__path__ + except Exception: + logger.exception("Failed to import voicebot.bots package") + return bots + + for _finder, name, _ispkg in pkgutil.iter_modules(package_path): + try: + mod = importlib.import_module(f"voicebot.bots.{name}") + except Exception: + logger.exception("Failed to import voicebot.bots.%s", name) + continue + info = None + create_tracks = None + if hasattr(mod, "agent_info") and callable(getattr(mod, "agent_info")): + try: + info = mod.agent_info() + # Note: Keep copy as is to maintain structure + except Exception: + logger.exception("agent_info() failed for %s", name) + if hasattr(mod, "create_agent_tracks") and callable(getattr(mod, "create_agent_tracks")): + create_tracks = getattr(mod, "create_agent_tracks") + + if info: + bots[info.get("name", name)] = {"module": name, "info": info, "create_tracks": create_tracks} + return bots + + +@app.get("/bots") +def list_bots() -> Dict[str, Any]: + """List available bots.""" + bots = discover_bots() + return {k: v["info"] for k, v in bots.items()} + + +@app.post("/bots/{bot_name}/join") +async def bot_join(bot_name: str, req: JoinRequest): + """Make a bot join a lobby.""" + bots = discover_bots() + bot = bots.get(bot_name) + if not bot: + raise HTTPException(status_code=404, detail="Bot not found") + + create_tracks = bot.get("create_tracks") + + # Start the WebRTCSignalingClient in a background asyncio task and register it + client = WebRTCSignalingClient( + server_url=req.server_url, + lobby_id=req.lobby_id, + session_id=req.session_id, + session_name=req.nick, + insecure=req.insecure, + create_tracks=create_tracks, + ) + + run_id = str(uuid.uuid4()) + + async def run_client(): + try: + registry[run_id] = client + await client.connect() + except Exception: + logger.exception("Bot client failed for run %s", run_id) + finally: + registry.pop(run_id, None) + + loop = asyncio.get_event_loop() + threading.Thread(target=loop.run_until_complete, args=(run_client(),), daemon=True).start() + + return {"status": "started", "bot": bot_name, "run_id": run_id} + + +@app.post("/bots/runs/{run_id}/stop") +async def stop_run(run_id: str): + """Stop a running bot.""" + client = registry.get(run_id) + if not client: + raise HTTPException(status_code=404, detail="Run not found") + try: + await client.disconnect() + except Exception: + logger.exception("Failed to stop run %s", run_id) + raise HTTPException(status_code=500, detail="Failed to stop run") + registry.pop(run_id, None) + return {"status": "stopped", "run_id": run_id} + + +@app.get("/bots/runs") +def list_runs() -> Dict[str, Any]: + """List running bot instances.""" + return { + "runs": [ + {"run_id": run_id, "session_id": client.session_id, "session_name": client.session_name} + for run_id, client in registry.items() + ] + } + + +def start_bot_api(host: str = "0.0.0.0", port: int = 8788): + """Start the bot orchestration API server""" + uvicorn.run(app, host=host, port=port) + + +async def register_with_server(server_url: str, voicebot_url: str, insecure: bool = False) -> str: + """Register this voicebot instance as a bot provider with the main server""" + try: + # Import httpx locally to avoid dependency issues + import httpx + + payload = { + "base_url": voicebot_url.rstrip('/'), + "name": "voicebot-provider", + "description": "AI voicebot provider with speech recognition and synthetic media capabilities" + } + + # Prepare SSL context if needed + verify = not insecure + + async with httpx.AsyncClient(verify=verify) as client: + response = await client.post( + f"{server_url}/api/bots/providers/register", + json=payload, + timeout=10.0 + ) + + if response.status_code == 200: + result = response.json() + provider_id = result.get("provider_id") + logger.info(f"Successfully registered with server as provider: {provider_id}") + return provider_id + else: + logger.error(f"Failed to register with server: HTTP {response.status_code}: {response.text}") + raise RuntimeError(f"Registration failed: {response.status_code}") + + except Exception as e: + logger.error(f"Error registering with server: {e}") + raise + + +def start_bot_provider( + host: str = "0.0.0.0", + port: int = 8788, + server_url: str | None = None, + insecure: bool = False, + reload: bool = False +): + """Start the bot provider API server and optionally register with main server""" + import time + import socket + + # Start the FastAPI server in a background thread + # Add reload functionality for development + if reload: + server_thread = threading.Thread( + target=lambda: uvicorn.run( + app, + host=host, + port=port, + log_level="info", + reload=True, + reload_dirs=["/voicebot", "/shared"] + ), + daemon=True + ) + else: + server_thread = threading.Thread( + target=lambda: uvicorn.run(app, host=host, port=port, log_level="info"), + daemon=True + ) + logger.info(f"Starting bot provider API server on {host}:{port}...") + server_thread.start() + + # If server_url is provided, register with the main server + if server_url: + # Give the server a moment to start + time.sleep(2) + + # Construct the voicebot URL + voicebot_url = f"http://{host}:{port}" + if host == "0.0.0.0": + # Try to get a better hostname + try: + hostname = socket.gethostname() + voicebot_url = f"http://{hostname}:{port}" + except Exception: + voicebot_url = f"http://localhost:{port}" + + try: + asyncio.run(register_with_server(server_url, voicebot_url, insecure)) + except Exception as e: + logger.error(f"Failed to register with server: {e}") + + # Keep the main thread alive + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Shutting down bot provider...") diff --git a/voicebot/bots/whisper.py b/voicebot/bots/whisper.py index fd78adb..9de7d38 100644 --- a/voicebot/bots/whisper.py +++ b/voicebot/bots/whisper.py @@ -4,11 +4,11 @@ Lightweight agent descriptor; heavy model loading must be done by a controller when the agent is actually used. """ -from typing import Dict -from typing import Any +from typing import Dict, Any import librosa from logger import logger from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq +from aiortc import MediaStreamTrack AGENT_NAME = "whisper" @@ -19,91 +19,95 @@ def agent_info() -> Dict[str, str]: return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION} -def create_agent_tracks(session_name: str) -> dict: +def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]: """Whisper is not a media source - return no local tracks.""" return {} -model_ids = { - "Distil-Whisper": [ - "distil-whisper/distil-large-v2", - "distil-whisper/distil-medium.en", - "distil-whisper/distil-small.en" - ], - "Whisper": [ - "openai/whisper-large-v3", - "openai/whisper-large-v2", - "openai/whisper-large", - "openai/whisper-medium", - "openai/whisper-small", - "openai/whisper-base", - "openai/whisper-tiny", - "openai/whisper-medium.en", - "openai/whisper-small.en", - "openai/whisper-base.en", - "openai/whisper-tiny.en", - ] -} +def do_work(): + model_ids = { + "Distil-Whisper": [ + "distil-whisper/distil-large-v2", + "distil-whisper/distil-medium.en", + "distil-whisper/distil-small.en" + ], + "Whisper": [ + "openai/whisper-large-v3", + "openai/whisper-large-v2", + "openai/whisper-large", + "openai/whisper-medium", + "openai/whisper-small", + "openai/whisper-base", + "openai/whisper-tiny", + "openai/whisper-medium.en", + "openai/whisper-small.en", + "openai/whisper-base.en", + "openai/whisper-tiny.en", + ] + } -model_type = model_ids["Distil-Whisper"] + model_type = model_ids["Distil-Whisper"] -logger.info(model_type) -model_id = model_type[0] + logger.info(model_type) + model_id = model_type[0] -processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore + processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore -pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore -pt_model.eval() # type: ignore + pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore + pt_model.eval() # type: ignore -def extract_input_features(audio_array: Any, sampling_rate: int) -> Any: - """Extract input features from audio array and sampling rate.""" - input_features = processor( - audio_array, - sampling_rate=sampling_rate, - return_tensors="pt", - ).input_features - return input_features + def extract_input_features(audio_array: Any, sampling_rate: int) -> Any: + """Extract input features from audio array and sampling rate.""" + processor_output = processor( # type: ignore + audio_array, + sampling_rate=sampling_rate, + return_tensors="pt", + ) + input_features: Any = processor_output.input_features # type: ignore + return input_features # type: ignore -def load_audio_file(file_path: str) -> tuple[Any, int]: - """Load audio file from disk and return audio array and sampling rate.""" - # Whisper models expect 16kHz sample rate - target_sample_rate = 16000 - - try: - # Load audio file using librosa and resample to target rate - audio_array, original_sampling_rate = librosa.load(file_path, sr=None) # type: ignore - logger.info(f"Loaded audio file: {file_path}, duration: {len(audio_array)/original_sampling_rate:.2f}s, original sample rate: {original_sampling_rate}Hz") # type: ignore - - # Resample if necessary - if original_sampling_rate != target_sample_rate: - audio_array = librosa.resample(audio_array, orig_sr=original_sampling_rate, target_sr=target_sample_rate) # type: ignore - logger.info(f"Resampled audio from {original_sampling_rate}Hz to {target_sample_rate}Hz") - - return audio_array, target_sample_rate # type: ignore - except Exception as e: - logger.error(f"Error loading audio file {file_path}: {e}") - raise + def load_audio_file(file_path: str) -> tuple[Any, int]: + """Load audio file from disk and return audio array and sampling rate.""" + # Whisper models expect 16kHz sample rate + target_sample_rate = 16000 + + try: + # Load audio file using librosa and resample to target rate + audio_array, original_sampling_rate = librosa.load(file_path, sr=None) # type: ignore + logger.info(f"Loaded audio file: {file_path}, duration: {len(audio_array)/original_sampling_rate:.2f}s, original sample rate: {original_sampling_rate}Hz") # type: ignore + + # Resample if necessary + if original_sampling_rate != target_sample_rate: + audio_array = librosa.resample(audio_array, orig_sr=original_sampling_rate, target_sr=target_sample_rate) # type: ignore + logger.info(f"Resampled audio from {original_sampling_rate}Hz to {target_sample_rate}Hz") + + return audio_array, target_sample_rate # type: ignore + except Exception as e: + logger.error(f"Error loading audio file {file_path}: {e}") + raise -# Example usage - replace with your audio file path -audio_file_path = "/voicebot/F_0818_15y11m_1.wav" + # Example usage - replace with your audio file path + audio_file_path = "/voicebot/F_0818_15y11m_1.wav" -# Load audio from file instead of dataset -try: - audio_array, sampling_rate = load_audio_file(audio_file_path) - input_features = extract_input_features(audio_array, sampling_rate) - - predicted_ids = pt_model.generate(input_features) # type: ignore - transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore - - print(f"Audio file: {audio_file_path}") - print(f"Transcription: {transcription[0]}") - -except FileNotFoundError: - logger.error(f"Audio file not found: {audio_file_path}") - print("Please update the audio_file_path variable with a valid path to your wav file") -except Exception as e: - logger.error(f"Error processing audio: {e}") - print(f"Error: {e}") \ No newline at end of file + # Load audio from file instead of dataset + try: + audio_array, sampling_rate = load_audio_file(audio_file_path) + input_features = extract_input_features(audio_array, sampling_rate) + + predicted_ids = pt_model.generate(input_features) # type: ignore + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore + + print(f"Audio file: {audio_file_path}") + print(f"Transcription: {transcription[0]}") + + except FileNotFoundError: + logger.error(f"Audio file not found: {audio_file_path}") + print("Please update the audio_file_path variable with a valid path to your wav file") + except Exception as e: + logger.error(f"Error processing audio: {e}") + print(f"Error: {e}") + + \ No newline at end of file diff --git a/voicebot/client_app.py b/voicebot/client_app.py new file mode 100644 index 0000000..da4468b --- /dev/null +++ b/voicebot/client_app.py @@ -0,0 +1,127 @@ +""" +Client FastAPI application for voicebot. + +This module provides the FastAPI application for client mode operations. +""" + +import asyncio +import os +import fcntl +import sys +from contextlib import asynccontextmanager +from typing import Optional + +# Add the parent directory to sys.path to allow absolute imports +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from fastapi import FastAPI +from logger import logger + +# Import shared models +from shared.models import ClientStatusResponse + +from voicebot.models import VoicebotArgs + + +# Global client arguments storage +_client_args: Optional[VoicebotArgs] = None + + +def create_client_app(args: VoicebotArgs) -> FastAPI: + """Create a FastAPI app for client mode that uvicorn can import.""" + + global _client_args + _client_args = args + + # Store the client task globally so we can manage it + client_task = None + lock_file = None + + @asynccontextmanager + async def lifespan(app: FastAPI): + nonlocal client_task, lock_file + # Startup + # Use a file lock to prevent multiple instances from starting + lock_file_path = "/tmp/voicebot_client.lock" + + try: + lock_file = open(lock_file_path, 'w') + # Try to acquire an exclusive lock (non-blocking) + fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + + if _client_args is None: + logger.error("Client args not initialized") + if lock_file: + lock_file.close() + lock_file = None + yield + return + + logger.info("Starting voicebot client...") + # Import here to avoid circular imports + from .client_main import main_with_args + client_task = asyncio.create_task(main_with_args(_client_args)) + + except (IOError, OSError): + # Another process already has the lock + logger.info("Another instance is already running - skipping client startup") + if lock_file: + lock_file.close() + lock_file = None + + yield + + # Shutdown + if client_task and not client_task.done(): + logger.info("Shutting down voicebot client...") + client_task.cancel() + try: + await client_task + except asyncio.CancelledError: + pass + + if lock_file: + try: + fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) + lock_file.close() + os.unlink(lock_file_path) + except Exception: + pass + + # Create the client FastAPI app + app = FastAPI(title="voicebot-client", lifespan=lifespan) + + @app.get("/health") + async def health_check(): # type: ignore + """Simple health check endpoint""" + return {"status": "running", "mode": "client"} + + @app.get("/status", response_model=ClientStatusResponse) + async def client_status() -> ClientStatusResponse: # type: ignore + """Get client status""" + return ClientStatusResponse( + client_running=client_task is not None and not client_task.done(), + session_name=_client_args.session_name if _client_args else 'unknown', + lobby=_client_args.lobby if _client_args else 'unknown', + server_url=_client_args.server_url if _client_args else 'unknown' + ) + + return app + + +def get_app() -> FastAPI: + """Get the appropriate FastAPI app based on VOICEBOT_MODE environment variable.""" + mode = os.getenv('VOICEBOT_MODE', 'provider') + + if mode == 'client': + # For client mode, we need to create the client app with args from environment + args = VoicebotArgs.from_environment() + return create_client_app(args) + else: + # Provider mode - return the main bot orchestration app + from voicebot.bot_orchestrator import app + return app + + +# Create app instance for uvicorn import +uvicorn_app = get_app() diff --git a/voicebot/client_main.py b/voicebot/client_main.py new file mode 100644 index 0000000..7ce8c27 --- /dev/null +++ b/voicebot/client_main.py @@ -0,0 +1,144 @@ +""" +Main client logic for voicebot. + +This module contains the main client functionality and entry points. +""" + +import asyncio +import sys +import os +from logger import logger +from voicebot.bots.synthetic_media import AnimatedVideoTrack + +# Add the parent directory to sys.path to allow absolute imports +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from voicebot.models import VoicebotArgs, Peer +from voicebot.session_manager import create_or_get_session, create_or_get_lobby +from voicebot.webrtc_signaling import WebRTCSignalingClient +from voicebot.utils import ws_url +from aiortc import MediaStreamTrack + + +async def main_with_args(args: VoicebotArgs): + """Main voicebot client logic that accepts arguments object.""" + + # Resolve session id (create if needed) + try: + session_id = create_or_get_session( + args.server_url, args.session_id, insecure=args.insecure + ) + print(f"Using session id: {session_id}") + except Exception as e: + print(f"Failed to get/create session: {e}") + return + + # Create or get lobby id + try: + lobby_id = create_or_get_lobby( + args.server_url, + session_id, + args.lobby, + args.private, + insecure=args.insecure, + ) + print(f"Using lobby id: {lobby_id} (name={args.lobby})") + except Exception as e: + print(f"Failed to create/get lobby: {e}") + return + + # Build websocket base URL (ws:// or wss://) from server_url and pass to client so + # it constructs the final websocket path (/ws/lobby/{lobby}/{session}) itself. + ws_base = ws_url(args.server_url) + + client = WebRTCSignalingClient( + ws_base, lobby_id, session_id, args.session_name, + insecure=args.insecure, + registration_check_interval=args.registration_check_interval + ) + + # Set up event handlers + async def on_peer_added(peer: Peer): + print(f"Peer added: {peer.peer_name}") + + async def on_peer_removed(peer: Peer): + print(f"Peer removed: {peer.peer_name}") + + # Remove any video tracks from this peer from our synthetic video track + if "video" in client.local_tracks: + synthetic_video_track = client.local_tracks["video"] + if isinstance(synthetic_video_track, AnimatedVideoTrack): + # We need to identify and remove tracks from this specific peer + # Since we don't have a direct mapping, we'll need to track this differently + # For now, this is a placeholder - we might need to enhance the peer tracking + logger.info( + f"Peer {peer.peer_name} removed - may need to clean up video tracks" + ) + + async def on_track_received(peer: Peer, track: MediaStreamTrack): + print(f"Received {track.kind} track from {peer.peer_name}") + + # If it's a video track, attach it to our synthetic video track for edge detection + if track.kind == "video" and "video" in client.local_tracks: + synthetic_video_track = client.local_tracks["video"] + if isinstance(synthetic_video_track, AnimatedVideoTrack): + synthetic_video_track.add_remote_video_track(track) + logger.info( + f"Attached remote video track from {peer.peer_name} to synthetic video track" + ) + + client.on_peer_added = on_peer_added + client.on_peer_removed = on_peer_removed + client.on_track_received = on_track_received + + # Retry loop for connection resilience + max_retries = 5 + retry_delay = 5.0 # seconds + retry_count = 0 + + while retry_count < max_retries: + try: + # If a password was provided on the CLI, store it on the client for use when setting name + if args.password: + client.name_password = args.password + + print(f"Attempting to connect (attempt {retry_count + 1}/{max_retries})") + await client.connect() + + except KeyboardInterrupt: + print("Shutting down...") + break + except Exception as e: + logger.error(f"Connection failed (attempt {retry_count + 1}): {e}") + retry_count += 1 + + if retry_count < max_retries: + print(f"Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + # Exponential backoff with max delay of 60 seconds + retry_delay = min(retry_delay * 1.5, 60.0) + else: + print("Max retries exceeded. Giving up.") + break + finally: + try: + await client.disconnect() + except Exception as e: + logger.error(f"Error during disconnect: {e}") + + print("Voicebot client stopped.") + + +def start_client_with_reload(args: VoicebotArgs): + """Start the client with auto-reload functionality.""" + + logger.info("Creating client app for uvicorn...") + from voicebot.client_app import create_client_app + create_client_app(args) + + # Note: This function is called when --reload is specified + # The actual uvicorn execution should be handled by the entrypoint script + logger.info("Client app created. Uvicorn should be started by entrypoint script.") + + # Fall back to running client directly if not using uvicorn + asyncio.run(main_with_args(args)) diff --git a/voicebot/main.py b/voicebot/main.py index d49d3ac..6267394 100644 --- a/voicebot/main.py +++ b/voicebot/main.py @@ -1,1244 +1,25 @@ """ -WebRTC Media Agent for Python +Main entry point for voicebot. -This module provides WebRTC signaling server communication and peer connection management. -Synthetic audio/video track creation is handled by the bots.synthetic_media module. +This module provides the main entry point and orchestration for the voicebot application. +It can run in either client mode (connecting to a lobby) or provider mode (serving bots). """ -from __future__ import annotations - -import asyncio -import json -import websockets -from typing import ( - Dict, - Optional, - Callable, - Awaitable, - Protocol, - AsyncIterator, - cast, - Any, -) - -# test -from dataclasses import dataclass, field -from pydantic import ValidationError - -# types.SimpleNamespace removed — not used anymore after parsing candidates via aiortc.sdp import argparse -import urllib.request -import urllib.error -import urllib.parse -import ssl +import asyncio import sys import os -# Import shared models +# Add the parent directory to sys.path to allow absolute imports sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from shared.models import ( - SessionModel, - LobbyCreateResponse, - WebSocketMessageModel, - JoinStatusModel, - UserJoinedModel, - LobbyStateModel, - UpdateNameModel, - AddPeerModel, - RemovePeerModel, - SessionDescriptionModel, - IceCandidateModel, - ICECandidateDictModel, - SessionDescriptionTypedModel, - ClientStatusResponse, -) -from aiortc import ( - RTCPeerConnection, - RTCSessionDescription, - RTCIceCandidate, - MediaStreamTrack, -) -from logger import logger -from voicebot.bots.synthetic_media import create_synthetic_tracks, AnimatedVideoTrack -# Pydantic model for voicebot arguments -from pydantic import BaseModel, Field -from enum import Enum +from voicebot.models import VoicebotArgs, VoicebotMode +from voicebot.client_main import main_with_args, start_client_with_reload +from voicebot.bot_orchestrator import start_bot_provider +from voicebot.client_app import get_app -class VoicebotMode(str, Enum): - """Voicebot operation modes.""" - CLIENT = "client" - PROVIDER = "provider" - -class VoicebotArgs(BaseModel): - """Pydantic model for voicebot CLI arguments and configuration.""" - - # Mode selection - mode: VoicebotMode = Field(default=VoicebotMode.CLIENT, description="Run as client (connect to lobby) or provider (serve bots)") - - # Provider mode arguments - host: str = Field(default="0.0.0.0", description="Host for provider mode") - port: int = Field(default=8788, description="Port for provider mode", ge=1, le=65535) - reload: bool = Field(default=False, description="Enable auto-reload for development") - - # Client mode arguments - server_url: str = Field( - default="http://localhost:8000/ai-voicebot", - description="AI-Voicebot lobby and signaling server base URL (http:// or https://)" - ) - lobby: str = Field(default="default", description="Lobby name to create or join") - session_name: str = Field(default="Python Bot", description="Session (user) display name") - session_id: Optional[str] = Field(default=None, description="Optional existing session id to reuse") - password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name") - private: bool = Field(default=False, description="Create the lobby as private") - insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL") - registration_check_interval: float = Field(default=30.0, description="Interval in seconds for checking registration status", ge=5.0, le=300.0) - - @classmethod - def from_environment(cls) -> 'VoicebotArgs': - """Create VoicebotArgs from environment variables.""" - import os - - mode_str = os.getenv('VOICEBOT_MODE', 'client') - return cls( - mode=VoicebotMode(mode_str), - host=os.getenv('VOICEBOT_HOST', '0.0.0.0'), - port=int(os.getenv('VOICEBOT_PORT', '8788')), - reload=os.getenv('VOICEBOT_RELOAD', 'false').lower() == 'true', - server_url=os.getenv('VOICEBOT_SERVER_URL', 'http://localhost:8000/ai-voicebot'), - lobby=os.getenv('VOICEBOT_LOBBY', 'default'), - session_name=os.getenv('VOICEBOT_SESSION_NAME', 'Python Bot'), - session_id=os.getenv('VOICEBOT_SESSION_ID', None), - password=os.getenv('VOICEBOT_PASSWORD', None), - private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true', - insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true', - registration_check_interval=float(os.getenv('VOICEBOT_REGISTRATION_CHECK_INTERVAL', '30.0')) - ) - - @classmethod - def from_argparse(cls, args: 'argparse.Namespace') -> 'VoicebotArgs': - """Create VoicebotArgs from argparse Namespace.""" - mode_str = getattr(args, 'mode', 'client') - return cls( - mode=VoicebotMode(mode_str), - host=getattr(args, 'host', '0.0.0.0'), - port=getattr(args, 'port', 8788), - reload=getattr(args, 'reload', False), - server_url=getattr(args, 'server_url', 'http://localhost:8000/ai-voicebot'), - lobby=getattr(args, 'lobby', 'default'), - session_name=getattr(args, 'session_name', 'Python Bot'), - session_id=getattr(args, 'session_id', None), - password=getattr(args, 'password', None), - private=getattr(args, 'private', False), - insecure=getattr(args, 'insecure', False), - registration_check_interval=float(getattr(args, 'registration_check_interval', 30.0)) - ) - -# Bot orchestration imports -import importlib -import pkgutil -import uuid -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -import threading -import uvicorn - - -# import debug_aioice - -# Generic message payload type -MessageData = dict[str, object] - - -class WebSocketProtocol(Protocol): - def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ... - def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ... - def __aiter__(self) -> AsyncIterator[str]: ... - - -def _default_attributes() -> Dict[str, object]: - return {} - - -@dataclass -class Peer: - """Represents a WebRTC peer in the session""" - - session_id: str - peer_name: str - # Generic attributes bag. Values can be tracks or simple metadata. - attributes: Dict[str, object] = field(default_factory=_default_attributes) - muted: bool = False - video_on: bool = True - local: bool = False - dead: bool = False - connection: Optional[RTCPeerConnection] = None - - -class WebRTCSignalingClient: - """ - WebRTC signaling client that communicates with the FastAPI signaling server. - Handles peer-to-peer connection establishment and media streaming. - """ - - def __init__( - self, - server_url: str, - lobby_id: str, - session_id: str, - session_name: str, - insecure: bool = False, - create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None, - registration_check_interval: float = 30.0, - ): - self.server_url = server_url - self.lobby_id = lobby_id - self.session_id = session_id - self.session_name = session_name - self.insecure = insecure - - # Optional factory to create local media tracks for this client (bot provided) - self.create_tracks = create_tracks - - # WebSocket client protocol instance (typed as object to avoid Any) - self.websocket: Optional[object] = None - - # Optional password to register or takeover a name - self.name_password: Optional[str] = None - - self.peers: dict[str, Peer] = {} - self.peer_connections: dict[str, RTCPeerConnection] = {} - self.local_tracks: dict[str, MediaStreamTrack] = {} - - # State management - self.is_negotiating: dict[str, bool] = {} - self.making_offer: dict[str, bool] = {} - self.initiated_offer: set[str] = set() - self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {} - - # Registration status tracking - self.is_registered: bool = False - self.last_registration_check: float = 0 - self.registration_check_interval: float = registration_check_interval - self.registration_check_task: Optional[asyncio.Task[None]] = None - - # Event callbacks - self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None - self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None - self.on_track_received: Optional[ - Callable[[Peer, MediaStreamTrack], Awaitable[None]] - ] = None - - async def connect(self): - """Connect to the signaling server""" - ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}" - logger.info(f"Connecting to signaling server: {ws_url}") - - # Log network information for debugging - try: - import socket - - hostname = socket.gethostname() - local_ip = socket.gethostbyname(hostname) - logger.info(f"Container hostname: {hostname}, local IP: {local_ip}") - - # Get all network interfaces - import subprocess - - result = subprocess.run( - ["ip", "addr", "show"], capture_output=True, text=True - ) - logger.info(f"Network interfaces:\n{result.stdout}") - except Exception as e: - logger.warning(f"Could not get network info: {e}") - - try: - # If insecure (self-signed certs), create an SSL context for the websocket - ws_ssl = None - if self.insecure: - ws_ssl = ssl.create_default_context() - ws_ssl.check_hostname = False - ws_ssl.verify_mode = ssl.CERT_NONE - - logger.info( - f"Attempting websocket connection to {ws_url} with ssl={bool(ws_ssl)}" - ) - self.websocket = await websockets.connect(ws_url, ssl=ws_ssl) - logger.info("Connected to signaling server") - - # Set up local media - await self._setup_local_media() - - # Set name and join lobby - name_payload: MessageData = {"name": self.session_name} - if self.name_password: - name_payload["password"] = self.name_password - logger.info(f"Sending set_name: {name_payload}") - await self._send_message("set_name", name_payload) - logger.info("Sending join message") - await self._send_message("join", {}) - - # Mark as registered after successful join - self.is_registered = True - import time - self.last_registration_check = time.time() - - # Start periodic registration check - self.registration_check_task = asyncio.create_task(self._periodic_registration_check()) - - # Start message handling - logger.info("Starting message handler loop") - try: - await self._handle_messages() - except Exception as e: - logger.error(f"Message handling stopped: {e}") - self.is_registered = False - raise - - except Exception as e: - logger.error(f"Failed to connect to signaling server: {e}", exc_info=True) - raise - - async def _periodic_registration_check(self): - """Periodically check registration status and re-register if needed""" - import time - - while True: - try: - await asyncio.sleep(self.registration_check_interval) - - current_time = time.time() - if current_time - self.last_registration_check < self.registration_check_interval: - continue - - # Check if we're still connected and registered - if not await self._check_registration_status(): - logger.warning("Registration check failed, attempting to re-register") - await self._re_register() - - self.last_registration_check = current_time - - except asyncio.CancelledError: - logger.info("Registration check task cancelled") - break - except Exception as e: - logger.error(f"Error in periodic registration check: {e}", exc_info=True) - # Continue checking even if one iteration fails - continue - - async def _check_registration_status(self) -> bool: - """Check if the voicebot is still registered with the server""" - try: - # First check if websocket is still connected - if not self.websocket: - logger.warning("WebSocket connection lost") - return False - - # Try to send a ping/status check message to verify connection - # We'll use a simple status message to check connectivity - try: - import time - await self._send_message("status_check", {"timestamp": time.time()}) - logger.debug("Registration status check sent") - return True - except Exception as e: - logger.warning(f"Failed to send status check: {e}") - return False - - except Exception as e: - logger.error(f"Error checking registration status: {e}") - return False - - async def _re_register(self): - """Attempt to re-register with the server""" - try: - logger.info("Attempting to re-register with server") - - # Mark as not registered during re-registration attempt - self.is_registered = False - - # Try to reconnect the websocket if it's lost - if not self.websocket: - logger.info("WebSocket lost, attempting to reconnect") - await self._reconnect_websocket() - - # Re-send name and join messages - name_payload: MessageData = {"name": self.session_name} - if self.name_password: - name_payload["password"] = self.name_password - - logger.info("Re-sending set_name message") - await self._send_message("set_name", name_payload) - - logger.info("Re-sending join message") - await self._send_message("join", {}) - - # Mark as registered after successful re-join - self.is_registered = True - import time - self.last_registration_check = time.time() - - logger.info("Successfully re-registered with server") - - except Exception as e: - logger.error(f"Failed to re-register with server: {e}", exc_info=True) - # Will try again on next check interval - - async def _reconnect_websocket(self): - """Reconnect the WebSocket connection""" - try: - # Close existing connection if any - if self.websocket: - try: - ws = cast(WebSocketProtocol, self.websocket) - await ws.close() - except Exception: - pass - self.websocket = None - - # Reconnect - ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}" - - # If insecure (self-signed certs), create an SSL context for the websocket - ws_ssl = None - if self.insecure: - ws_ssl = ssl.create_default_context() - ws_ssl.check_hostname = False - ws_ssl.verify_mode = ssl.CERT_NONE - - logger.info(f"Reconnecting to signaling server: {ws_url}") - self.websocket = await websockets.connect(ws_url, ssl=ws_ssl) - logger.info("Successfully reconnected to signaling server") - - except Exception as e: - logger.error(f"Failed to reconnect websocket: {e}", exc_info=True) - raise - - async def disconnect(self): - """Disconnect from signaling server and cleanup""" - # Cancel the registration check task - if self.registration_check_task and not self.registration_check_task.done(): - self.registration_check_task.cancel() - try: - await self.registration_check_task - except asyncio.CancelledError: - pass - self.registration_check_task = None - - if self.websocket: - ws = cast(WebSocketProtocol, self.websocket) - await ws.close() - - # Close all peer connections - for pc in self.peer_connections.values(): - await pc.close() - - # Stop local tracks - for track in self.local_tracks.values(): - track.stop() - - # Reset registration status - self.is_registered = False - - logger.info("Disconnected from signaling server") - - async def _setup_local_media(self): - """Create local media tracks""" - # If a bot provided a create_tracks callable, use it to create tracks. - # Otherwise, use default synthetic tracks. - try: - if self.create_tracks: - tracks = self.create_tracks(self.session_name) - self.local_tracks.update(tracks) - else: - # Default fallback to synthetic tracks - tracks = create_synthetic_tracks(self.session_name) - self.local_tracks.update(tracks) - except Exception: - logger.exception("Failed to create local tracks using bot factory") - - # Add local peer to peers dict - local_peer = Peer( - session_id=self.session_id, - peer_name=self.session_name, - local=True, - attributes={"tracks": self.local_tracks}, - ) - self.peers[self.session_id] = local_peer - - logger.info("Local media tracks created") - - async def _send_message( - self, message_type: str, data: Optional[MessageData] = None - ): - """Send message to signaling server""" - if not self.websocket: - logger.error("No websocket connection") - return - # Build message with explicit type to avoid type narrowing - message: dict[str, object] = {"type": message_type} - if data is not None: - message["data"] = data - - ws = cast(WebSocketProtocol, self.websocket) - try: - logger.debug(f"_send_message: Sending {message_type} with data: {data}") - await ws.send(json.dumps(message)) - logger.debug(f"_send_message: Sent message: {message_type}") - except Exception as e: - logger.error( - f"_send_message: Failed to send {message_type}: {e}", exc_info=True - ) - - async def _handle_messages(self): - """Handle incoming messages from signaling server""" - try: - ws = cast(WebSocketProtocol, self.websocket) - async for message in ws: - logger.debug(f"_handle_messages: Received raw message: {message}") - try: - data = cast(MessageData, json.loads(message)) - except Exception as e: - logger.error( - f"_handle_messages: Failed to parse message: {e}", exc_info=True - ) - continue - await self._process_message(data) - except websockets.exceptions.ConnectionClosed as e: - logger.warning(f"WebSocket connection closed: {e}") - self.is_registered = False - # The periodic registration check will detect this and attempt reconnection - except Exception as e: - logger.error(f"Error handling messages: {e}", exc_info=True) - self.is_registered = False - - async def _process_message(self, message: MessageData): - """Process incoming signaling messages""" - try: - # Validate the base message structure first - validated_message = WebSocketMessageModel.model_validate(message) - msg_type = validated_message.type - data = validated_message.data - except ValidationError as e: - logger.error(f"Invalid message structure: {e}", exc_info=True) - return - - logger.debug( - f"_process_message: Received message type: {msg_type} with data: {data}" - ) - - if msg_type == "addPeer": - try: - validated = AddPeerModel.model_validate(data) - except ValidationError as e: - logger.error(f"Invalid addPeer payload: {e}", exc_info=True) - return - await self._handle_add_peer(validated) - elif msg_type == "removePeer": - try: - validated = RemovePeerModel.model_validate(data) - except ValidationError as e: - logger.error(f"Invalid removePeer payload: {e}", exc_info=True) - return - await self._handle_remove_peer(validated) - elif msg_type == "sessionDescription": - try: - validated = SessionDescriptionModel.model_validate(data) - except ValidationError as e: - logger.error(f"Invalid sessionDescription payload: {e}", exc_info=True) - return - await self._handle_session_description(validated) - elif msg_type == "iceCandidate": - try: - validated = IceCandidateModel.model_validate(data) - except ValidationError as e: - logger.error(f"Invalid iceCandidate payload: {e}", exc_info=True) - return - await self._handle_ice_candidate(validated) - elif msg_type == "join_status": - try: - validated = JoinStatusModel.model_validate(data) - except ValidationError as e: - logger.error(f"Invalid join_status payload: {e}", exc_info=True) - return - logger.info(f"Join status: {validated.status} - {validated.message}") - elif msg_type == "user_joined": - try: - validated = UserJoinedModel.model_validate(data) - except ValidationError as e: - logger.error(f"Invalid user_joined payload: {e}", exc_info=True) - return - logger.info( - f"User joined: {validated.name} (session: {validated.session_id})" - ) - elif msg_type == "lobby_state": - try: - validated = LobbyStateModel.model_validate(data) - except ValidationError as e: - logger.error(f"Invalid lobby_state payload: {e}", exc_info=True) - return - participants = validated.participants - logger.info(f"Lobby state updated: {len(participants)} participants") - elif msg_type == "update_name": - try: - validated = UpdateNameModel.model_validate(data) - except ValidationError as e: - logger.error(f"Invalid update payload: {e}", exc_info=True) - return - logger.info(f"Received update message: {validated}") - else: - logger.info(f"Unhandled message type: {msg_type} with data: {data}") - - async def _handle_add_peer(self, data: AddPeerModel): - """Handle addPeer message - create new peer connection""" - peer_id = data.peer_id - peer_name = data.peer_name - should_create_offer = data.should_create_offer - - logger.info( - f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})" - ) - logger.debug( - f"_handle_add_peer: peer_id={peer_id}, peer_name={peer_name}, should_create_offer={should_create_offer}" - ) - - # Check if peer already exists - if peer_id in self.peer_connections: - pc = self.peer_connections[peer_id] - logger.debug( - f"_handle_add_peer: Existing connection state: {pc.connectionState}" - ) - if pc.connectionState in ["new", "connected", "connecting"]: - logger.info(f"Peer connection already exists for {peer_name}") - return - else: - # Clean up stale connection - logger.debug( - f"_handle_add_peer: Closing stale connection for {peer_name}" - ) - await pc.close() - del self.peer_connections[peer_id] - - # Create new peer - peer = Peer(session_id=peer_id, peer_name=peer_name, local=False) - self.peers[peer_id] = peer - - # Create RTCPeerConnection - from aiortc.rtcconfiguration import RTCConfiguration, RTCIceServer - - config = RTCConfiguration( - iceServers=[ - RTCIceServer(urls="stun:ketrenos.com:3478"), - RTCIceServer( - urls="turns:ketrenos.com:5349", - username="ketra", - credential="ketran", - ), - # Add Google's public STUN server as fallback - RTCIceServer(urls="stun:stun.l.google.com:19302"), - ], - ) - logger.debug( - f"_handle_add_peer: Creating RTCPeerConnection for {peer_name} with config: {config}" - ) - pc = RTCPeerConnection(configuration=config) - - # Add ICE gathering state change handler (explicit registration to satisfy static analyzers) - def on_ice_gathering_state_change() -> None: - logger.info(f"ICE gathering state: {pc.iceGatheringState}") - # Debug: Check if we have any local candidates when gathering is complete - if pc.iceGatheringState == "complete": - logger.info( - f"ICE gathering complete for {peer_name} - checking if candidates were generated..." - ) - - pc.on("icegatheringstatechange")(on_ice_gathering_state_change) - - # Add connection state change handler (explicit registration to satisfy static analyzers) - def on_connection_state_change() -> None: - logger.info(f"Connection state: {pc.connectionState}") - - pc.on("connectionstatechange")(on_connection_state_change) - - self.peer_connections[peer_id] = pc - peer.connection = pc - - # Set up event handlers - def on_track(track: MediaStreamTrack) -> None: - logger.info(f"Received {track.kind} track from {peer_name}") - logger.info(f"on_track: {track.kind} from {peer_name}, track={track}") - peer.attributes[f"{track.kind}_track"] = track - if self.on_track_received: - asyncio.ensure_future(self.on_track_received(peer, track)) - - pc.on("track")(on_track) - - def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: - logger.info(f"on_ice_candidate: {candidate}") - logger.info( - f"on_ice_candidate CALLED for {peer_name}: candidate={candidate}" - ) - if not candidate: - logger.info( - f"on_ice_candidate: End of candidates signal for {peer_name}" - ) - return - - # Raw SDP fragment for the candidate - raw = getattr(candidate, "candidate", None) - - # Try to infer candidate type from the SDP string (host/srflx/relay/prflx) - def _parse_type(s: Optional[str]) -> str: - if not s: - return "eoc" - import re - - m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", s) - return m.group(1) if m else "unknown" - - cand_type = _parse_type(raw) - protocol = getattr(candidate, "protocol", "unknown") - logger.info( - f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}" - ) - - candidate_model = ICECandidateDictModel( - candidate=raw, - sdpMid=getattr(candidate, "sdpMid", None), - sdpMLineIndex=getattr(candidate, "sdpMLineIndex", None), - ) - payload_model = IceCandidateModel( - peer_id=peer_id, peer_name=peer_name, candidate=candidate_model - ) - logger.info( - f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_model}" - ) - asyncio.ensure_future( - self._send_message("relayICECandidate", payload_model.model_dump()) - ) - - pc.on("icecandidate")(on_ice_candidate) - - # Add local tracks - for track in self.local_tracks.values(): - logger.debug( - f"_handle_add_peer: Adding local track {track.kind} to {peer_name}" - ) - pc.addTrack(track) - - # Create offer if needed - if should_create_offer: - self.initiated_offer.add(peer_id) - self.making_offer[peer_id] = True - self.is_negotiating[peer_id] = True - - try: - logger.debug(f"_handle_add_peer: Creating offer for {peer_name}") - offer = await pc.createOffer() - logger.debug( - f"_handle_add_peer: Offer created for {peer_name}: {offer}" - ) - await pc.setLocalDescription(offer) - logger.debug(f"_handle_add_peer: Local description set for {peer_name}") - - # WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344) - # Use Method 2: Complete SDP approach to extract ICE candidates - logger.debug( - f"_handle_add_peer: Waiting for ICE gathering to complete for {peer_name}" - ) - while pc.iceGatheringState != "complete": - await asyncio.sleep(0.1) - - logger.debug( - f"_handle_add_peer: ICE gathering complete, extracting candidates from SDP for {peer_name}" - ) - - # Parse ICE candidates from the local SDP - sdp_lines = pc.localDescription.sdp.split("\n") - candidate_lines = [ - line for line in sdp_lines if line.startswith("a=candidate:") - ] - - # Track which media section we're in to determine sdpMid and sdpMLineIndex - current_media_index = -1 - current_mid = None - - for line in sdp_lines: - if line.startswith("m="): # Media section - current_media_index += 1 - elif line.startswith("a=mid:"): # Media ID - current_mid = line.split(":", 1)[1].strip() - elif line.startswith("a=candidate:"): - candidate_sdp = line[2:] # Remove 'a=' prefix - - candidate_model = ICECandidateDictModel( - candidate=candidate_sdp, - sdpMid=current_mid, - sdpMLineIndex=current_media_index, - ) - payload_candidate = IceCandidateModel( - peer_id=peer_id, - peer_name=peer_name, - candidate=candidate_model, - ) - - logger.debug( - f"_handle_add_peer: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}..." - ) - await self._send_message( - "relayICECandidate", payload_candidate.model_dump() - ) - - # Send end-of-candidates signal (empty candidate) - end_candidate_model = ICECandidateDictModel( - candidate="", - sdpMid=None, - sdpMLineIndex=None, - ) - payload_end = IceCandidateModel( - peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model - ) - logger.debug( - f"_handle_add_peer: Sending end-of-candidates signal for {peer_name}" - ) - await self._send_message("relayICECandidate", payload_end.model_dump()) - - logger.debug( - f"_handle_add_peer: Sent {len(candidate_lines)} ICE candidates to {peer_name}" - ) - - session_desc_typed = SessionDescriptionTypedModel( - type=offer.type, sdp=offer.sdp - ) - session_desc_model = SessionDescriptionModel( - peer_id=peer_id, - peer_name=peer_name, - session_description=session_desc_typed, - ) - await self._send_message( - "relaySessionDescription", - session_desc_model.model_dump(), - ) - - logger.info(f"Offer sent to {peer_name}") - except Exception as e: - logger.error( - f"Failed to create/send offer to {peer_name}: {e}", exc_info=True - ) - finally: - self.making_offer[peer_id] = False - - if self.on_peer_added: - await self.on_peer_added(peer) - - async def _handle_remove_peer(self, data: RemovePeerModel): - """Handle removePeer message""" - peer_id = data.peer_id - peer_name = data.peer_name - - logger.info(f"Removing peer: {peer_name}") - - # Close peer connection - if peer_id in self.peer_connections: - pc = self.peer_connections[peer_id] - await pc.close() - del self.peer_connections[peer_id] - - # Clean up state - self.is_negotiating.pop(peer_id, None) - self.making_offer.pop(peer_id, None) - self.initiated_offer.discard(peer_id) - self.pending_ice_candidates.pop(peer_id, None) - - # Remove peer - peer = self.peers.pop(peer_id, None) - if peer and self.on_peer_removed: - await self.on_peer_removed(peer) - - async def _handle_session_description(self, data: SessionDescriptionModel): - """Handle sessionDescription message""" - peer_id = data.peer_id - peer_name = data.peer_name - session_description = data.session_description.model_dump() - - logger.info(f"Received {session_description['type']} from {peer_name}") - - pc = self.peer_connections.get(peer_id) - if not pc: - logger.error(f"No peer connection for {peer_name}") - return - - desc = RTCSessionDescription( - sdp=session_description["sdp"], type=session_description["type"] - ) - - # Handle offer collision (polite peer pattern) - making_offer = self.making_offer.get(peer_id, False) - offer_collision = desc.type == "offer" and ( - making_offer or pc.signalingState != "stable" - ) - we_initiated = peer_id in self.initiated_offer - ignore_offer = we_initiated and offer_collision - - if ignore_offer: - logger.info(f"Ignoring offer from {peer_name} due to collision") - return - - try: - await pc.setRemoteDescription(desc) - self.is_negotiating[peer_id] = False - logger.info(f"Remote description set for {peer_name}") - - # Process queued ICE candidates - pending_candidates = self.pending_ice_candidates.pop(peer_id, []) - from aiortc.sdp import candidate_from_sdp - - for candidate_data in pending_candidates: - # candidate_data is an ICECandidateDictModel Pydantic model - cand = candidate_data.candidate - # handle end-of-candidates marker - if not cand: - await pc.addIceCandidate(None) - logger.info(f"Added queued end-of-candidates for {peer_name}") - continue - - # cand may be the full "candidate:..." string or the inner SDP part - if cand and cand.startswith("candidate:"): - sdp_part = cand.split(":", 1)[1] - else: - sdp_part = cand - - try: - rtc_candidate = candidate_from_sdp(sdp_part) - rtc_candidate.sdpMid = candidate_data.sdpMid - rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex - await pc.addIceCandidate(rtc_candidate) - logger.info(f"Added queued ICE candidate for {peer_name}") - except Exception as e: - logger.error( - f"Failed to add queued ICE candidate for {peer_name}: {e}" - ) - - except Exception as e: - logger.error(f"Failed to set remote description for {peer_name}: {e}") - return - - # Create answer if this was an offer - if session_description["type"] == "offer": - try: - answer = await pc.createAnswer() - await pc.setLocalDescription(answer) - - # WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344) - # Use Method 2: Complete SDP approach to extract ICE candidates - logger.debug( - f"_handle_session_description: Waiting for ICE gathering to complete for {peer_name} (answer)" - ) - while pc.iceGatheringState != "complete": - await asyncio.sleep(0.1) - - logger.debug( - f"_handle_session_description: ICE gathering complete, extracting candidates from SDP for {peer_name} (answer)" - ) - - # Parse ICE candidates from the local SDP - sdp_lines = pc.localDescription.sdp.split("\n") - candidate_lines = [ - line for line in sdp_lines if line.startswith("a=candidate:") - ] - - # Track which media section we're in to determine sdpMid and sdpMLineIndex - current_media_index = -1 - current_mid = None - - for line in sdp_lines: - if line.startswith("m="): # Media section - current_media_index += 1 - elif line.startswith("a=mid:"): # Media ID - current_mid = line.split(":", 1)[1].strip() - elif line.startswith("a=candidate:"): - candidate_sdp = line[2:] # Remove 'a=' prefix - - candidate_model = ICECandidateDictModel( - candidate=candidate_sdp, - sdpMid=current_mid, - sdpMLineIndex=current_media_index, - ) - payload_candidate = IceCandidateModel( - peer_id=peer_id, - peer_name=peer_name, - candidate=candidate_model, - ) - - logger.debug( - f"_handle_session_description: Sending extracted ICE candidate for {peer_name} (answer): {candidate_sdp[:60]}..." - ) - await self._send_message( - "relayICECandidate", payload_candidate.model_dump() - ) - - # Send end-of-candidates signal (empty candidate) - end_candidate_model = ICECandidateDictModel( - candidate="", - sdpMid=None, - sdpMLineIndex=None, - ) - payload_end = IceCandidateModel( - peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model - ) - logger.debug( - f"_handle_session_description: Sending end-of-candidates signal for {peer_name} (answer)" - ) - await self._send_message("relayICECandidate", payload_end.model_dump()) - - logger.debug( - f"_handle_session_description: Sent {len(candidate_lines)} ICE candidates to {peer_name} (answer)" - ) - - session_desc_typed = SessionDescriptionTypedModel( - type=answer.type, sdp=answer.sdp - ) - session_desc_model = SessionDescriptionModel( - peer_id=peer_id, - peer_name=peer_name, - session_description=session_desc_typed, - ) - await self._send_message( - "relaySessionDescription", - session_desc_model.model_dump(), - ) - - logger.info(f"Answer sent to {peer_name}") - except Exception as e: - logger.error(f"Failed to create/send answer to {peer_name}: {e}") - - async def _handle_ice_candidate(self, data: IceCandidateModel): - """Handle iceCandidate message""" - peer_id = data.peer_id - peer_name = data.peer_name - candidate_data = data.candidate - - logger.info(f"Received ICE candidate from {peer_name}") - - pc = self.peer_connections.get(peer_id) - if not pc: - logger.error(f"No peer connection for {peer_name}") - return - - # Queue candidate if remote description not set - if not pc.remoteDescription: - logger.info( - f"Remote description not set, queuing ICE candidate for {peer_name}" - ) - if peer_id not in self.pending_ice_candidates: - self.pending_ice_candidates[peer_id] = [] - # candidate_data is an ICECandidateDictModel Pydantic model - self.pending_ice_candidates[peer_id].append(candidate_data) - return - - try: - from aiortc.sdp import candidate_from_sdp - - cand = candidate_data.candidate - if not cand: - # end-of-candidates - await pc.addIceCandidate(None) - logger.info(f"End-of-candidates added for {peer_name}") - return - - if cand and cand.startswith("candidate:"): - sdp_part = cand.split(":", 1)[1] - else: - sdp_part = cand - - # Detect type for logging - try: - import re - - m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", sdp_part) - cand_type = m.group(1) if m else "unknown" - except Exception: - cand_type = "unknown" - - try: - rtc_candidate = candidate_from_sdp(sdp_part) - rtc_candidate.sdpMid = candidate_data.sdpMid - rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex - - # aiortc expects an object with attributes (RTCIceCandidate) - await pc.addIceCandidate(rtc_candidate) - logger.info(f"ICE candidate added for {peer_name}: type={cand_type}") - except Exception as e: - logger.error( - f"Failed to add ICE candidate for {peer_name}: type={cand_type} error={e} sdp='{sdp_part}'", - exc_info=True, - ) - except Exception as e: - logger.error( - f"Unexpected error handling ICE candidate for {peer_name}: {e}", - exc_info=True, - ) - - async def _handle_ice_connection_failure(self, peer_id: str, peer_name: str): - """Handle ICE connection failure by logging details""" - logger.info(f"ICE connection failure detected for {peer_name}") - - pc = self.peer_connections.get(peer_id) - if not pc: - logger.error( - f"No peer connection found for {peer_name} during ICE failure recovery" - ) - return - - logger.error( - f"ICE connection failed for {peer_name}. Connection state: {pc.connectionState}, ICE state: {pc.iceConnectionState}" - ) - # In a real implementation, you might want to notify the user or attempt reconnection - - async def _schedule_ice_timeout(self, peer_id: str, peer_name: str): - """Schedule a timeout for ICE connection checking""" - await asyncio.sleep(30) # Wait 30 seconds - - pc = self.peer_connections.get(peer_id) - if not pc: - return - - if pc.iceConnectionState == "checking": - logger.warning( - f"ICE connection timeout for {peer_name} - still in checking state after 30 seconds" - ) - logger.warning( - f"Final connection state: {pc.connectionState}, ICE state: {pc.iceConnectionState}" - ) - logger.warning( - "This might be due to network connectivity issues between the browser and Docker container" - ) - logger.warning( - "Consider checking: 1) Port forwarding 2) TURN server config 3) Docker network mode" - ) - elif pc.iceConnectionState in ["failed", "closed"]: - logger.info( - f"ICE connection for {peer_name} resolved to: {pc.iceConnectionState}" - ) - else: - logger.info( - f"ICE connection for {peer_name} established: {pc.iceConnectionState}" - ) - - -# Example usage -def _http_base_url(server_url: str) -> str: - # Convert ws:// or wss:// to http(s) and ensure no trailing slash - if server_url.startswith("ws://"): - return "http://" + server_url[len("ws://") :].rstrip("/") - if server_url.startswith("wss://"): - return "https://" + server_url[len("wss://") :].rstrip("/") - return server_url.rstrip("/") - - -def _ws_url(server_url: str) -> str: - # Convert http(s) to ws(s) if needed - if server_url.startswith("http://"): - return "ws://" + server_url[len("http://") :].rstrip("/") - if server_url.startswith("https://"): - return "wss://" + server_url[len("https://") :].rstrip("/") - return server_url.rstrip("/") - - -def create_or_get_session( - server_url: str, session_id: str | None = None, insecure: bool = False -) -> str: - """Call GET /api/session to obtain a session_id (unless one was provided). - - Uses urllib so no extra runtime deps are required. - """ - if session_id: - return session_id - - http_base = _http_base_url(server_url) - url = f"{http_base}/api/session" - req = urllib.request.Request(url, method="GET") - # Prepare SSL context if requested (accept self-signed certs) - ssl_ctx = None - if insecure: - ssl_ctx = ssl.create_default_context() - ssl_ctx.check_hostname = False - ssl_ctx.verify_mode = ssl.CERT_NONE - - try: - with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp: - body = resp.read() - data = json.loads(body) - - # Validate response shape using Pydantic - try: - session = SessionModel.model_validate(data) - except ValidationError as e: - raise RuntimeError(f"Invalid session response from {url}: {e}") - - sid = session.id - if not sid: - raise RuntimeError(f"No session id returned from {url}: {data}") - return sid - except urllib.error.HTTPError as e: - raise RuntimeError(f"HTTP error getting session: {e}") - except Exception as e: - raise RuntimeError(f"Error getting session: {e}") - - -def create_or_get_lobby( - server_url: str, - session_id: str, - lobby_name: str, - private: bool = False, - insecure: bool = False, -) -> str: - """Call POST /api/lobby/{session_id} to create or lookup a lobby by name. - - Returns the lobby id. - """ - http_base = _http_base_url(server_url) - url = f"{http_base}/api/lobby/{urllib.parse.quote(session_id)}" - payload = json.dumps( - { - "type": "lobby_create", - "data": {"name": lobby_name, "private": private}, - } - ).encode("utf-8") - req = urllib.request.Request( - url, data=payload, headers={"Content-Type": "application/json"}, method="POST" - ) - # Prepare SSL context if requested (accept self-signed certs) - ssl_ctx = None - if insecure: - ssl_ctx = ssl.create_default_context() - ssl_ctx.check_hostname = False - ssl_ctx.verify_mode = ssl.CERT_NONE - - try: - with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp: - body = resp.read() - data = json.loads(body) - # Expect shape: { "type": "lobby_created", "data": {"id":..., ...}} - try: - lobby_resp = LobbyCreateResponse.model_validate(data) - except ValidationError as e: - raise RuntimeError(f"Invalid lobby response from {url}: {e}") - - lobby_id = lobby_resp.data.id - if not lobby_id: - raise RuntimeError(f"No lobby id returned from {url}: {data}") - return lobby_id - except urllib.error.HTTPError as e: - # Try to include response body for.infoging - try: - body = e.read() - msg = body.decode("utf-8", errors="ignore") - except Exception: - msg = str(e) - raise RuntimeError(f"HTTP error creating lobby: {msg}") - except Exception as e: - raise RuntimeError(f"Error creating lobby: {e}") +# Create app instance for uvicorn import +uvicorn_app = get_app() async def main(): @@ -1278,465 +59,10 @@ async def main(): await main_with_args(voicebot_args) -async def main_with_args(args: VoicebotArgs): - """Main voicebot client logic that accepts arguments object.""" - - # Resolve session id (create if needed) - try: - session_id = create_or_get_session( - args.server_url, args.session_id, insecure=args.insecure - ) - print(f"Using session id: {session_id}") - except Exception as e: - print(f"Failed to get/create session: {e}") - return - - # Create or get lobby id - try: - lobby_id = create_or_get_lobby( - args.server_url, - session_id, - args.lobby, - args.private, - insecure=args.insecure, - ) - print(f"Using lobby id: {lobby_id} (name={args.lobby})") - except Exception as e: - print(f"Failed to create/get lobby: {e}") - return - - # Build websocket base URL (ws:// or wss://) from server_url and pass to client so - # it constructs the final websocket path (/ws/lobby/{lobby}/{session}) itself. - ws_base = _ws_url(args.server_url) - - client = WebRTCSignalingClient( - ws_base, lobby_id, session_id, args.session_name, - insecure=args.insecure, - registration_check_interval=args.registration_check_interval - ) - - # Set up event handlers - async def on_peer_added(peer: Peer): - print(f"Peer added: {peer.peer_name}") - - async def on_peer_removed(peer: Peer): - print(f"Peer removed: {peer.peer_name}") - - # Remove any video tracks from this peer from our synthetic video track - if "video" in client.local_tracks: - synthetic_video_track = client.local_tracks["video"] - if isinstance(synthetic_video_track, AnimatedVideoTrack): - # We need to identify and remove tracks from this specific peer - # Since we don't have a direct mapping, we'll need to track this differently - # For now, this is a placeholder - we might need to enhance the peer tracking - logger.info( - f"Peer {peer.peer_name} removed - may need to clean up video tracks" - ) - - async def on_track_received(peer: Peer, track: MediaStreamTrack): - print(f"Received {track.kind} track from {peer.peer_name}") - - # If it's a video track, attach it to our synthetic video track for edge detection - if track.kind == "video" and "video" in client.local_tracks: - synthetic_video_track = client.local_tracks["video"] - if isinstance(synthetic_video_track, AnimatedVideoTrack): - synthetic_video_track.add_remote_video_track(track) - logger.info( - f"Attached remote video track from {peer.peer_name} to synthetic video track" - ) - - client.on_peer_added = on_peer_added - client.on_peer_removed = on_peer_removed - client.on_track_received = on_track_received - - # Retry loop for connection resilience - max_retries = 5 - retry_delay = 5.0 # seconds - retry_count = 0 - - while retry_count < max_retries: - try: - # If a password was provided on the CLI, store it on the client for use when setting name - if args.password: - client.name_password = args.password - - print(f"Attempting to connect (attempt {retry_count + 1}/{max_retries})") - await client.connect() - - except KeyboardInterrupt: - print("Shutting down...") - break - except Exception as e: - logger.error(f"Connection failed (attempt {retry_count + 1}): {e}") - retry_count += 1 - - if retry_count < max_retries: - print(f"Retrying in {retry_delay} seconds...") - await asyncio.sleep(retry_delay) - # Exponential backoff with max delay of 60 seconds - retry_delay = min(retry_delay * 1.5, 60.0) - else: - print("Max retries exceeded. Giving up.") - break - finally: - try: - await client.disconnect() - except Exception as e: - logger.error(f"Error during disconnect: {e}") - - print("Voicebot client stopped.") - - -# --- FastAPI service for bot discovery and orchestration ------------------- - -app = FastAPI(title="voicebot-bot-orchestrator") - -# Global client app instance for uvicorn import -client_app = None - -# Global client arguments storage -_client_args: Optional[VoicebotArgs] = None - -def create_client_app(args: VoicebotArgs): - """Create a FastAPI app for client mode that uvicorn can import.""" - import asyncio - from contextlib import asynccontextmanager - import os - import fcntl - - global _client_args - _client_args = args - - # Store the client task globally so we can manage it - client_task = None - lock_file = None - - @asynccontextmanager - async def lifespan(app: FastAPI): - nonlocal client_task, lock_file - # Startup - # Use a file lock to prevent multiple instances from starting - lock_file_path = "/tmp/voicebot_client.lock" - - try: - lock_file = open(lock_file_path, 'w') - # Try to acquire an exclusive lock (non-blocking) - fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) - - if _client_args is None: - logger.error("Client args not initialized") - if lock_file: - lock_file.close() - lock_file = None - yield - return - - logger.info("Starting voicebot client...") - client_task = asyncio.create_task(main_with_args(_client_args)) - - except (IOError, OSError): - # Another process already has the lock - logger.info("Another instance is already running - skipping client startup") - if lock_file: - lock_file.close() - lock_file = None - - yield - - # Shutdown - if client_task and not client_task.done(): - logger.info("Shutting down voicebot client...") - client_task.cancel() - try: - await client_task - except asyncio.CancelledError: - pass - - if lock_file: - try: - fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN) - lock_file.close() - os.unlink(lock_file_path) - except Exception: - pass - - # Create the client FastAPI app - app = FastAPI(title="voicebot-client", lifespan=lifespan) - - @app.get("/health") - async def health_check():# pyright: ignore - """Simple health check endpoint""" - return {"status": "running", "mode": "client"} - - @app.get("/status", response_model=ClientStatusResponse) - async def client_status() -> ClientStatusResponse:# pyright: ignore - """Get client status""" - return ClientStatusResponse( - client_running=client_task is not None and not client_task.done(), - session_name=_client_args.session_name if _client_args else 'unknown', - lobby=_client_args.lobby if _client_args else 'unknown', - server_url=_client_args.server_url if _client_args else 'unknown' - ) - - return app - -# Function to get the appropriate app based on environment variable -def get_app(): - """Get the appropriate FastAPI app based on VOICEBOT_MODE environment variable.""" - import os - mode = os.getenv('VOICEBOT_MODE', 'provider') - - if mode == 'client': - # For client mode, we need to create the client app with args from environment - args = VoicebotArgs.from_environment() - return create_client_app(args) - else: - # Provider mode - return the main bot orchestration app - return app - -# Create app instance for uvicorn import -uvicorn_app = get_app() - - -class JoinRequest(BaseModel): - lobby_id: str - session_id: str - nick: str - server_url: str - insecure: bool = False - - -def discover_bots() -> Dict[str, Dict[str, Any]]: - """Discover bot modules under the voicebot.bots package that expose bot_info. - - This intentionally imports modules under `voicebot.bots` so heavy bot - implementations can remain in that package and be imported lazily. - """ - bots: Dict[str, Dict[str, Any]] = {} - try: - package = importlib.import_module("voicebot.bots") - package_path = package.__path__ - except Exception: - logger.exception("Failed to import voicebot.bots package") - return bots - - for _finder, name, _ispkg in pkgutil.iter_modules(package_path): - try: - mod = importlib.import_module(f"voicebot.bots.{name}") - except Exception: - logger.exception("Failed to import voicebot.bots.%s", name) - continue - info = None - create_tracks = None - if hasattr(mod, "agent_info") and callable(getattr(mod, "agent_info")): - try: - info = mod.agent_info() - # Note: Keep copy as is to maintain structure - except Exception: - logger.exception("agent_info() failed for %s", name) - if hasattr(mod, "create_agent_tracks") and callable(getattr(mod, "create_agent_tracks")): - create_tracks = getattr(mod, "create_agent_tracks") - - if info: - bots[info.get("name", name)] = {"module": name, "info": info, "create_tracks": create_tracks} - return bots - - -@app.get("/bots") -def list_bots() -> Dict[str, Any]: - bots = discover_bots() - return {k: v["info"] for k, v in bots.items()} - - -@app.post("/bots/{bot_name}/join") -async def bot_join(bot_name: str, req: JoinRequest): - bots = discover_bots() - bot = bots.get(bot_name) - if not bot: - raise HTTPException(status_code=404, detail="Bot not found") - - create_tracks = bot.get("create_tracks") - - # Start the WebRTCSignalingClient in a background asyncio task and register it - client = WebRTCSignalingClient( - server_url=req.server_url, - lobby_id=req.lobby_id, - session_id=req.session_id, - session_name=req.nick, - insecure=req.insecure, - create_tracks=create_tracks, - ) - - run_id = str(uuid.uuid4()) - - async def run_client(): - try: - registry[run_id] = client - await client.connect() - except Exception: - logger.exception("Bot client failed for run %s", run_id) - finally: - registry.pop(run_id, None) - - loop = asyncio.get_event_loop() - threading.Thread(target=loop.run_until_complete, args=(run_client(),), daemon=True).start() - - return {"status": "started", "bot": bot_name, "run_id": run_id} - - -# Lightweight in-memory registry of running bot clients -registry: Dict[str, WebRTCSignalingClient] = {} - - -@app.post("/bots/runs/{run_id}/stop") -async def stop_run(run_id: str): - client = registry.get(run_id) - if not client: - raise HTTPException(status_code=404, detail="Run not found") - try: - await client.disconnect() - except Exception: - logger.exception("Failed to stop run %s", run_id) - raise HTTPException(status_code=500, detail="Failed to stop run") - registry.pop(run_id, None) - return {"status": "stopped", "run_id": run_id} - - -@app.get("/bots/runs") -def list_runs() -> Dict[str, Any]: - return { - "runs": [ - {"run_id": run_id, "session_id": client.session_id, "session_name": client.session_name} - for run_id, client in registry.items() - ] - } - - -def start_bot_api(host: str = "0.0.0.0", port: int = 8788): - """Start the bot orchestration API server""" - uvicorn.run(app, host=host, port=port) - - -async def register_with_server(server_url: str, voicebot_url: str, insecure: bool = False) -> str: - """Register this voicebot instance as a bot provider with the main server""" - try: - # Import httpx locally to avoid dependency issues - import httpx - - payload = { - "base_url": voicebot_url.rstrip('/'), - "name": "voicebot-provider", - "description": "AI voicebot provider with speech recognition and synthetic media capabilities" - } - - # Prepare SSL context if needed - verify = not insecure - - async with httpx.AsyncClient(verify=verify) as client: - response = await client.post( - f"{server_url}/api/bots/providers/register", - json=payload, - timeout=10.0 - ) - - if response.status_code == 200: - result = response.json() - provider_id = result.get("provider_id") - logger.info(f"Successfully registered with server as provider: {provider_id}") - return provider_id - else: - logger.error(f"Failed to register with server: HTTP {response.status_code}: {response.text}") - raise RuntimeError(f"Registration failed: {response.status_code}") - - except Exception as e: - logger.error(f"Error registering with server: {e}") - raise - - -def start_bot_provider( - host: str = "0.0.0.0", - port: int = 8788, - server_url: str | None = None, - insecure: bool = False, - reload: bool = False -): - """Start the bot provider API server and optionally register with main server""" - import time - - # Start the FastAPI server in a background thread - import threading - - # Add reload functionality for development - if reload: - server_thread = threading.Thread( - target=lambda: uvicorn.run( - app, - host=host, - port=port, - log_level="info", - reload=True, - reload_dirs=["/voicebot", "/shared"] - ), - daemon=True - ) - else: - server_thread = threading.Thread( - target=lambda: uvicorn.run(app, host=host, port=port, log_level="info"), - daemon=True - ) - - server_thread.start() - - # If server_url is provided, register with the main server - if server_url: - # Give the server a moment to start - time.sleep(2) - - # Construct the voicebot URL - voicebot_url = f"http://{host}:{port}" - if host == "0.0.0.0": - # Try to get a better hostname - import socket - try: - hostname = socket.gethostname() - voicebot_url = f"http://{hostname}:{port}" - except Exception: - voicebot_url = f"http://localhost:{port}" - - try: - asyncio.run(register_with_server(server_url, voicebot_url, insecure)) - except Exception as e: - logger.error(f"Failed to register with server: {e}") - - # Keep the main thread alive - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - logger.info("Shutting down bot provider...") - - -def start_client_with_reload(args: VoicebotArgs): - """Start the client with auto-reload functionality.""" - global client_app - - logger.info("Creating client app for uvicorn...") - client_app = create_client_app(args) - - # Note: This function is called when --reload is specified - # The actual uvicorn execution should be handled by the entrypoint script - logger.info("Client app created. Uvicorn should be started by entrypoint script.") - - # Fall back to running client directly if not using uvicorn - asyncio.run(main_with_args(args)) - - if __name__ == "__main__": # Install required packages: # pip install aiortc websockets opencv-python numpy - import argparse - # Check if we're being run as a bot provider or as a client parser = argparse.ArgumentParser(description="AI Voicebot - WebRTC client or bot provider") parser.add_argument("--mode", choices=["client", "provider"], default="client", @@ -1780,4 +106,3 @@ if __name__ == "__main__": start_client_with_reload(voicebot_args) else: asyncio.run(main_with_args(voicebot_args)) - diff --git a/voicebot/models.py b/voicebot/models.py new file mode 100644 index 0000000..9caea93 --- /dev/null +++ b/voicebot/models.py @@ -0,0 +1,121 @@ +""" +Data models and configuration for voicebot. + +This module provides Pydantic models for configuration and data structures +used throughout the voicebot application. +""" + +from __future__ import annotations + +import argparse +from enum import Enum +from typing import Dict, Optional, TYPE_CHECKING +from dataclasses import dataclass, field +from pydantic import BaseModel, Field + +if TYPE_CHECKING: + from aiortc import RTCPeerConnection + + +class VoicebotMode(str, Enum): + """Voicebot operation modes.""" + CLIENT = "client" + PROVIDER = "provider" + + +class VoicebotArgs(BaseModel): + """Pydantic model for voicebot CLI arguments and configuration.""" + + # Mode selection + mode: VoicebotMode = Field(default=VoicebotMode.CLIENT, description="Run as client (connect to lobby) or provider (serve bots)") + + # Provider mode arguments + host: str = Field(default="0.0.0.0", description="Host for provider mode") + port: int = Field(default=8788, description="Port for provider mode", ge=1, le=65535) + reload: bool = Field(default=False, description="Enable auto-reload for development") + + # Client mode arguments + server_url: str = Field( + default="http://localhost:8000/ai-voicebot", + description="AI-Voicebot lobby and signaling server base URL (http:// or https://)" + ) + lobby: str = Field(default="default", description="Lobby name to create or join") + session_name: str = Field(default="Python Bot", description="Session (user) display name") + session_id: Optional[str] = Field(default=None, description="Optional existing session id to reuse") + password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name") + private: bool = Field(default=False, description="Create the lobby as private") + insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL") + registration_check_interval: float = Field(default=30.0, description="Interval in seconds for checking registration status", ge=5.0, le=300.0) + + @classmethod + def from_environment(cls) -> 'VoicebotArgs': + """Create VoicebotArgs from environment variables.""" + import os + + mode_str = os.getenv('VOICEBOT_MODE', 'client') + return cls( + mode=VoicebotMode(mode_str), + host=os.getenv('VOICEBOT_HOST', '0.0.0.0'), + port=int(os.getenv('VOICEBOT_PORT', '8788')), + reload=os.getenv('VOICEBOT_RELOAD', 'false').lower() == 'true', + server_url=os.getenv('VOICEBOT_SERVER_URL', 'http://localhost:8000/ai-voicebot'), + lobby=os.getenv('VOICEBOT_LOBBY', 'default'), + session_name=os.getenv('VOICEBOT_SESSION_NAME', 'Python Bot'), + session_id=os.getenv('VOICEBOT_SESSION_ID', None), + password=os.getenv('VOICEBOT_PASSWORD', None), + private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true', + insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true', + registration_check_interval=float(os.getenv('VOICEBOT_REGISTRATION_CHECK_INTERVAL', '30.0')) + ) + + @classmethod + def from_argparse(cls, args: argparse.Namespace) -> 'VoicebotArgs': + """Create VoicebotArgs from argparse Namespace.""" + mode_str = getattr(args, 'mode', 'client') + return cls( + mode=VoicebotMode(mode_str), + host=getattr(args, 'host', '0.0.0.0'), + port=getattr(args, 'port', 8788), + reload=getattr(args, 'reload', False), + server_url=getattr(args, 'server_url', 'http://localhost:8000/ai-voicebot'), + lobby=getattr(args, 'lobby', 'default'), + session_name=getattr(args, 'session_name', 'Python Bot'), + session_id=getattr(args, 'session_id', None), + password=getattr(args, 'password', None), + private=getattr(args, 'private', False), + insecure=getattr(args, 'insecure', False), + registration_check_interval=float(getattr(args, 'registration_check_interval', 30.0)) + ) + + +class JoinRequest(BaseModel): + """Request model for joining a lobby.""" + lobby_id: str + session_id: str + nick: str + server_url: str + insecure: bool = False + + +def _default_attributes() -> Dict[str, object]: + """Default factory for peer attributes.""" + return {} + + +@dataclass +class Peer: + """Represents a WebRTC peer in the session""" + + session_id: str + peer_name: str + # Generic attributes bag. Values can be tracks or simple metadata. + attributes: Dict[str, object] = field(default_factory=_default_attributes) + muted: bool = False + video_on: bool = True + local: bool = False + dead: bool = False + connection: Optional['RTCPeerConnection'] = None + + +# Generic message payload type +MessageData = dict[str, object] diff --git a/voicebot/session_manager.py b/voicebot/session_manager.py new file mode 100644 index 0000000..a4251ce --- /dev/null +++ b/voicebot/session_manager.py @@ -0,0 +1,118 @@ +""" +Session and lobby management for voicebot. + +This module handles session creation and lobby management functionality. +""" + +import json +import ssl +import urllib.request +import urllib.error +import urllib.parse +import sys +import os +from pydantic import ValidationError + +# Add the parent directory to sys.path to allow absolute imports +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Import shared models +from shared.models import SessionModel, LobbyCreateResponse + +from voicebot.utils import http_base_url + + +def create_or_get_session( + server_url: str, session_id: str | None = None, insecure: bool = False +) -> str: + """Call GET /api/session to obtain a session_id (unless one was provided). + + Uses urllib so no extra runtime deps are required. + """ + if session_id: + return session_id + + http_base = http_base_url(server_url) + url = f"{http_base}/api/session" + req = urllib.request.Request(url, method="GET") + # Prepare SSL context if requested (accept self-signed certs) + ssl_ctx = None + if insecure: + ssl_ctx = ssl.create_default_context() + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = ssl.CERT_NONE + + try: + with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp: + body = resp.read() + data = json.loads(body) + + # Validate response shape using Pydantic + try: + session = SessionModel.model_validate(data) + except ValidationError as e: + raise RuntimeError(f"Invalid session response from {url}: {e}") + + sid = session.id + if not sid: + raise RuntimeError(f"No session id returned from {url}: {data}") + return sid + except urllib.error.HTTPError as e: + raise RuntimeError(f"HTTP error getting session: {e}") + except Exception as e: + raise RuntimeError(f"Error getting session: {e}") + + +def create_or_get_lobby( + server_url: str, + session_id: str, + lobby_name: str, + private: bool = False, + insecure: bool = False, +) -> str: + """Call POST /api/lobby/{session_id} to create or lookup a lobby by name. + + Returns the lobby id. + """ + http_base = http_base_url(server_url) + url = f"{http_base}/api/lobby/{urllib.parse.quote(session_id)}" + payload = json.dumps( + { + "type": "lobby_create", + "data": {"name": lobby_name, "private": private}, + } + ).encode("utf-8") + req = urllib.request.Request( + url, data=payload, headers={"Content-Type": "application/json"}, method="POST" + ) + # Prepare SSL context if requested (accept self-signed certs) + ssl_ctx = None + if insecure: + ssl_ctx = ssl.create_default_context() + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = ssl.CERT_NONE + + try: + with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp: + body = resp.read() + data = json.loads(body) + # Expect shape: { "type": "lobby_created", "data": {"id":..., ...}} + try: + lobby_resp = LobbyCreateResponse.model_validate(data) + except ValidationError as e: + raise RuntimeError(f"Invalid lobby response from {url}: {e}") + + lobby_id = lobby_resp.data.id + if not lobby_id: + raise RuntimeError(f"No lobby id returned from {url}: {data}") + return lobby_id + except urllib.error.HTTPError as e: + # Try to include response body for debugging + try: + body = e.read() + msg = body.decode("utf-8", errors="ignore") + except Exception: + msg = str(e) + raise RuntimeError(f"HTTP error creating lobby: {msg}") + except Exception as e: + raise RuntimeError(f"Error creating lobby: {e}") diff --git a/voicebot/utils.py b/voicebot/utils.py new file mode 100644 index 0000000..299f681 --- /dev/null +++ b/voicebot/utils.py @@ -0,0 +1,57 @@ +""" +Utility functions for voicebot. + +This module provides common utility functions used throughout the application. +""" + +import ssl + + +def http_base_url(server_url: str) -> str: + """Convert ws:// or wss:// to http(s) and ensure no trailing slash.""" + if server_url.startswith("ws://"): + return "http://" + server_url[len("ws://") :].rstrip("/") + if server_url.startswith("wss://"): + return "https://" + server_url[len("wss://") :].rstrip("/") + return server_url.rstrip("/") + + +def ws_url(server_url: str) -> str: + """Convert http(s) to ws(s) if needed.""" + if server_url.startswith("http://"): + return "ws://" + server_url[len("http://") :].rstrip("/") + if server_url.startswith("https://"): + return "wss://" + server_url[len("https://") :].rstrip("/") + return server_url.rstrip("/") + + +def create_ssl_context(insecure: bool = False) -> ssl.SSLContext | None: + """Create SSL context for connections.""" + if not insecure: + return None + + ssl_ctx = ssl.create_default_context() + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = ssl.CERT_NONE + return ssl_ctx + + +def log_network_info(): + """Log network information for debugging.""" + from logger import logger + + try: + import socket + import subprocess + + hostname = socket.gethostname() + local_ip = socket.gethostbyname(hostname) + logger.info(f"Container hostname: {hostname}, local IP: {local_ip}") + + # Get all network interfaces + result = subprocess.run( + ["ip", "addr", "show"], capture_output=True, text=True + ) + logger.info(f"Network interfaces:\n{result.stdout}") + except Exception as e: + logger.warning(f"Could not get network info: {e}") diff --git a/voicebot/webrtc_signaling.py b/voicebot/webrtc_signaling.py new file mode 100644 index 0000000..61436a0 --- /dev/null +++ b/voicebot/webrtc_signaling.py @@ -0,0 +1,894 @@ +""" +WebRTC signaling client for voicebot. + +This module provides WebRTC signaling server communication and peer connection management. +Synthetic audio/video track creation is handled by the bots.synthetic_media module. +""" + +from __future__ import annotations + +import asyncio +import json +import websockets +import time +import re +from typing import ( + Dict, + Optional, + Callable, + Awaitable, + Protocol, + AsyncIterator, + cast, +) + +# Add the parent directory to sys.path to allow absolute imports + +from pydantic import ValidationError +from aiortc import ( + RTCPeerConnection, + RTCSessionDescription, + RTCIceCandidate, + MediaStreamTrack, +) +from aiortc.rtcconfiguration import RTCConfiguration, RTCIceServer +from aiortc.sdp import candidate_from_sdp + +# Import shared models +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from shared.models import ( + WebSocketMessageModel, + JoinStatusModel, + UserJoinedModel, + LobbyStateModel, + UpdateNameModel, + AddPeerModel, + RemovePeerModel, + SessionDescriptionModel, + IceCandidateModel, + ICECandidateDictModel, + SessionDescriptionTypedModel, +) + +from logger import logger +from voicebot.bots.synthetic_media import create_synthetic_tracks +from voicebot.models import Peer, MessageData +from voicebot.utils import create_ssl_context, log_network_info + + +class WebSocketProtocol(Protocol): + def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ... + def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ... + def __aiter__(self) -> AsyncIterator[str]: ... + + +class WebRTCSignalingClient: + """ + WebRTC signaling client that communicates with the FastAPI signaling server. + Handles peer-to-peer connection establishment and media streaming. + """ + + def __init__( + self, + server_url: str, + lobby_id: str, + session_id: str, + session_name: str, + insecure: bool = False, + create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None, + registration_check_interval: float = 30.0, + ): + self.server_url = server_url + self.lobby_id = lobby_id + self.session_id = session_id + self.session_name = session_name + self.insecure = insecure + + # Optional factory to create local media tracks for this client (bot provided) + self.create_tracks = create_tracks + + # WebSocket client protocol instance (typed as object to avoid Any) + self.websocket: Optional[object] = None + + # Optional password to register or takeover a name + self.name_password: Optional[str] = None + + self.peers: dict[str, Peer] = {} + self.peer_connections: dict[str, RTCPeerConnection] = {} + self.local_tracks: dict[str, MediaStreamTrack] = {} + + # State management + self.is_negotiating: dict[str, bool] = {} + self.making_offer: dict[str, bool] = {} + self.initiated_offer: set[str] = set() + self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {} + + # Registration status tracking + self.is_registered: bool = False + self.last_registration_check: float = 0 + self.registration_check_interval: float = registration_check_interval + self.registration_check_task: Optional[asyncio.Task[None]] = None + + # Event callbacks + self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None + self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None + self.on_track_received: Optional[ + Callable[[Peer, MediaStreamTrack], Awaitable[None]] + ] = None + + async def connect(self): + """Connect to the signaling server""" + ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}" + logger.info(f"Connecting to signaling server: {ws_url}") + + # Log network information for debugging + log_network_info() + + try: + # If insecure (self-signed certs), create an SSL context for the websocket + ws_ssl = create_ssl_context(self.insecure) + + logger.info( + f"Attempting websocket connection to {ws_url} with ssl={bool(ws_ssl)}" + ) + self.websocket = await websockets.connect(ws_url, ssl=ws_ssl) + logger.info("Connected to signaling server") + + # Set up local media + await self._setup_local_media() + + # Set name and join lobby + name_payload: MessageData = {"name": self.session_name} + if self.name_password: + name_payload["password"] = self.name_password + logger.info(f"Sending set_name: {name_payload}") + await self._send_message("set_name", name_payload) + logger.info("Sending join message") + await self._send_message("join", {}) + + # Mark as registered after successful join + self.is_registered = True + self.last_registration_check = time.time() + + # Start periodic registration check + self.registration_check_task = asyncio.create_task(self._periodic_registration_check()) + + # Start message handling + logger.info("Starting message handler loop") + try: + await self._handle_messages() + except Exception as e: + logger.error(f"Message handling stopped: {e}") + self.is_registered = False + raise + + except Exception as e: + logger.error(f"Failed to connect to signaling server: {e}", exc_info=True) + raise + + async def _periodic_registration_check(self): + """Periodically check registration status and re-register if needed""" + + while True: + try: + await asyncio.sleep(self.registration_check_interval) + + current_time = time.time() + if current_time - self.last_registration_check < self.registration_check_interval: + continue + + # Check if we're still connected and registered + if not await self._check_registration_status(): + logger.warning("Registration check failed, attempting to re-register") + await self._re_register() + + self.last_registration_check = current_time + + except asyncio.CancelledError: + logger.info("Registration check task cancelled") + break + except Exception as e: + logger.error(f"Error in periodic registration check: {e}", exc_info=True) + # Continue checking even if one iteration fails + continue + + async def _check_registration_status(self) -> bool: + """Check if the voicebot is still registered with the server""" + try: + # First check if websocket is still connected + if not self.websocket: + logger.warning("WebSocket connection lost") + return False + + # Try to send a ping/status check message to verify connection + # We'll use a simple status message to check connectivity + try: + await self._send_message("status_check", {"timestamp": time.time()}) + logger.debug("Registration status check sent") + return True + except Exception as e: + logger.warning(f"Failed to send status check: {e}") + return False + + except Exception as e: + logger.error(f"Error checking registration status: {e}") + return False + + async def _re_register(self): + """Attempt to re-register with the server""" + try: + logger.info("Attempting to re-register with server") + + # Mark as not registered during re-registration attempt + self.is_registered = False + + # Try to reconnect the websocket if it's lost + if not self.websocket: + logger.info("WebSocket lost, attempting to reconnect") + await self._reconnect_websocket() + + # Re-send name and join messages + name_payload: MessageData = {"name": self.session_name} + if self.name_password: + name_payload["password"] = self.name_password + + logger.info("Re-sending set_name message") + await self._send_message("set_name", name_payload) + + logger.info("Re-sending join message") + await self._send_message("join", {}) + + # Mark as registered after successful re-join + self.is_registered = True + self.last_registration_check = time.time() + + logger.info("Successfully re-registered with server") + + except Exception as e: + logger.error(f"Failed to re-register with server: {e}", exc_info=True) + # Will try again on next check interval + + async def _reconnect_websocket(self): + """Reconnect the WebSocket connection""" + try: + # Close existing connection if any + if self.websocket: + try: + ws = cast(WebSocketProtocol, self.websocket) + await ws.close() + except Exception: + pass + self.websocket = None + + # Reconnect + ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}" + + # If insecure (self-signed certs), create an SSL context for the websocket + ws_ssl = create_ssl_context(self.insecure) + + logger.info(f"Reconnecting to signaling server: {ws_url}") + self.websocket = await websockets.connect(ws_url, ssl=ws_ssl) + logger.info("Successfully reconnected to signaling server") + + except Exception as e: + logger.error(f"Failed to reconnect websocket: {e}", exc_info=True) + raise + + async def disconnect(self): + """Disconnect from signaling server and cleanup""" + # Cancel the registration check task + if self.registration_check_task and not self.registration_check_task.done(): + self.registration_check_task.cancel() + try: + await self.registration_check_task + except asyncio.CancelledError: + pass + self.registration_check_task = None + + if self.websocket: + ws = cast(WebSocketProtocol, self.websocket) + await ws.close() + + # Close all peer connections + for pc in self.peer_connections.values(): + await pc.close() + + # Stop local tracks + for track in self.local_tracks.values(): + track.stop() + + # Reset registration status + self.is_registered = False + + logger.info("Disconnected from signaling server") + + async def _setup_local_media(self): + """Create local media tracks""" + # If a bot provided a create_tracks callable, use it to create tracks. + # Otherwise, use default synthetic tracks. + try: + if self.create_tracks: + tracks = self.create_tracks(self.session_name) + self.local_tracks.update(tracks) + else: + # Default fallback to synthetic tracks + tracks = create_synthetic_tracks(self.session_name) + self.local_tracks.update(tracks) + except Exception: + logger.exception("Failed to create local tracks using bot factory") + + # Add local peer to peers dict + local_peer = Peer( + session_id=self.session_id, + peer_name=self.session_name, + local=True, + attributes={"tracks": self.local_tracks}, + ) + self.peers[self.session_id] = local_peer + + logger.info("Local media tracks created") + + async def _send_message( + self, message_type: str, data: Optional[MessageData] = None + ): + """Send message to signaling server""" + if not self.websocket: + logger.error("No websocket connection") + return + # Build message with explicit type to avoid type narrowing + message: dict[str, object] = {"type": message_type} + if data is not None: + message["data"] = data + + ws = cast(WebSocketProtocol, self.websocket) + try: + logger.debug(f"_send_message: Sending {message_type} with data: {data}") + await ws.send(json.dumps(message)) + logger.debug(f"_send_message: Sent message: {message_type}") + except Exception as e: + logger.error( + f"_send_message: Failed to send {message_type}: {e}", exc_info=True + ) + + async def _handle_messages(self): + """Handle incoming messages from signaling server""" + try: + ws = cast(WebSocketProtocol, self.websocket) + async for message in ws: + logger.debug(f"_handle_messages: Received raw message: {message}") + try: + data = cast(MessageData, json.loads(message)) + except Exception as e: + logger.error( + f"_handle_messages: Failed to parse message: {e}", exc_info=True + ) + continue + await self._process_message(data) + except websockets.exceptions.ConnectionClosed as e: + logger.warning(f"WebSocket connection closed: {e}") + self.is_registered = False + # The periodic registration check will detect this and attempt reconnection + except Exception as e: + logger.error(f"Error handling messages: {e}", exc_info=True) + self.is_registered = False + + async def _process_message(self, message: MessageData): + """Process incoming signaling messages""" + try: + # Validate the base message structure first + validated_message = WebSocketMessageModel.model_validate(message) + msg_type = validated_message.type + data = validated_message.data + except ValidationError as e: + logger.error(f"Invalid message structure: {e}", exc_info=True) + return + + logger.debug( + f"_process_message: Received message type: {msg_type} with data: {data}" + ) + + if msg_type == "addPeer": + try: + validated = AddPeerModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid addPeer payload: {e}", exc_info=True) + return + await self._handle_add_peer(validated) + elif msg_type == "removePeer": + try: + validated = RemovePeerModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid removePeer payload: {e}", exc_info=True) + return + await self._handle_remove_peer(validated) + elif msg_type == "sessionDescription": + try: + validated = SessionDescriptionModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid sessionDescription payload: {e}", exc_info=True) + return + await self._handle_session_description(validated) + elif msg_type == "iceCandidate": + try: + validated = IceCandidateModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid iceCandidate payload: {e}", exc_info=True) + return + await self._handle_ice_candidate(validated) + elif msg_type == "join_status": + try: + validated = JoinStatusModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid join_status payload: {e}", exc_info=True) + return + logger.info(f"Join status: {validated.status} - {validated.message}") + elif msg_type == "user_joined": + try: + validated = UserJoinedModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid user_joined payload: {e}", exc_info=True) + return + logger.info( + f"User joined: {validated.name} (session: {validated.session_id})" + ) + elif msg_type == "lobby_state": + try: + validated = LobbyStateModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid lobby_state payload: {e}", exc_info=True) + return + participants = validated.participants + logger.info(f"Lobby state updated: {len(participants)} participants") + elif msg_type == "update_name": + try: + validated = UpdateNameModel.model_validate(data) + except ValidationError as e: + logger.error(f"Invalid update payload: {e}", exc_info=True) + return + logger.info(f"Received update message: {validated}") + else: + logger.info(f"Unhandled message type: {msg_type} with data: {data}") + + # Continue with more methods in the next part... + + async def _handle_add_peer(self, data: AddPeerModel): + """Handle addPeer message - create new peer connection""" + peer_id = data.peer_id + peer_name = data.peer_name + should_create_offer = data.should_create_offer + + logger.info( + f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})" + ) + logger.debug( + f"_handle_add_peer: peer_id={peer_id}, peer_name={peer_name}, should_create_offer={should_create_offer}" + ) + + # Check if peer already exists + if peer_id in self.peer_connections: + pc = self.peer_connections[peer_id] + logger.debug( + f"_handle_add_peer: Existing connection state: {pc.connectionState}" + ) + if pc.connectionState in ["new", "connected", "connecting"]: + logger.info(f"Peer connection already exists for {peer_name}") + return + else: + # Clean up stale connection + logger.debug( + f"_handle_add_peer: Closing stale connection for {peer_name}" + ) + await pc.close() + del self.peer_connections[peer_id] + + # Create new peer + peer = Peer(session_id=peer_id, peer_name=peer_name, local=False) + self.peers[peer_id] = peer + + # Create RTCPeerConnection + config = RTCConfiguration( + iceServers=[ + RTCIceServer(urls="stun:ketrenos.com:3478"), + RTCIceServer( + urls="turns:ketrenos.com:5349", + username="ketra", + credential="ketran", + ), + # Add Google's public STUN server as fallback + RTCIceServer(urls="stun:stun.l.google.com:19302"), + ], + ) + logger.debug( + f"_handle_add_peer: Creating RTCPeerConnection for {peer_name} with config: {config}" + ) + pc = RTCPeerConnection(configuration=config) + + # Add ICE gathering state change handler + def on_ice_gathering_state_change() -> None: + logger.info(f"ICE gathering state: {pc.iceGatheringState}") + if pc.iceGatheringState == "complete": + logger.info( + f"ICE gathering complete for {peer_name} - checking if candidates were generated..." + ) + + pc.on("icegatheringstatechange")(on_ice_gathering_state_change) + + # Add connection state change handler + def on_connection_state_change() -> None: + logger.info(f"Connection state: {pc.connectionState}") + + pc.on("connectionstatechange")(on_connection_state_change) + + self.peer_connections[peer_id] = pc + peer.connection = pc + + # Set up event handlers + def on_track(track: MediaStreamTrack) -> None: + logger.info(f"Received {track.kind} track from {peer_name}") + logger.info(f"on_track: {track.kind} from {peer_name}, track={track}") + peer.attributes[f"{track.kind}_track"] = track + if self.on_track_received: + asyncio.ensure_future(self.on_track_received(peer, track)) + + pc.on("track")(on_track) + + def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None: + logger.info(f"on_ice_candidate: {candidate}") + logger.info( + f"on_ice_candidate CALLED for {peer_name}: candidate={candidate}" + ) + if not candidate: + logger.info( + f"on_ice_candidate: End of candidates signal for {peer_name}" + ) + return + + # Raw SDP fragment for the candidate + raw = getattr(candidate, "candidate", None) + + # Try to infer candidate type from the SDP string (host/srflx/relay/prflx) + def _parse_type(s: Optional[str]) -> str: + if not s: + return "eoc" + m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", s) + return m.group(1) if m else "unknown" + + cand_type = _parse_type(raw) + protocol = getattr(candidate, "protocol", "unknown") + logger.info( + f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}" + ) + + candidate_model = ICECandidateDictModel( + candidate=raw, + sdpMid=getattr(candidate, "sdpMid", None), + sdpMLineIndex=getattr(candidate, "sdpMLineIndex", None), + ) + payload_model = IceCandidateModel( + peer_id=peer_id, peer_name=peer_name, candidate=candidate_model + ) + logger.info( + f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_model}" + ) + asyncio.ensure_future( + self._send_message("relayICECandidate", payload_model.model_dump()) + ) + + pc.on("icecandidate")(on_ice_candidate) + + # Add local tracks + for track in self.local_tracks.values(): + logger.debug( + f"_handle_add_peer: Adding local track {track.kind} to {peer_name}" + ) + pc.addTrack(track) + + # Create offer if needed + if should_create_offer: + await self._create_and_send_offer(peer_id, peer_name, pc) + + if self.on_peer_added: + await self.on_peer_added(peer) + + async def _create_and_send_offer(self, peer_id: str, peer_name: str, pc: RTCPeerConnection): + """Create and send an offer to a peer""" + self.initiated_offer.add(peer_id) + self.making_offer[peer_id] = True + self.is_negotiating[peer_id] = True + + try: + logger.debug(f"_handle_add_peer: Creating offer for {peer_name}") + offer = await pc.createOffer() + logger.debug( + f"_handle_add_peer: Offer created for {peer_name}: {offer}" + ) + await pc.setLocalDescription(offer) + logger.debug(f"_handle_add_peer: Local description set for {peer_name}") + + # WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344) + # Use Method 2: Complete SDP approach to extract ICE candidates + logger.debug( + f"_handle_add_peer: Waiting for ICE gathering to complete for {peer_name}" + ) + while pc.iceGatheringState != "complete": + await asyncio.sleep(0.1) + + logger.debug( + f"_handle_add_peer: ICE gathering complete, extracting candidates from SDP for {peer_name}" + ) + + await self._extract_and_send_candidates(peer_id, peer_name, pc) + + session_desc_typed = SessionDescriptionTypedModel( + type=offer.type, sdp=offer.sdp + ) + session_desc_model = SessionDescriptionModel( + peer_id=peer_id, + peer_name=peer_name, + session_description=session_desc_typed, + ) + await self._send_message( + "relaySessionDescription", + session_desc_model.model_dump(), + ) + + logger.info(f"Offer sent to {peer_name}") + except Exception as e: + logger.error( + f"Failed to create/send offer to {peer_name}: {e}", exc_info=True + ) + finally: + self.making_offer[peer_id] = False + + async def _extract_and_send_candidates(self, peer_id: str, peer_name: str, pc: RTCPeerConnection): + """Extract ICE candidates from SDP and send them""" + # Parse ICE candidates from the local SDP + sdp_lines = pc.localDescription.sdp.split("\n") + candidate_lines = [ + line for line in sdp_lines if line.startswith("a=candidate:") + ] + + # Track which media section we're in to determine sdpMid and sdpMLineIndex + current_media_index = -1 + current_mid = None + + for line in sdp_lines: + if line.startswith("m="): # Media section + current_media_index += 1 + elif line.startswith("a=mid:"): # Media ID + current_mid = line.split(":", 1)[1].strip() + elif line.startswith("a=candidate:"): + candidate_sdp = line[2:] # Remove 'a=' prefix + + candidate_model = ICECandidateDictModel( + candidate=candidate_sdp, + sdpMid=current_mid, + sdpMLineIndex=current_media_index, + ) + payload_candidate = IceCandidateModel( + peer_id=peer_id, + peer_name=peer_name, + candidate=candidate_model, + ) + + logger.debug( + f"_extract_and_send_candidates: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}..." + ) + await self._send_message( + "relayICECandidate", payload_candidate.model_dump() + ) + + # Send end-of-candidates signal (empty candidate) + end_candidate_model = ICECandidateDictModel( + candidate="", + sdpMid=None, + sdpMLineIndex=None, + ) + payload_end = IceCandidateModel( + peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model + ) + logger.debug( + f"_extract_and_send_candidates: Sending end-of-candidates signal for {peer_name}" + ) + await self._send_message("relayICECandidate", payload_end.model_dump()) + + logger.debug( + f"_extract_and_send_candidates: Sent {len(candidate_lines)} ICE candidates to {peer_name}" + ) + + async def _handle_remove_peer(self, data: RemovePeerModel): + """Handle removePeer message""" + peer_id = data.peer_id + peer_name = data.peer_name + + logger.info(f"Removing peer: {peer_name}") + + # Close peer connection + if peer_id in self.peer_connections: + pc = self.peer_connections[peer_id] + await pc.close() + del self.peer_connections[peer_id] + + # Clean up state + self.is_negotiating.pop(peer_id, None) + self.making_offer.pop(peer_id, None) + self.initiated_offer.discard(peer_id) + self.pending_ice_candidates.pop(peer_id, None) + + # Remove peer + peer = self.peers.pop(peer_id, None) + if peer and self.on_peer_removed: + await self.on_peer_removed(peer) + + async def _handle_session_description(self, data: SessionDescriptionModel): + """Handle sessionDescription message""" + peer_id = data.peer_id + peer_name = data.peer_name + session_description = data.session_description.model_dump() + + logger.info(f"Received {session_description['type']} from {peer_name}") + + pc = self.peer_connections.get(peer_id) + if not pc: + logger.error(f"No peer connection for {peer_name}") + return + + desc = RTCSessionDescription( + sdp=session_description["sdp"], type=session_description["type"] + ) + + # Handle offer collision (polite peer pattern) + making_offer = self.making_offer.get(peer_id, False) + offer_collision = desc.type == "offer" and ( + making_offer or pc.signalingState != "stable" + ) + we_initiated = peer_id in self.initiated_offer + ignore_offer = we_initiated and offer_collision + + if ignore_offer: + logger.info(f"Ignoring offer from {peer_name} due to collision") + return + + try: + await pc.setRemoteDescription(desc) + self.is_negotiating[peer_id] = False + logger.info(f"Remote description set for {peer_name}") + + # Process queued ICE candidates + pending_candidates = self.pending_ice_candidates.pop(peer_id, []) + + for candidate_data in pending_candidates: + # candidate_data is an ICECandidateDictModel Pydantic model + cand = candidate_data.candidate + # handle end-of-candidates marker + if not cand: + await pc.addIceCandidate(None) + logger.info(f"Added queued end-of-candidates for {peer_name}") + continue + + # cand may be the full "candidate:..." string or the inner SDP part + if cand and cand.startswith("candidate:"): + sdp_part = cand.split(":", 1)[1] + else: + sdp_part = cand + + try: + rtc_candidate = candidate_from_sdp(sdp_part) + rtc_candidate.sdpMid = candidate_data.sdpMid + rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex + await pc.addIceCandidate(rtc_candidate) + logger.info(f"Added queued ICE candidate for {peer_name}") + except Exception as e: + logger.error( + f"Failed to add queued ICE candidate for {peer_name}: {e}" + ) + + except Exception as e: + logger.error(f"Failed to set remote description for {peer_name}: {e}") + return + + # Create answer if this was an offer + if session_description["type"] == "offer": + await self._create_and_send_answer(peer_id, peer_name, pc) + + async def _create_and_send_answer(self, peer_id: str, peer_name: str, pc: RTCPeerConnection): + """Create and send an answer to a peer""" + try: + answer = await pc.createAnswer() + await pc.setLocalDescription(answer) + + # WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344) + # Use Method 2: Complete SDP approach to extract ICE candidates + logger.debug( + f"_create_and_send_answer: Waiting for ICE gathering to complete for {peer_name} (answer)" + ) + while pc.iceGatheringState != "complete": + await asyncio.sleep(0.1) + + logger.debug( + f"_create_and_send_answer: ICE gathering complete, extracting candidates from SDP for {peer_name} (answer)" + ) + + await self._extract_and_send_candidates(peer_id, peer_name, pc) + + session_desc_typed = SessionDescriptionTypedModel( + type=answer.type, sdp=answer.sdp + ) + session_desc_model = SessionDescriptionModel( + peer_id=peer_id, + peer_name=peer_name, + session_description=session_desc_typed, + ) + await self._send_message( + "relaySessionDescription", + session_desc_model.model_dump(), + ) + + logger.info(f"Answer sent to {peer_name}") + except Exception as e: + logger.error(f"Failed to create/send answer to {peer_name}: {e}") + + async def _handle_ice_candidate(self, data: IceCandidateModel): + """Handle iceCandidate message""" + peer_id = data.peer_id + peer_name = data.peer_name + candidate_data = data.candidate + + logger.info(f"Received ICE candidate from {peer_name}") + + pc = self.peer_connections.get(peer_id) + if not pc: + logger.error(f"No peer connection for {peer_name}") + return + + # Queue candidate if remote description not set + if not pc.remoteDescription: + logger.info( + f"Remote description not set, queuing ICE candidate for {peer_name}" + ) + if peer_id not in self.pending_ice_candidates: + self.pending_ice_candidates[peer_id] = [] + # candidate_data is an ICECandidateDictModel Pydantic model + self.pending_ice_candidates[peer_id].append(candidate_data) + return + + try: + cand = candidate_data.candidate + if not cand: + # end-of-candidates + await pc.addIceCandidate(None) + logger.info(f"End-of-candidates added for {peer_name}") + return + + if cand and cand.startswith("candidate:"): + sdp_part = cand.split(":", 1)[1] + else: + sdp_part = cand + + # Detect type for logging + try: + m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", sdp_part) + cand_type = m.group(1) if m else "unknown" + except Exception: + cand_type = "unknown" + + try: + rtc_candidate = candidate_from_sdp(sdp_part) + rtc_candidate.sdpMid = candidate_data.sdpMid + rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex + + # aiortc expects an object with attributes (RTCIceCandidate) + await pc.addIceCandidate(rtc_candidate) + logger.info(f"ICE candidate added for {peer_name}: type={cand_type}") + except Exception as e: + logger.error( + f"Failed to add ICE candidate for {peer_name}: type={cand_type} error={e} sdp='{sdp_part}'", + exc_info=True, + ) + except Exception as e: + logger.error( + f"Unexpected error handling ICE candidate for {peer_name}: {e}", + exc_info=True, + )