Refactored voicebot/main.py

This commit is contained in:
James Ketr 2025-09-03 14:33:15 -07:00
parent 2e91a4eadb
commit b916db243b
13 changed files with 1904 additions and 1763 deletions

View File

@ -24,6 +24,7 @@ RUN apt-get update \
libgl1 \
libglib2.0-0t64 \
git \
iproute2 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}

View File

@ -523,7 +523,7 @@ class Session:
try:
if os.path.exists(cls._save_file + ".tmp"):
os.remove(cls._save_file + ".tmp")
except:
except Exception as e:
pass
@classmethod
@ -1971,7 +1971,7 @@ async def lobby_join(
)
try:
await websocket.close()
except:
except Exception as e:
pass

View File

@ -0,0 +1,82 @@
# Voicebot Module Refactoring
The voicebot/main.py functionality has been broken down into individual Python files for better organization and maintainability:
## New File Structure
### Core Modules
1. **`models.py`** - Data models and configuration
- `VoicebotArgs` - Pydantic model for CLI arguments and configuration
- `VoicebotMode` - Enum for client/provider modes
- `Peer` - WebRTC peer representation
- `JoinRequest` - Request model for joining lobbies
- `MessageData` - Type alias for message payloads
2. **`webrtc_signaling.py`** - WebRTC signaling client functionality
- `WebRTCSignalingClient` - Main WebRTC signaling client class
- Handles peer connection management, ICE candidates, session descriptions
- Registration status tracking and reconnection logic
- Message processing and event handling
3. **`session_manager.py`** - Session and lobby management
- `create_or_get_session()` - Session creation/retrieval
- `create_or_get_lobby()` - Lobby creation/retrieval
- HTTP API communication utilities
4. **`bot_orchestrator.py`** - FastAPI bot orchestration service
- Bot discovery and management
- FastAPI endpoints for bot operations
- Provider registration with main server
- Bot instance lifecycle management
5. **`client_main.py`** - Main client logic
- `main_with_args()` - Core client functionality
- `start_client_with_reload()` - Development mode with reload
- Event handlers for peer and track management
6. **`client_app.py`** - Client FastAPI application
- `create_client_app()` - Creates FastAPI app for client mode
- Health check and status endpoints
- Process isolation and locking
7. **`utils.py`** - Utility functions
- URL conversion utilities (`http_base_url`, `ws_url`)
- SSL context creation
- Network information logging
8. **`main.py`** - Main orchestration and entry point
- Command-line argument parsing
- Mode selection (client vs provider)
- Entry points for both modes
### Key Improvements
- **Separation of Concerns**: Each file handles specific functionality
- **Better Maintainability**: Smaller, focused modules are easier to understand and modify
- **Reduced Coupling**: Dependencies between components are more explicit
- **Type Safety**: Proper type hints and Pydantic models throughout
- **Error Handling**: Centralized error handling and logging
### Usage
The refactored code maintains the same CLI interface:
```bash
# Client mode
python voicebot/main.py --mode client --server-url http://localhost:8000/ai-voicebot
# Provider mode
python voicebot/main.py --mode provider --host 0.0.0.0 --port 8788
```
### Import Structure
```python
from voicebot import VoicebotArgs, VoicebotMode, WebRTCSignalingClient
from voicebot.models import Peer, JoinRequest
from voicebot.session_manager import create_or_get_session, create_or_get_lobby
from voicebot.client_main import main_with_args
```
The original `main_old.py` contains the monolithic implementation for reference.

30
voicebot/__init__.py Normal file
View File

@ -0,0 +1,30 @@
"""
Voicebot package.
This package provides WebRTC signaling client functionality and bot orchestration
for AI voicebots.
"""
import sys
import os
# Add the parent directory to sys.path to allow absolute imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from voicebot.models import VoicebotArgs, VoicebotMode, Peer, JoinRequest
from voicebot.webrtc_signaling import WebRTCSignalingClient
from voicebot.session_manager import create_or_get_session, create_or_get_lobby
from voicebot.client_main import main_with_args
from voicebot.bot_orchestrator import app as bot_orchestrator_app
__all__ = [
'VoicebotArgs',
'VoicebotMode',
'Peer',
'JoinRequest',
'WebRTCSignalingClient',
'create_or_get_session',
'create_or_get_lobby',
'main_with_args',
'bot_orchestrator_app',
]

View File

@ -0,0 +1,238 @@
"""
Bot orchestrator FastAPI service.
This module provides the FastAPI service for bot discovery and orchestration.
"""
import asyncio
import threading
import uuid
import importlib
import pkgutil
import sys
import os
from typing import Dict, Any
# Add the parent directory to sys.path to allow absolute imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import uvicorn
from fastapi import FastAPI, HTTPException
from logger import logger
from voicebot.models import JoinRequest
from voicebot.webrtc_signaling import WebRTCSignalingClient
app = FastAPI(title="voicebot-bot-orchestrator")
# Lightweight in-memory registry of running bot clients
registry: Dict[str, WebRTCSignalingClient] = {}
def discover_bots() -> Dict[str, Dict[str, Any]]:
"""Discover bot modules under the voicebot.bots package that expose bot_info.
This intentionally imports modules under `voicebot.bots` so heavy bot
implementations can remain in that package and be imported lazily.
"""
bots: Dict[str, Dict[str, Any]] = {}
try:
package = importlib.import_module("voicebot.bots")
package_path = package.__path__
except Exception:
logger.exception("Failed to import voicebot.bots package")
return bots
for _finder, name, _ispkg in pkgutil.iter_modules(package_path):
try:
mod = importlib.import_module(f"voicebot.bots.{name}")
except Exception:
logger.exception("Failed to import voicebot.bots.%s", name)
continue
info = None
create_tracks = None
if hasattr(mod, "agent_info") and callable(getattr(mod, "agent_info")):
try:
info = mod.agent_info()
# Note: Keep copy as is to maintain structure
except Exception:
logger.exception("agent_info() failed for %s", name)
if hasattr(mod, "create_agent_tracks") and callable(getattr(mod, "create_agent_tracks")):
create_tracks = getattr(mod, "create_agent_tracks")
if info:
bots[info.get("name", name)] = {"module": name, "info": info, "create_tracks": create_tracks}
return bots
@app.get("/bots")
def list_bots() -> Dict[str, Any]:
"""List available bots."""
bots = discover_bots()
return {k: v["info"] for k, v in bots.items()}
@app.post("/bots/{bot_name}/join")
async def bot_join(bot_name: str, req: JoinRequest):
"""Make a bot join a lobby."""
bots = discover_bots()
bot = bots.get(bot_name)
if not bot:
raise HTTPException(status_code=404, detail="Bot not found")
create_tracks = bot.get("create_tracks")
# Start the WebRTCSignalingClient in a background asyncio task and register it
client = WebRTCSignalingClient(
server_url=req.server_url,
lobby_id=req.lobby_id,
session_id=req.session_id,
session_name=req.nick,
insecure=req.insecure,
create_tracks=create_tracks,
)
run_id = str(uuid.uuid4())
async def run_client():
try:
registry[run_id] = client
await client.connect()
except Exception:
logger.exception("Bot client failed for run %s", run_id)
finally:
registry.pop(run_id, None)
loop = asyncio.get_event_loop()
threading.Thread(target=loop.run_until_complete, args=(run_client(),), daemon=True).start()
return {"status": "started", "bot": bot_name, "run_id": run_id}
@app.post("/bots/runs/{run_id}/stop")
async def stop_run(run_id: str):
"""Stop a running bot."""
client = registry.get(run_id)
if not client:
raise HTTPException(status_code=404, detail="Run not found")
try:
await client.disconnect()
except Exception:
logger.exception("Failed to stop run %s", run_id)
raise HTTPException(status_code=500, detail="Failed to stop run")
registry.pop(run_id, None)
return {"status": "stopped", "run_id": run_id}
@app.get("/bots/runs")
def list_runs() -> Dict[str, Any]:
"""List running bot instances."""
return {
"runs": [
{"run_id": run_id, "session_id": client.session_id, "session_name": client.session_name}
for run_id, client in registry.items()
]
}
def start_bot_api(host: str = "0.0.0.0", port: int = 8788):
"""Start the bot orchestration API server"""
uvicorn.run(app, host=host, port=port)
async def register_with_server(server_url: str, voicebot_url: str, insecure: bool = False) -> str:
"""Register this voicebot instance as a bot provider with the main server"""
try:
# Import httpx locally to avoid dependency issues
import httpx
payload = {
"base_url": voicebot_url.rstrip('/'),
"name": "voicebot-provider",
"description": "AI voicebot provider with speech recognition and synthetic media capabilities"
}
# Prepare SSL context if needed
verify = not insecure
async with httpx.AsyncClient(verify=verify) as client:
response = await client.post(
f"{server_url}/api/bots/providers/register",
json=payload,
timeout=10.0
)
if response.status_code == 200:
result = response.json()
provider_id = result.get("provider_id")
logger.info(f"Successfully registered with server as provider: {provider_id}")
return provider_id
else:
logger.error(f"Failed to register with server: HTTP {response.status_code}: {response.text}")
raise RuntimeError(f"Registration failed: {response.status_code}")
except Exception as e:
logger.error(f"Error registering with server: {e}")
raise
def start_bot_provider(
host: str = "0.0.0.0",
port: int = 8788,
server_url: str | None = None,
insecure: bool = False,
reload: bool = False
):
"""Start the bot provider API server and optionally register with main server"""
import time
import socket
# Start the FastAPI server in a background thread
# Add reload functionality for development
if reload:
server_thread = threading.Thread(
target=lambda: uvicorn.run(
app,
host=host,
port=port,
log_level="info",
reload=True,
reload_dirs=["/voicebot", "/shared"]
),
daemon=True
)
else:
server_thread = threading.Thread(
target=lambda: uvicorn.run(app, host=host, port=port, log_level="info"),
daemon=True
)
logger.info(f"Starting bot provider API server on {host}:{port}...")
server_thread.start()
# If server_url is provided, register with the main server
if server_url:
# Give the server a moment to start
time.sleep(2)
# Construct the voicebot URL
voicebot_url = f"http://{host}:{port}"
if host == "0.0.0.0":
# Try to get a better hostname
try:
hostname = socket.gethostname()
voicebot_url = f"http://{hostname}:{port}"
except Exception:
voicebot_url = f"http://localhost:{port}"
try:
asyncio.run(register_with_server(server_url, voicebot_url, insecure))
except Exception as e:
logger.error(f"Failed to register with server: {e}")
# Keep the main thread alive
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("Shutting down bot provider...")

View File

@ -4,11 +4,11 @@ Lightweight agent descriptor; heavy model loading must be done by a controller
when the agent is actually used.
"""
from typing import Dict
from typing import Any
from typing import Dict, Any
import librosa
from logger import logger
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
from aiortc import MediaStreamTrack
AGENT_NAME = "whisper"
@ -19,11 +19,12 @@ def agent_info() -> Dict[str, str]:
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION}
def create_agent_tracks(session_name: str) -> dict:
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
"""Whisper is not a media source - return no local tracks."""
return {}
model_ids = {
def do_work():
model_ids = {
"Distil-Whisper": [
"distil-whisper/distil-large-v2",
"distil-whisper/distil-medium.en",
@ -42,31 +43,32 @@ model_ids = {
"openai/whisper-base.en",
"openai/whisper-tiny.en",
]
}
}
model_type = model_ids["Distil-Whisper"]
model_type = model_ids["Distil-Whisper"]
logger.info(model_type)
model_id = model_type[0]
logger.info(model_type)
model_id = model_type[0]
processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
pt_model.eval() # type: ignore
pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
pt_model.eval() # type: ignore
def extract_input_features(audio_array: Any, sampling_rate: int) -> Any:
def extract_input_features(audio_array: Any, sampling_rate: int) -> Any:
"""Extract input features from audio array and sampling rate."""
input_features = processor(
processor_output = processor( # type: ignore
audio_array,
sampling_rate=sampling_rate,
return_tensors="pt",
).input_features
return input_features
)
input_features: Any = processor_output.input_features # type: ignore
return input_features # type: ignore
def load_audio_file(file_path: str) -> tuple[Any, int]:
def load_audio_file(file_path: str) -> tuple[Any, int]:
"""Load audio file from disk and return audio array and sampling rate."""
# Whisper models expect 16kHz sample rate
target_sample_rate = 16000
@ -87,11 +89,11 @@ def load_audio_file(file_path: str) -> tuple[Any, int]:
raise
# Example usage - replace with your audio file path
audio_file_path = "/voicebot/F_0818_15y11m_1.wav"
# Example usage - replace with your audio file path
audio_file_path = "/voicebot/F_0818_15y11m_1.wav"
# Load audio from file instead of dataset
try:
# Load audio from file instead of dataset
try:
audio_array, sampling_rate = load_audio_file(audio_file_path)
input_features = extract_input_features(audio_array, sampling_rate)
@ -101,9 +103,11 @@ try:
print(f"Audio file: {audio_file_path}")
print(f"Transcription: {transcription[0]}")
except FileNotFoundError:
except FileNotFoundError:
logger.error(f"Audio file not found: {audio_file_path}")
print("Please update the audio_file_path variable with a valid path to your wav file")
except Exception as e:
except Exception as e:
logger.error(f"Error processing audio: {e}")
print(f"Error: {e}")

127
voicebot/client_app.py Normal file
View File

@ -0,0 +1,127 @@
"""
Client FastAPI application for voicebot.
This module provides the FastAPI application for client mode operations.
"""
import asyncio
import os
import fcntl
import sys
from contextlib import asynccontextmanager
from typing import Optional
# Add the parent directory to sys.path to allow absolute imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from fastapi import FastAPI
from logger import logger
# Import shared models
from shared.models import ClientStatusResponse
from voicebot.models import VoicebotArgs
# Global client arguments storage
_client_args: Optional[VoicebotArgs] = None
def create_client_app(args: VoicebotArgs) -> FastAPI:
"""Create a FastAPI app for client mode that uvicorn can import."""
global _client_args
_client_args = args
# Store the client task globally so we can manage it
client_task = None
lock_file = None
@asynccontextmanager
async def lifespan(app: FastAPI):
nonlocal client_task, lock_file
# Startup
# Use a file lock to prevent multiple instances from starting
lock_file_path = "/tmp/voicebot_client.lock"
try:
lock_file = open(lock_file_path, 'w')
# Try to acquire an exclusive lock (non-blocking)
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
if _client_args is None:
logger.error("Client args not initialized")
if lock_file:
lock_file.close()
lock_file = None
yield
return
logger.info("Starting voicebot client...")
# Import here to avoid circular imports
from .client_main import main_with_args
client_task = asyncio.create_task(main_with_args(_client_args))
except (IOError, OSError):
# Another process already has the lock
logger.info("Another instance is already running - skipping client startup")
if lock_file:
lock_file.close()
lock_file = None
yield
# Shutdown
if client_task and not client_task.done():
logger.info("Shutting down voicebot client...")
client_task.cancel()
try:
await client_task
except asyncio.CancelledError:
pass
if lock_file:
try:
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
lock_file.close()
os.unlink(lock_file_path)
except Exception:
pass
# Create the client FastAPI app
app = FastAPI(title="voicebot-client", lifespan=lifespan)
@app.get("/health")
async def health_check(): # type: ignore
"""Simple health check endpoint"""
return {"status": "running", "mode": "client"}
@app.get("/status", response_model=ClientStatusResponse)
async def client_status() -> ClientStatusResponse: # type: ignore
"""Get client status"""
return ClientStatusResponse(
client_running=client_task is not None and not client_task.done(),
session_name=_client_args.session_name if _client_args else 'unknown',
lobby=_client_args.lobby if _client_args else 'unknown',
server_url=_client_args.server_url if _client_args else 'unknown'
)
return app
def get_app() -> FastAPI:
"""Get the appropriate FastAPI app based on VOICEBOT_MODE environment variable."""
mode = os.getenv('VOICEBOT_MODE', 'provider')
if mode == 'client':
# For client mode, we need to create the client app with args from environment
args = VoicebotArgs.from_environment()
return create_client_app(args)
else:
# Provider mode - return the main bot orchestration app
from voicebot.bot_orchestrator import app
return app
# Create app instance for uvicorn import
uvicorn_app = get_app()

144
voicebot/client_main.py Normal file
View File

@ -0,0 +1,144 @@
"""
Main client logic for voicebot.
This module contains the main client functionality and entry points.
"""
import asyncio
import sys
import os
from logger import logger
from voicebot.bots.synthetic_media import AnimatedVideoTrack
# Add the parent directory to sys.path to allow absolute imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from voicebot.models import VoicebotArgs, Peer
from voicebot.session_manager import create_or_get_session, create_or_get_lobby
from voicebot.webrtc_signaling import WebRTCSignalingClient
from voicebot.utils import ws_url
from aiortc import MediaStreamTrack
async def main_with_args(args: VoicebotArgs):
"""Main voicebot client logic that accepts arguments object."""
# Resolve session id (create if needed)
try:
session_id = create_or_get_session(
args.server_url, args.session_id, insecure=args.insecure
)
print(f"Using session id: {session_id}")
except Exception as e:
print(f"Failed to get/create session: {e}")
return
# Create or get lobby id
try:
lobby_id = create_or_get_lobby(
args.server_url,
session_id,
args.lobby,
args.private,
insecure=args.insecure,
)
print(f"Using lobby id: {lobby_id} (name={args.lobby})")
except Exception as e:
print(f"Failed to create/get lobby: {e}")
return
# Build websocket base URL (ws:// or wss://) from server_url and pass to client so
# it constructs the final websocket path (/ws/lobby/{lobby}/{session}) itself.
ws_base = ws_url(args.server_url)
client = WebRTCSignalingClient(
ws_base, lobby_id, session_id, args.session_name,
insecure=args.insecure,
registration_check_interval=args.registration_check_interval
)
# Set up event handlers
async def on_peer_added(peer: Peer):
print(f"Peer added: {peer.peer_name}")
async def on_peer_removed(peer: Peer):
print(f"Peer removed: {peer.peer_name}")
# Remove any video tracks from this peer from our synthetic video track
if "video" in client.local_tracks:
synthetic_video_track = client.local_tracks["video"]
if isinstance(synthetic_video_track, AnimatedVideoTrack):
# We need to identify and remove tracks from this specific peer
# Since we don't have a direct mapping, we'll need to track this differently
# For now, this is a placeholder - we might need to enhance the peer tracking
logger.info(
f"Peer {peer.peer_name} removed - may need to clean up video tracks"
)
async def on_track_received(peer: Peer, track: MediaStreamTrack):
print(f"Received {track.kind} track from {peer.peer_name}")
# If it's a video track, attach it to our synthetic video track for edge detection
if track.kind == "video" and "video" in client.local_tracks:
synthetic_video_track = client.local_tracks["video"]
if isinstance(synthetic_video_track, AnimatedVideoTrack):
synthetic_video_track.add_remote_video_track(track)
logger.info(
f"Attached remote video track from {peer.peer_name} to synthetic video track"
)
client.on_peer_added = on_peer_added
client.on_peer_removed = on_peer_removed
client.on_track_received = on_track_received
# Retry loop for connection resilience
max_retries = 5
retry_delay = 5.0 # seconds
retry_count = 0
while retry_count < max_retries:
try:
# If a password was provided on the CLI, store it on the client for use when setting name
if args.password:
client.name_password = args.password
print(f"Attempting to connect (attempt {retry_count + 1}/{max_retries})")
await client.connect()
except KeyboardInterrupt:
print("Shutting down...")
break
except Exception as e:
logger.error(f"Connection failed (attempt {retry_count + 1}): {e}")
retry_count += 1
if retry_count < max_retries:
print(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
# Exponential backoff with max delay of 60 seconds
retry_delay = min(retry_delay * 1.5, 60.0)
else:
print("Max retries exceeded. Giving up.")
break
finally:
try:
await client.disconnect()
except Exception as e:
logger.error(f"Error during disconnect: {e}")
print("Voicebot client stopped.")
def start_client_with_reload(args: VoicebotArgs):
"""Start the client with auto-reload functionality."""
logger.info("Creating client app for uvicorn...")
from voicebot.client_app import create_client_app
create_client_app(args)
# Note: This function is called when --reload is specified
# The actual uvicorn execution should be handled by the entrypoint script
logger.info("Client app created. Uvicorn should be started by entrypoint script.")
# Fall back to running client directly if not using uvicorn
asyncio.run(main_with_args(args))

File diff suppressed because it is too large Load Diff

121
voicebot/models.py Normal file
View File

@ -0,0 +1,121 @@
"""
Data models and configuration for voicebot.
This module provides Pydantic models for configuration and data structures
used throughout the voicebot application.
"""
from __future__ import annotations
import argparse
from enum import Enum
from typing import Dict, Optional, TYPE_CHECKING
from dataclasses import dataclass, field
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from aiortc import RTCPeerConnection
class VoicebotMode(str, Enum):
"""Voicebot operation modes."""
CLIENT = "client"
PROVIDER = "provider"
class VoicebotArgs(BaseModel):
"""Pydantic model for voicebot CLI arguments and configuration."""
# Mode selection
mode: VoicebotMode = Field(default=VoicebotMode.CLIENT, description="Run as client (connect to lobby) or provider (serve bots)")
# Provider mode arguments
host: str = Field(default="0.0.0.0", description="Host for provider mode")
port: int = Field(default=8788, description="Port for provider mode", ge=1, le=65535)
reload: bool = Field(default=False, description="Enable auto-reload for development")
# Client mode arguments
server_url: str = Field(
default="http://localhost:8000/ai-voicebot",
description="AI-Voicebot lobby and signaling server base URL (http:// or https://)"
)
lobby: str = Field(default="default", description="Lobby name to create or join")
session_name: str = Field(default="Python Bot", description="Session (user) display name")
session_id: Optional[str] = Field(default=None, description="Optional existing session id to reuse")
password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name")
private: bool = Field(default=False, description="Create the lobby as private")
insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL")
registration_check_interval: float = Field(default=30.0, description="Interval in seconds for checking registration status", ge=5.0, le=300.0)
@classmethod
def from_environment(cls) -> 'VoicebotArgs':
"""Create VoicebotArgs from environment variables."""
import os
mode_str = os.getenv('VOICEBOT_MODE', 'client')
return cls(
mode=VoicebotMode(mode_str),
host=os.getenv('VOICEBOT_HOST', '0.0.0.0'),
port=int(os.getenv('VOICEBOT_PORT', '8788')),
reload=os.getenv('VOICEBOT_RELOAD', 'false').lower() == 'true',
server_url=os.getenv('VOICEBOT_SERVER_URL', 'http://localhost:8000/ai-voicebot'),
lobby=os.getenv('VOICEBOT_LOBBY', 'default'),
session_name=os.getenv('VOICEBOT_SESSION_NAME', 'Python Bot'),
session_id=os.getenv('VOICEBOT_SESSION_ID', None),
password=os.getenv('VOICEBOT_PASSWORD', None),
private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true',
insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true',
registration_check_interval=float(os.getenv('VOICEBOT_REGISTRATION_CHECK_INTERVAL', '30.0'))
)
@classmethod
def from_argparse(cls, args: argparse.Namespace) -> 'VoicebotArgs':
"""Create VoicebotArgs from argparse Namespace."""
mode_str = getattr(args, 'mode', 'client')
return cls(
mode=VoicebotMode(mode_str),
host=getattr(args, 'host', '0.0.0.0'),
port=getattr(args, 'port', 8788),
reload=getattr(args, 'reload', False),
server_url=getattr(args, 'server_url', 'http://localhost:8000/ai-voicebot'),
lobby=getattr(args, 'lobby', 'default'),
session_name=getattr(args, 'session_name', 'Python Bot'),
session_id=getattr(args, 'session_id', None),
password=getattr(args, 'password', None),
private=getattr(args, 'private', False),
insecure=getattr(args, 'insecure', False),
registration_check_interval=float(getattr(args, 'registration_check_interval', 30.0))
)
class JoinRequest(BaseModel):
"""Request model for joining a lobby."""
lobby_id: str
session_id: str
nick: str
server_url: str
insecure: bool = False
def _default_attributes() -> Dict[str, object]:
"""Default factory for peer attributes."""
return {}
@dataclass
class Peer:
"""Represents a WebRTC peer in the session"""
session_id: str
peer_name: str
# Generic attributes bag. Values can be tracks or simple metadata.
attributes: Dict[str, object] = field(default_factory=_default_attributes)
muted: bool = False
video_on: bool = True
local: bool = False
dead: bool = False
connection: Optional['RTCPeerConnection'] = None
# Generic message payload type
MessageData = dict[str, object]

118
voicebot/session_manager.py Normal file
View File

@ -0,0 +1,118 @@
"""
Session and lobby management for voicebot.
This module handles session creation and lobby management functionality.
"""
import json
import ssl
import urllib.request
import urllib.error
import urllib.parse
import sys
import os
from pydantic import ValidationError
# Add the parent directory to sys.path to allow absolute imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import shared models
from shared.models import SessionModel, LobbyCreateResponse
from voicebot.utils import http_base_url
def create_or_get_session(
server_url: str, session_id: str | None = None, insecure: bool = False
) -> str:
"""Call GET /api/session to obtain a session_id (unless one was provided).
Uses urllib so no extra runtime deps are required.
"""
if session_id:
return session_id
http_base = http_base_url(server_url)
url = f"{http_base}/api/session"
req = urllib.request.Request(url, method="GET")
# Prepare SSL context if requested (accept self-signed certs)
ssl_ctx = None
if insecure:
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
try:
with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp:
body = resp.read()
data = json.loads(body)
# Validate response shape using Pydantic
try:
session = SessionModel.model_validate(data)
except ValidationError as e:
raise RuntimeError(f"Invalid session response from {url}: {e}")
sid = session.id
if not sid:
raise RuntimeError(f"No session id returned from {url}: {data}")
return sid
except urllib.error.HTTPError as e:
raise RuntimeError(f"HTTP error getting session: {e}")
except Exception as e:
raise RuntimeError(f"Error getting session: {e}")
def create_or_get_lobby(
server_url: str,
session_id: str,
lobby_name: str,
private: bool = False,
insecure: bool = False,
) -> str:
"""Call POST /api/lobby/{session_id} to create or lookup a lobby by name.
Returns the lobby id.
"""
http_base = http_base_url(server_url)
url = f"{http_base}/api/lobby/{urllib.parse.quote(session_id)}"
payload = json.dumps(
{
"type": "lobby_create",
"data": {"name": lobby_name, "private": private},
}
).encode("utf-8")
req = urllib.request.Request(
url, data=payload, headers={"Content-Type": "application/json"}, method="POST"
)
# Prepare SSL context if requested (accept self-signed certs)
ssl_ctx = None
if insecure:
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
try:
with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp:
body = resp.read()
data = json.loads(body)
# Expect shape: { "type": "lobby_created", "data": {"id":..., ...}}
try:
lobby_resp = LobbyCreateResponse.model_validate(data)
except ValidationError as e:
raise RuntimeError(f"Invalid lobby response from {url}: {e}")
lobby_id = lobby_resp.data.id
if not lobby_id:
raise RuntimeError(f"No lobby id returned from {url}: {data}")
return lobby_id
except urllib.error.HTTPError as e:
# Try to include response body for debugging
try:
body = e.read()
msg = body.decode("utf-8", errors="ignore")
except Exception:
msg = str(e)
raise RuntimeError(f"HTTP error creating lobby: {msg}")
except Exception as e:
raise RuntimeError(f"Error creating lobby: {e}")

57
voicebot/utils.py Normal file
View File

@ -0,0 +1,57 @@
"""
Utility functions for voicebot.
This module provides common utility functions used throughout the application.
"""
import ssl
def http_base_url(server_url: str) -> str:
"""Convert ws:// or wss:// to http(s) and ensure no trailing slash."""
if server_url.startswith("ws://"):
return "http://" + server_url[len("ws://") :].rstrip("/")
if server_url.startswith("wss://"):
return "https://" + server_url[len("wss://") :].rstrip("/")
return server_url.rstrip("/")
def ws_url(server_url: str) -> str:
"""Convert http(s) to ws(s) if needed."""
if server_url.startswith("http://"):
return "ws://" + server_url[len("http://") :].rstrip("/")
if server_url.startswith("https://"):
return "wss://" + server_url[len("https://") :].rstrip("/")
return server_url.rstrip("/")
def create_ssl_context(insecure: bool = False) -> ssl.SSLContext | None:
"""Create SSL context for connections."""
if not insecure:
return None
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
return ssl_ctx
def log_network_info():
"""Log network information for debugging."""
from logger import logger
try:
import socket
import subprocess
hostname = socket.gethostname()
local_ip = socket.gethostbyname(hostname)
logger.info(f"Container hostname: {hostname}, local IP: {local_ip}")
# Get all network interfaces
result = subprocess.run(
["ip", "addr", "show"], capture_output=True, text=True
)
logger.info(f"Network interfaces:\n{result.stdout}")
except Exception as e:
logger.warning(f"Could not get network info: {e}")

View File

@ -0,0 +1,894 @@
"""
WebRTC signaling client for voicebot.
This module provides WebRTC signaling server communication and peer connection management.
Synthetic audio/video track creation is handled by the bots.synthetic_media module.
"""
from __future__ import annotations
import asyncio
import json
import websockets
import time
import re
from typing import (
Dict,
Optional,
Callable,
Awaitable,
Protocol,
AsyncIterator,
cast,
)
# Add the parent directory to sys.path to allow absolute imports
from pydantic import ValidationError
from aiortc import (
RTCPeerConnection,
RTCSessionDescription,
RTCIceCandidate,
MediaStreamTrack,
)
from aiortc.rtcconfiguration import RTCConfiguration, RTCIceServer
from aiortc.sdp import candidate_from_sdp
# Import shared models
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from shared.models import (
WebSocketMessageModel,
JoinStatusModel,
UserJoinedModel,
LobbyStateModel,
UpdateNameModel,
AddPeerModel,
RemovePeerModel,
SessionDescriptionModel,
IceCandidateModel,
ICECandidateDictModel,
SessionDescriptionTypedModel,
)
from logger import logger
from voicebot.bots.synthetic_media import create_synthetic_tracks
from voicebot.models import Peer, MessageData
from voicebot.utils import create_ssl_context, log_network_info
class WebSocketProtocol(Protocol):
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
def __aiter__(self) -> AsyncIterator[str]: ...
class WebRTCSignalingClient:
"""
WebRTC signaling client that communicates with the FastAPI signaling server.
Handles peer-to-peer connection establishment and media streaming.
"""
def __init__(
self,
server_url: str,
lobby_id: str,
session_id: str,
session_name: str,
insecure: bool = False,
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
registration_check_interval: float = 30.0,
):
self.server_url = server_url
self.lobby_id = lobby_id
self.session_id = session_id
self.session_name = session_name
self.insecure = insecure
# Optional factory to create local media tracks for this client (bot provided)
self.create_tracks = create_tracks
# WebSocket client protocol instance (typed as object to avoid Any)
self.websocket: Optional[object] = None
# Optional password to register or takeover a name
self.name_password: Optional[str] = None
self.peers: dict[str, Peer] = {}
self.peer_connections: dict[str, RTCPeerConnection] = {}
self.local_tracks: dict[str, MediaStreamTrack] = {}
# State management
self.is_negotiating: dict[str, bool] = {}
self.making_offer: dict[str, bool] = {}
self.initiated_offer: set[str] = set()
self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {}
# Registration status tracking
self.is_registered: bool = False
self.last_registration_check: float = 0
self.registration_check_interval: float = registration_check_interval
self.registration_check_task: Optional[asyncio.Task[None]] = None
# Event callbacks
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
self.on_track_received: Optional[
Callable[[Peer, MediaStreamTrack], Awaitable[None]]
] = None
async def connect(self):
"""Connect to the signaling server"""
ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
logger.info(f"Connecting to signaling server: {ws_url}")
# Log network information for debugging
log_network_info()
try:
# If insecure (self-signed certs), create an SSL context for the websocket
ws_ssl = create_ssl_context(self.insecure)
logger.info(
f"Attempting websocket connection to {ws_url} with ssl={bool(ws_ssl)}"
)
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
logger.info("Connected to signaling server")
# Set up local media
await self._setup_local_media()
# Set name and join lobby
name_payload: MessageData = {"name": self.session_name}
if self.name_password:
name_payload["password"] = self.name_password
logger.info(f"Sending set_name: {name_payload}")
await self._send_message("set_name", name_payload)
logger.info("Sending join message")
await self._send_message("join", {})
# Mark as registered after successful join
self.is_registered = True
self.last_registration_check = time.time()
# Start periodic registration check
self.registration_check_task = asyncio.create_task(self._periodic_registration_check())
# Start message handling
logger.info("Starting message handler loop")
try:
await self._handle_messages()
except Exception as e:
logger.error(f"Message handling stopped: {e}")
self.is_registered = False
raise
except Exception as e:
logger.error(f"Failed to connect to signaling server: {e}", exc_info=True)
raise
async def _periodic_registration_check(self):
"""Periodically check registration status and re-register if needed"""
while True:
try:
await asyncio.sleep(self.registration_check_interval)
current_time = time.time()
if current_time - self.last_registration_check < self.registration_check_interval:
continue
# Check if we're still connected and registered
if not await self._check_registration_status():
logger.warning("Registration check failed, attempting to re-register")
await self._re_register()
self.last_registration_check = current_time
except asyncio.CancelledError:
logger.info("Registration check task cancelled")
break
except Exception as e:
logger.error(f"Error in periodic registration check: {e}", exc_info=True)
# Continue checking even if one iteration fails
continue
async def _check_registration_status(self) -> bool:
"""Check if the voicebot is still registered with the server"""
try:
# First check if websocket is still connected
if not self.websocket:
logger.warning("WebSocket connection lost")
return False
# Try to send a ping/status check message to verify connection
# We'll use a simple status message to check connectivity
try:
await self._send_message("status_check", {"timestamp": time.time()})
logger.debug("Registration status check sent")
return True
except Exception as e:
logger.warning(f"Failed to send status check: {e}")
return False
except Exception as e:
logger.error(f"Error checking registration status: {e}")
return False
async def _re_register(self):
"""Attempt to re-register with the server"""
try:
logger.info("Attempting to re-register with server")
# Mark as not registered during re-registration attempt
self.is_registered = False
# Try to reconnect the websocket if it's lost
if not self.websocket:
logger.info("WebSocket lost, attempting to reconnect")
await self._reconnect_websocket()
# Re-send name and join messages
name_payload: MessageData = {"name": self.session_name}
if self.name_password:
name_payload["password"] = self.name_password
logger.info("Re-sending set_name message")
await self._send_message("set_name", name_payload)
logger.info("Re-sending join message")
await self._send_message("join", {})
# Mark as registered after successful re-join
self.is_registered = True
self.last_registration_check = time.time()
logger.info("Successfully re-registered with server")
except Exception as e:
logger.error(f"Failed to re-register with server: {e}", exc_info=True)
# Will try again on next check interval
async def _reconnect_websocket(self):
"""Reconnect the WebSocket connection"""
try:
# Close existing connection if any
if self.websocket:
try:
ws = cast(WebSocketProtocol, self.websocket)
await ws.close()
except Exception:
pass
self.websocket = None
# Reconnect
ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
# If insecure (self-signed certs), create an SSL context for the websocket
ws_ssl = create_ssl_context(self.insecure)
logger.info(f"Reconnecting to signaling server: {ws_url}")
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
logger.info("Successfully reconnected to signaling server")
except Exception as e:
logger.error(f"Failed to reconnect websocket: {e}", exc_info=True)
raise
async def disconnect(self):
"""Disconnect from signaling server and cleanup"""
# Cancel the registration check task
if self.registration_check_task and not self.registration_check_task.done():
self.registration_check_task.cancel()
try:
await self.registration_check_task
except asyncio.CancelledError:
pass
self.registration_check_task = None
if self.websocket:
ws = cast(WebSocketProtocol, self.websocket)
await ws.close()
# Close all peer connections
for pc in self.peer_connections.values():
await pc.close()
# Stop local tracks
for track in self.local_tracks.values():
track.stop()
# Reset registration status
self.is_registered = False
logger.info("Disconnected from signaling server")
async def _setup_local_media(self):
"""Create local media tracks"""
# If a bot provided a create_tracks callable, use it to create tracks.
# Otherwise, use default synthetic tracks.
try:
if self.create_tracks:
tracks = self.create_tracks(self.session_name)
self.local_tracks.update(tracks)
else:
# Default fallback to synthetic tracks
tracks = create_synthetic_tracks(self.session_name)
self.local_tracks.update(tracks)
except Exception:
logger.exception("Failed to create local tracks using bot factory")
# Add local peer to peers dict
local_peer = Peer(
session_id=self.session_id,
peer_name=self.session_name,
local=True,
attributes={"tracks": self.local_tracks},
)
self.peers[self.session_id] = local_peer
logger.info("Local media tracks created")
async def _send_message(
self, message_type: str, data: Optional[MessageData] = None
):
"""Send message to signaling server"""
if not self.websocket:
logger.error("No websocket connection")
return
# Build message with explicit type to avoid type narrowing
message: dict[str, object] = {"type": message_type}
if data is not None:
message["data"] = data
ws = cast(WebSocketProtocol, self.websocket)
try:
logger.debug(f"_send_message: Sending {message_type} with data: {data}")
await ws.send(json.dumps(message))
logger.debug(f"_send_message: Sent message: {message_type}")
except Exception as e:
logger.error(
f"_send_message: Failed to send {message_type}: {e}", exc_info=True
)
async def _handle_messages(self):
"""Handle incoming messages from signaling server"""
try:
ws = cast(WebSocketProtocol, self.websocket)
async for message in ws:
logger.debug(f"_handle_messages: Received raw message: {message}")
try:
data = cast(MessageData, json.loads(message))
except Exception as e:
logger.error(
f"_handle_messages: Failed to parse message: {e}", exc_info=True
)
continue
await self._process_message(data)
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"WebSocket connection closed: {e}")
self.is_registered = False
# The periodic registration check will detect this and attempt reconnection
except Exception as e:
logger.error(f"Error handling messages: {e}", exc_info=True)
self.is_registered = False
async def _process_message(self, message: MessageData):
"""Process incoming signaling messages"""
try:
# Validate the base message structure first
validated_message = WebSocketMessageModel.model_validate(message)
msg_type = validated_message.type
data = validated_message.data
except ValidationError as e:
logger.error(f"Invalid message structure: {e}", exc_info=True)
return
logger.debug(
f"_process_message: Received message type: {msg_type} with data: {data}"
)
if msg_type == "addPeer":
try:
validated = AddPeerModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid addPeer payload: {e}", exc_info=True)
return
await self._handle_add_peer(validated)
elif msg_type == "removePeer":
try:
validated = RemovePeerModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid removePeer payload: {e}", exc_info=True)
return
await self._handle_remove_peer(validated)
elif msg_type == "sessionDescription":
try:
validated = SessionDescriptionModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid sessionDescription payload: {e}", exc_info=True)
return
await self._handle_session_description(validated)
elif msg_type == "iceCandidate":
try:
validated = IceCandidateModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid iceCandidate payload: {e}", exc_info=True)
return
await self._handle_ice_candidate(validated)
elif msg_type == "join_status":
try:
validated = JoinStatusModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid join_status payload: {e}", exc_info=True)
return
logger.info(f"Join status: {validated.status} - {validated.message}")
elif msg_type == "user_joined":
try:
validated = UserJoinedModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid user_joined payload: {e}", exc_info=True)
return
logger.info(
f"User joined: {validated.name} (session: {validated.session_id})"
)
elif msg_type == "lobby_state":
try:
validated = LobbyStateModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid lobby_state payload: {e}", exc_info=True)
return
participants = validated.participants
logger.info(f"Lobby state updated: {len(participants)} participants")
elif msg_type == "update_name":
try:
validated = UpdateNameModel.model_validate(data)
except ValidationError as e:
logger.error(f"Invalid update payload: {e}", exc_info=True)
return
logger.info(f"Received update message: {validated}")
else:
logger.info(f"Unhandled message type: {msg_type} with data: {data}")
# Continue with more methods in the next part...
async def _handle_add_peer(self, data: AddPeerModel):
"""Handle addPeer message - create new peer connection"""
peer_id = data.peer_id
peer_name = data.peer_name
should_create_offer = data.should_create_offer
logger.info(
f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})"
)
logger.debug(
f"_handle_add_peer: peer_id={peer_id}, peer_name={peer_name}, should_create_offer={should_create_offer}"
)
# Check if peer already exists
if peer_id in self.peer_connections:
pc = self.peer_connections[peer_id]
logger.debug(
f"_handle_add_peer: Existing connection state: {pc.connectionState}"
)
if pc.connectionState in ["new", "connected", "connecting"]:
logger.info(f"Peer connection already exists for {peer_name}")
return
else:
# Clean up stale connection
logger.debug(
f"_handle_add_peer: Closing stale connection for {peer_name}"
)
await pc.close()
del self.peer_connections[peer_id]
# Create new peer
peer = Peer(session_id=peer_id, peer_name=peer_name, local=False)
self.peers[peer_id] = peer
# Create RTCPeerConnection
config = RTCConfiguration(
iceServers=[
RTCIceServer(urls="stun:ketrenos.com:3478"),
RTCIceServer(
urls="turns:ketrenos.com:5349",
username="ketra",
credential="ketran",
),
# Add Google's public STUN server as fallback
RTCIceServer(urls="stun:stun.l.google.com:19302"),
],
)
logger.debug(
f"_handle_add_peer: Creating RTCPeerConnection for {peer_name} with config: {config}"
)
pc = RTCPeerConnection(configuration=config)
# Add ICE gathering state change handler
def on_ice_gathering_state_change() -> None:
logger.info(f"ICE gathering state: {pc.iceGatheringState}")
if pc.iceGatheringState == "complete":
logger.info(
f"ICE gathering complete for {peer_name} - checking if candidates were generated..."
)
pc.on("icegatheringstatechange")(on_ice_gathering_state_change)
# Add connection state change handler
def on_connection_state_change() -> None:
logger.info(f"Connection state: {pc.connectionState}")
pc.on("connectionstatechange")(on_connection_state_change)
self.peer_connections[peer_id] = pc
peer.connection = pc
# Set up event handlers
def on_track(track: MediaStreamTrack) -> None:
logger.info(f"Received {track.kind} track from {peer_name}")
logger.info(f"on_track: {track.kind} from {peer_name}, track={track}")
peer.attributes[f"{track.kind}_track"] = track
if self.on_track_received:
asyncio.ensure_future(self.on_track_received(peer, track))
pc.on("track")(on_track)
def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None:
logger.info(f"on_ice_candidate: {candidate}")
logger.info(
f"on_ice_candidate CALLED for {peer_name}: candidate={candidate}"
)
if not candidate:
logger.info(
f"on_ice_candidate: End of candidates signal for {peer_name}"
)
return
# Raw SDP fragment for the candidate
raw = getattr(candidate, "candidate", None)
# Try to infer candidate type from the SDP string (host/srflx/relay/prflx)
def _parse_type(s: Optional[str]) -> str:
if not s:
return "eoc"
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", s)
return m.group(1) if m else "unknown"
cand_type = _parse_type(raw)
protocol = getattr(candidate, "protocol", "unknown")
logger.info(
f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}"
)
candidate_model = ICECandidateDictModel(
candidate=raw,
sdpMid=getattr(candidate, "sdpMid", None),
sdpMLineIndex=getattr(candidate, "sdpMLineIndex", None),
)
payload_model = IceCandidateModel(
peer_id=peer_id, peer_name=peer_name, candidate=candidate_model
)
logger.info(
f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_model}"
)
asyncio.ensure_future(
self._send_message("relayICECandidate", payload_model.model_dump())
)
pc.on("icecandidate")(on_ice_candidate)
# Add local tracks
for track in self.local_tracks.values():
logger.debug(
f"_handle_add_peer: Adding local track {track.kind} to {peer_name}"
)
pc.addTrack(track)
# Create offer if needed
if should_create_offer:
await self._create_and_send_offer(peer_id, peer_name, pc)
if self.on_peer_added:
await self.on_peer_added(peer)
async def _create_and_send_offer(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
"""Create and send an offer to a peer"""
self.initiated_offer.add(peer_id)
self.making_offer[peer_id] = True
self.is_negotiating[peer_id] = True
try:
logger.debug(f"_handle_add_peer: Creating offer for {peer_name}")
offer = await pc.createOffer()
logger.debug(
f"_handle_add_peer: Offer created for {peer_name}: {offer}"
)
await pc.setLocalDescription(offer)
logger.debug(f"_handle_add_peer: Local description set for {peer_name}")
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
# Use Method 2: Complete SDP approach to extract ICE candidates
logger.debug(
f"_handle_add_peer: Waiting for ICE gathering to complete for {peer_name}"
)
while pc.iceGatheringState != "complete":
await asyncio.sleep(0.1)
logger.debug(
f"_handle_add_peer: ICE gathering complete, extracting candidates from SDP for {peer_name}"
)
await self._extract_and_send_candidates(peer_id, peer_name, pc)
session_desc_typed = SessionDescriptionTypedModel(
type=offer.type, sdp=offer.sdp
)
session_desc_model = SessionDescriptionModel(
peer_id=peer_id,
peer_name=peer_name,
session_description=session_desc_typed,
)
await self._send_message(
"relaySessionDescription",
session_desc_model.model_dump(),
)
logger.info(f"Offer sent to {peer_name}")
except Exception as e:
logger.error(
f"Failed to create/send offer to {peer_name}: {e}", exc_info=True
)
finally:
self.making_offer[peer_id] = False
async def _extract_and_send_candidates(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
"""Extract ICE candidates from SDP and send them"""
# Parse ICE candidates from the local SDP
sdp_lines = pc.localDescription.sdp.split("\n")
candidate_lines = [
line for line in sdp_lines if line.startswith("a=candidate:")
]
# Track which media section we're in to determine sdpMid and sdpMLineIndex
current_media_index = -1
current_mid = None
for line in sdp_lines:
if line.startswith("m="): # Media section
current_media_index += 1
elif line.startswith("a=mid:"): # Media ID
current_mid = line.split(":", 1)[1].strip()
elif line.startswith("a=candidate:"):
candidate_sdp = line[2:] # Remove 'a=' prefix
candidate_model = ICECandidateDictModel(
candidate=candidate_sdp,
sdpMid=current_mid,
sdpMLineIndex=current_media_index,
)
payload_candidate = IceCandidateModel(
peer_id=peer_id,
peer_name=peer_name,
candidate=candidate_model,
)
logger.debug(
f"_extract_and_send_candidates: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}..."
)
await self._send_message(
"relayICECandidate", payload_candidate.model_dump()
)
# Send end-of-candidates signal (empty candidate)
end_candidate_model = ICECandidateDictModel(
candidate="",
sdpMid=None,
sdpMLineIndex=None,
)
payload_end = IceCandidateModel(
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
)
logger.debug(
f"_extract_and_send_candidates: Sending end-of-candidates signal for {peer_name}"
)
await self._send_message("relayICECandidate", payload_end.model_dump())
logger.debug(
f"_extract_and_send_candidates: Sent {len(candidate_lines)} ICE candidates to {peer_name}"
)
async def _handle_remove_peer(self, data: RemovePeerModel):
"""Handle removePeer message"""
peer_id = data.peer_id
peer_name = data.peer_name
logger.info(f"Removing peer: {peer_name}")
# Close peer connection
if peer_id in self.peer_connections:
pc = self.peer_connections[peer_id]
await pc.close()
del self.peer_connections[peer_id]
# Clean up state
self.is_negotiating.pop(peer_id, None)
self.making_offer.pop(peer_id, None)
self.initiated_offer.discard(peer_id)
self.pending_ice_candidates.pop(peer_id, None)
# Remove peer
peer = self.peers.pop(peer_id, None)
if peer and self.on_peer_removed:
await self.on_peer_removed(peer)
async def _handle_session_description(self, data: SessionDescriptionModel):
"""Handle sessionDescription message"""
peer_id = data.peer_id
peer_name = data.peer_name
session_description = data.session_description.model_dump()
logger.info(f"Received {session_description['type']} from {peer_name}")
pc = self.peer_connections.get(peer_id)
if not pc:
logger.error(f"No peer connection for {peer_name}")
return
desc = RTCSessionDescription(
sdp=session_description["sdp"], type=session_description["type"]
)
# Handle offer collision (polite peer pattern)
making_offer = self.making_offer.get(peer_id, False)
offer_collision = desc.type == "offer" and (
making_offer or pc.signalingState != "stable"
)
we_initiated = peer_id in self.initiated_offer
ignore_offer = we_initiated and offer_collision
if ignore_offer:
logger.info(f"Ignoring offer from {peer_name} due to collision")
return
try:
await pc.setRemoteDescription(desc)
self.is_negotiating[peer_id] = False
logger.info(f"Remote description set for {peer_name}")
# Process queued ICE candidates
pending_candidates = self.pending_ice_candidates.pop(peer_id, [])
for candidate_data in pending_candidates:
# candidate_data is an ICECandidateDictModel Pydantic model
cand = candidate_data.candidate
# handle end-of-candidates marker
if not cand:
await pc.addIceCandidate(None)
logger.info(f"Added queued end-of-candidates for {peer_name}")
continue
# cand may be the full "candidate:..." string or the inner SDP part
if cand and cand.startswith("candidate:"):
sdp_part = cand.split(":", 1)[1]
else:
sdp_part = cand
try:
rtc_candidate = candidate_from_sdp(sdp_part)
rtc_candidate.sdpMid = candidate_data.sdpMid
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
await pc.addIceCandidate(rtc_candidate)
logger.info(f"Added queued ICE candidate for {peer_name}")
except Exception as e:
logger.error(
f"Failed to add queued ICE candidate for {peer_name}: {e}"
)
except Exception as e:
logger.error(f"Failed to set remote description for {peer_name}: {e}")
return
# Create answer if this was an offer
if session_description["type"] == "offer":
await self._create_and_send_answer(peer_id, peer_name, pc)
async def _create_and_send_answer(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
"""Create and send an answer to a peer"""
try:
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
# Use Method 2: Complete SDP approach to extract ICE candidates
logger.debug(
f"_create_and_send_answer: Waiting for ICE gathering to complete for {peer_name} (answer)"
)
while pc.iceGatheringState != "complete":
await asyncio.sleep(0.1)
logger.debug(
f"_create_and_send_answer: ICE gathering complete, extracting candidates from SDP for {peer_name} (answer)"
)
await self._extract_and_send_candidates(peer_id, peer_name, pc)
session_desc_typed = SessionDescriptionTypedModel(
type=answer.type, sdp=answer.sdp
)
session_desc_model = SessionDescriptionModel(
peer_id=peer_id,
peer_name=peer_name,
session_description=session_desc_typed,
)
await self._send_message(
"relaySessionDescription",
session_desc_model.model_dump(),
)
logger.info(f"Answer sent to {peer_name}")
except Exception as e:
logger.error(f"Failed to create/send answer to {peer_name}: {e}")
async def _handle_ice_candidate(self, data: IceCandidateModel):
"""Handle iceCandidate message"""
peer_id = data.peer_id
peer_name = data.peer_name
candidate_data = data.candidate
logger.info(f"Received ICE candidate from {peer_name}")
pc = self.peer_connections.get(peer_id)
if not pc:
logger.error(f"No peer connection for {peer_name}")
return
# Queue candidate if remote description not set
if not pc.remoteDescription:
logger.info(
f"Remote description not set, queuing ICE candidate for {peer_name}"
)
if peer_id not in self.pending_ice_candidates:
self.pending_ice_candidates[peer_id] = []
# candidate_data is an ICECandidateDictModel Pydantic model
self.pending_ice_candidates[peer_id].append(candidate_data)
return
try:
cand = candidate_data.candidate
if not cand:
# end-of-candidates
await pc.addIceCandidate(None)
logger.info(f"End-of-candidates added for {peer_name}")
return
if cand and cand.startswith("candidate:"):
sdp_part = cand.split(":", 1)[1]
else:
sdp_part = cand
# Detect type for logging
try:
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", sdp_part)
cand_type = m.group(1) if m else "unknown"
except Exception:
cand_type = "unknown"
try:
rtc_candidate = candidate_from_sdp(sdp_part)
rtc_candidate.sdpMid = candidate_data.sdpMid
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
# aiortc expects an object with attributes (RTCIceCandidate)
await pc.addIceCandidate(rtc_candidate)
logger.info(f"ICE candidate added for {peer_name}: type={cand_type}")
except Exception as e:
logger.error(
f"Failed to add ICE candidate for {peer_name}: type={cand_type} error={e} sdp='{sdp_part}'",
exc_info=True,
)
except Exception as e:
logger.error(
f"Unexpected error handling ICE candidate for {peer_name}: {e}",
exc_info=True,
)