Refactored voicebot/main.py
This commit is contained in:
parent
2e91a4eadb
commit
b916db243b
@ -24,6 +24,7 @@ RUN apt-get update \
|
||||
libgl1 \
|
||||
libglib2.0-0t64 \
|
||||
git \
|
||||
iproute2 \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/{apt,dpkg,cache,log}
|
||||
|
||||
|
@ -523,7 +523,7 @@ class Session:
|
||||
try:
|
||||
if os.path.exists(cls._save_file + ".tmp"):
|
||||
os.remove(cls._save_file + ".tmp")
|
||||
except:
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@ -1971,7 +1971,7 @@ async def lobby_join(
|
||||
)
|
||||
try:
|
||||
await websocket.close()
|
||||
except:
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
|
82
voicebot/REFACTORING_SUMMARY.md
Normal file
82
voicebot/REFACTORING_SUMMARY.md
Normal file
@ -0,0 +1,82 @@
|
||||
# Voicebot Module Refactoring
|
||||
|
||||
The voicebot/main.py functionality has been broken down into individual Python files for better organization and maintainability:
|
||||
|
||||
## New File Structure
|
||||
|
||||
### Core Modules
|
||||
|
||||
1. **`models.py`** - Data models and configuration
|
||||
- `VoicebotArgs` - Pydantic model for CLI arguments and configuration
|
||||
- `VoicebotMode` - Enum for client/provider modes
|
||||
- `Peer` - WebRTC peer representation
|
||||
- `JoinRequest` - Request model for joining lobbies
|
||||
- `MessageData` - Type alias for message payloads
|
||||
|
||||
2. **`webrtc_signaling.py`** - WebRTC signaling client functionality
|
||||
- `WebRTCSignalingClient` - Main WebRTC signaling client class
|
||||
- Handles peer connection management, ICE candidates, session descriptions
|
||||
- Registration status tracking and reconnection logic
|
||||
- Message processing and event handling
|
||||
|
||||
3. **`session_manager.py`** - Session and lobby management
|
||||
- `create_or_get_session()` - Session creation/retrieval
|
||||
- `create_or_get_lobby()` - Lobby creation/retrieval
|
||||
- HTTP API communication utilities
|
||||
|
||||
4. **`bot_orchestrator.py`** - FastAPI bot orchestration service
|
||||
- Bot discovery and management
|
||||
- FastAPI endpoints for bot operations
|
||||
- Provider registration with main server
|
||||
- Bot instance lifecycle management
|
||||
|
||||
5. **`client_main.py`** - Main client logic
|
||||
- `main_with_args()` - Core client functionality
|
||||
- `start_client_with_reload()` - Development mode with reload
|
||||
- Event handlers for peer and track management
|
||||
|
||||
6. **`client_app.py`** - Client FastAPI application
|
||||
- `create_client_app()` - Creates FastAPI app for client mode
|
||||
- Health check and status endpoints
|
||||
- Process isolation and locking
|
||||
|
||||
7. **`utils.py`** - Utility functions
|
||||
- URL conversion utilities (`http_base_url`, `ws_url`)
|
||||
- SSL context creation
|
||||
- Network information logging
|
||||
|
||||
8. **`main.py`** - Main orchestration and entry point
|
||||
- Command-line argument parsing
|
||||
- Mode selection (client vs provider)
|
||||
- Entry points for both modes
|
||||
|
||||
### Key Improvements
|
||||
|
||||
- **Separation of Concerns**: Each file handles specific functionality
|
||||
- **Better Maintainability**: Smaller, focused modules are easier to understand and modify
|
||||
- **Reduced Coupling**: Dependencies between components are more explicit
|
||||
- **Type Safety**: Proper type hints and Pydantic models throughout
|
||||
- **Error Handling**: Centralized error handling and logging
|
||||
|
||||
### Usage
|
||||
|
||||
The refactored code maintains the same CLI interface:
|
||||
|
||||
```bash
|
||||
# Client mode
|
||||
python voicebot/main.py --mode client --server-url http://localhost:8000/ai-voicebot
|
||||
|
||||
# Provider mode
|
||||
python voicebot/main.py --mode provider --host 0.0.0.0 --port 8788
|
||||
```
|
||||
|
||||
### Import Structure
|
||||
|
||||
```python
|
||||
from voicebot import VoicebotArgs, VoicebotMode, WebRTCSignalingClient
|
||||
from voicebot.models import Peer, JoinRequest
|
||||
from voicebot.session_manager import create_or_get_session, create_or_get_lobby
|
||||
from voicebot.client_main import main_with_args
|
||||
```
|
||||
|
||||
The original `main_old.py` contains the monolithic implementation for reference.
|
30
voicebot/__init__.py
Normal file
30
voicebot/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""
|
||||
Voicebot package.
|
||||
|
||||
This package provides WebRTC signaling client functionality and bot orchestration
|
||||
for AI voicebots.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the parent directory to sys.path to allow absolute imports
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from voicebot.models import VoicebotArgs, VoicebotMode, Peer, JoinRequest
|
||||
from voicebot.webrtc_signaling import WebRTCSignalingClient
|
||||
from voicebot.session_manager import create_or_get_session, create_or_get_lobby
|
||||
from voicebot.client_main import main_with_args
|
||||
from voicebot.bot_orchestrator import app as bot_orchestrator_app
|
||||
|
||||
__all__ = [
|
||||
'VoicebotArgs',
|
||||
'VoicebotMode',
|
||||
'Peer',
|
||||
'JoinRequest',
|
||||
'WebRTCSignalingClient',
|
||||
'create_or_get_session',
|
||||
'create_or_get_lobby',
|
||||
'main_with_args',
|
||||
'bot_orchestrator_app',
|
||||
]
|
238
voicebot/bot_orchestrator.py
Normal file
238
voicebot/bot_orchestrator.py
Normal file
@ -0,0 +1,238 @@
|
||||
"""
|
||||
Bot orchestrator FastAPI service.
|
||||
|
||||
This module provides the FastAPI service for bot discovery and orchestration.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import uuid
|
||||
import importlib
|
||||
import pkgutil
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
# Add the parent directory to sys.path to allow absolute imports
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
from logger import logger
|
||||
from voicebot.models import JoinRequest
|
||||
from voicebot.webrtc_signaling import WebRTCSignalingClient
|
||||
|
||||
|
||||
app = FastAPI(title="voicebot-bot-orchestrator")
|
||||
|
||||
# Lightweight in-memory registry of running bot clients
|
||||
registry: Dict[str, WebRTCSignalingClient] = {}
|
||||
|
||||
|
||||
def discover_bots() -> Dict[str, Dict[str, Any]]:
|
||||
"""Discover bot modules under the voicebot.bots package that expose bot_info.
|
||||
|
||||
This intentionally imports modules under `voicebot.bots` so heavy bot
|
||||
implementations can remain in that package and be imported lazily.
|
||||
"""
|
||||
bots: Dict[str, Dict[str, Any]] = {}
|
||||
try:
|
||||
package = importlib.import_module("voicebot.bots")
|
||||
package_path = package.__path__
|
||||
except Exception:
|
||||
logger.exception("Failed to import voicebot.bots package")
|
||||
return bots
|
||||
|
||||
for _finder, name, _ispkg in pkgutil.iter_modules(package_path):
|
||||
try:
|
||||
mod = importlib.import_module(f"voicebot.bots.{name}")
|
||||
except Exception:
|
||||
logger.exception("Failed to import voicebot.bots.%s", name)
|
||||
continue
|
||||
info = None
|
||||
create_tracks = None
|
||||
if hasattr(mod, "agent_info") and callable(getattr(mod, "agent_info")):
|
||||
try:
|
||||
info = mod.agent_info()
|
||||
# Note: Keep copy as is to maintain structure
|
||||
except Exception:
|
||||
logger.exception("agent_info() failed for %s", name)
|
||||
if hasattr(mod, "create_agent_tracks") and callable(getattr(mod, "create_agent_tracks")):
|
||||
create_tracks = getattr(mod, "create_agent_tracks")
|
||||
|
||||
if info:
|
||||
bots[info.get("name", name)] = {"module": name, "info": info, "create_tracks": create_tracks}
|
||||
return bots
|
||||
|
||||
|
||||
@app.get("/bots")
|
||||
def list_bots() -> Dict[str, Any]:
|
||||
"""List available bots."""
|
||||
bots = discover_bots()
|
||||
return {k: v["info"] for k, v in bots.items()}
|
||||
|
||||
|
||||
@app.post("/bots/{bot_name}/join")
|
||||
async def bot_join(bot_name: str, req: JoinRequest):
|
||||
"""Make a bot join a lobby."""
|
||||
bots = discover_bots()
|
||||
bot = bots.get(bot_name)
|
||||
if not bot:
|
||||
raise HTTPException(status_code=404, detail="Bot not found")
|
||||
|
||||
create_tracks = bot.get("create_tracks")
|
||||
|
||||
# Start the WebRTCSignalingClient in a background asyncio task and register it
|
||||
client = WebRTCSignalingClient(
|
||||
server_url=req.server_url,
|
||||
lobby_id=req.lobby_id,
|
||||
session_id=req.session_id,
|
||||
session_name=req.nick,
|
||||
insecure=req.insecure,
|
||||
create_tracks=create_tracks,
|
||||
)
|
||||
|
||||
run_id = str(uuid.uuid4())
|
||||
|
||||
async def run_client():
|
||||
try:
|
||||
registry[run_id] = client
|
||||
await client.connect()
|
||||
except Exception:
|
||||
logger.exception("Bot client failed for run %s", run_id)
|
||||
finally:
|
||||
registry.pop(run_id, None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
threading.Thread(target=loop.run_until_complete, args=(run_client(),), daemon=True).start()
|
||||
|
||||
return {"status": "started", "bot": bot_name, "run_id": run_id}
|
||||
|
||||
|
||||
@app.post("/bots/runs/{run_id}/stop")
|
||||
async def stop_run(run_id: str):
|
||||
"""Stop a running bot."""
|
||||
client = registry.get(run_id)
|
||||
if not client:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception:
|
||||
logger.exception("Failed to stop run %s", run_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to stop run")
|
||||
registry.pop(run_id, None)
|
||||
return {"status": "stopped", "run_id": run_id}
|
||||
|
||||
|
||||
@app.get("/bots/runs")
|
||||
def list_runs() -> Dict[str, Any]:
|
||||
"""List running bot instances."""
|
||||
return {
|
||||
"runs": [
|
||||
{"run_id": run_id, "session_id": client.session_id, "session_name": client.session_name}
|
||||
for run_id, client in registry.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def start_bot_api(host: str = "0.0.0.0", port: int = 8788):
|
||||
"""Start the bot orchestration API server"""
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
async def register_with_server(server_url: str, voicebot_url: str, insecure: bool = False) -> str:
|
||||
"""Register this voicebot instance as a bot provider with the main server"""
|
||||
try:
|
||||
# Import httpx locally to avoid dependency issues
|
||||
import httpx
|
||||
|
||||
payload = {
|
||||
"base_url": voicebot_url.rstrip('/'),
|
||||
"name": "voicebot-provider",
|
||||
"description": "AI voicebot provider with speech recognition and synthetic media capabilities"
|
||||
}
|
||||
|
||||
# Prepare SSL context if needed
|
||||
verify = not insecure
|
||||
|
||||
async with httpx.AsyncClient(verify=verify) as client:
|
||||
response = await client.post(
|
||||
f"{server_url}/api/bots/providers/register",
|
||||
json=payload,
|
||||
timeout=10.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
provider_id = result.get("provider_id")
|
||||
logger.info(f"Successfully registered with server as provider: {provider_id}")
|
||||
return provider_id
|
||||
else:
|
||||
logger.error(f"Failed to register with server: HTTP {response.status_code}: {response.text}")
|
||||
raise RuntimeError(f"Registration failed: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering with server: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def start_bot_provider(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8788,
|
||||
server_url: str | None = None,
|
||||
insecure: bool = False,
|
||||
reload: bool = False
|
||||
):
|
||||
"""Start the bot provider API server and optionally register with main server"""
|
||||
import time
|
||||
import socket
|
||||
|
||||
# Start the FastAPI server in a background thread
|
||||
# Add reload functionality for development
|
||||
if reload:
|
||||
server_thread = threading.Thread(
|
||||
target=lambda: uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info",
|
||||
reload=True,
|
||||
reload_dirs=["/voicebot", "/shared"]
|
||||
),
|
||||
daemon=True
|
||||
)
|
||||
else:
|
||||
server_thread = threading.Thread(
|
||||
target=lambda: uvicorn.run(app, host=host, port=port, log_level="info"),
|
||||
daemon=True
|
||||
)
|
||||
logger.info(f"Starting bot provider API server on {host}:{port}...")
|
||||
server_thread.start()
|
||||
|
||||
# If server_url is provided, register with the main server
|
||||
if server_url:
|
||||
# Give the server a moment to start
|
||||
time.sleep(2)
|
||||
|
||||
# Construct the voicebot URL
|
||||
voicebot_url = f"http://{host}:{port}"
|
||||
if host == "0.0.0.0":
|
||||
# Try to get a better hostname
|
||||
try:
|
||||
hostname = socket.gethostname()
|
||||
voicebot_url = f"http://{hostname}:{port}"
|
||||
except Exception:
|
||||
voicebot_url = f"http://localhost:{port}"
|
||||
|
||||
try:
|
||||
asyncio.run(register_with_server(server_url, voicebot_url, insecure))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register with server: {e}")
|
||||
|
||||
# Keep the main thread alive
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutting down bot provider...")
|
@ -4,11 +4,11 @@ Lightweight agent descriptor; heavy model loading must be done by a controller
|
||||
when the agent is actually used.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from typing import Any
|
||||
from typing import Dict, Any
|
||||
import librosa
|
||||
from logger import logger
|
||||
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
||||
from aiortc import MediaStreamTrack
|
||||
|
||||
|
||||
AGENT_NAME = "whisper"
|
||||
@ -19,91 +19,95 @@ def agent_info() -> Dict[str, str]:
|
||||
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION}
|
||||
|
||||
|
||||
def create_agent_tracks(session_name: str) -> dict:
|
||||
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
|
||||
"""Whisper is not a media source - return no local tracks."""
|
||||
return {}
|
||||
|
||||
model_ids = {
|
||||
"Distil-Whisper": [
|
||||
"distil-whisper/distil-large-v2",
|
||||
"distil-whisper/distil-medium.en",
|
||||
"distil-whisper/distil-small.en"
|
||||
],
|
||||
"Whisper": [
|
||||
"openai/whisper-large-v3",
|
||||
"openai/whisper-large-v2",
|
||||
"openai/whisper-large",
|
||||
"openai/whisper-medium",
|
||||
"openai/whisper-small",
|
||||
"openai/whisper-base",
|
||||
"openai/whisper-tiny",
|
||||
"openai/whisper-medium.en",
|
||||
"openai/whisper-small.en",
|
||||
"openai/whisper-base.en",
|
||||
"openai/whisper-tiny.en",
|
||||
]
|
||||
}
|
||||
def do_work():
|
||||
model_ids = {
|
||||
"Distil-Whisper": [
|
||||
"distil-whisper/distil-large-v2",
|
||||
"distil-whisper/distil-medium.en",
|
||||
"distil-whisper/distil-small.en"
|
||||
],
|
||||
"Whisper": [
|
||||
"openai/whisper-large-v3",
|
||||
"openai/whisper-large-v2",
|
||||
"openai/whisper-large",
|
||||
"openai/whisper-medium",
|
||||
"openai/whisper-small",
|
||||
"openai/whisper-base",
|
||||
"openai/whisper-tiny",
|
||||
"openai/whisper-medium.en",
|
||||
"openai/whisper-small.en",
|
||||
"openai/whisper-base.en",
|
||||
"openai/whisper-tiny.en",
|
||||
]
|
||||
}
|
||||
|
||||
model_type = model_ids["Distil-Whisper"]
|
||||
model_type = model_ids["Distil-Whisper"]
|
||||
|
||||
logger.info(model_type)
|
||||
model_id = model_type[0]
|
||||
logger.info(model_type)
|
||||
model_id = model_type[0]
|
||||
|
||||
|
||||
processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
|
||||
processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
|
||||
|
||||
pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
|
||||
pt_model.eval() # type: ignore
|
||||
pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
|
||||
pt_model.eval() # type: ignore
|
||||
|
||||
|
||||
def extract_input_features(audio_array: Any, sampling_rate: int) -> Any:
|
||||
"""Extract input features from audio array and sampling rate."""
|
||||
input_features = processor(
|
||||
audio_array,
|
||||
sampling_rate=sampling_rate,
|
||||
return_tensors="pt",
|
||||
).input_features
|
||||
return input_features
|
||||
def extract_input_features(audio_array: Any, sampling_rate: int) -> Any:
|
||||
"""Extract input features from audio array and sampling rate."""
|
||||
processor_output = processor( # type: ignore
|
||||
audio_array,
|
||||
sampling_rate=sampling_rate,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_features: Any = processor_output.input_features # type: ignore
|
||||
return input_features # type: ignore
|
||||
|
||||
|
||||
def load_audio_file(file_path: str) -> tuple[Any, int]:
|
||||
"""Load audio file from disk and return audio array and sampling rate."""
|
||||
# Whisper models expect 16kHz sample rate
|
||||
target_sample_rate = 16000
|
||||
|
||||
try:
|
||||
# Load audio file using librosa and resample to target rate
|
||||
audio_array, original_sampling_rate = librosa.load(file_path, sr=None) # type: ignore
|
||||
logger.info(f"Loaded audio file: {file_path}, duration: {len(audio_array)/original_sampling_rate:.2f}s, original sample rate: {original_sampling_rate}Hz") # type: ignore
|
||||
|
||||
# Resample if necessary
|
||||
if original_sampling_rate != target_sample_rate:
|
||||
audio_array = librosa.resample(audio_array, orig_sr=original_sampling_rate, target_sr=target_sample_rate) # type: ignore
|
||||
logger.info(f"Resampled audio from {original_sampling_rate}Hz to {target_sample_rate}Hz")
|
||||
|
||||
return audio_array, target_sample_rate # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading audio file {file_path}: {e}")
|
||||
raise
|
||||
def load_audio_file(file_path: str) -> tuple[Any, int]:
|
||||
"""Load audio file from disk and return audio array and sampling rate."""
|
||||
# Whisper models expect 16kHz sample rate
|
||||
target_sample_rate = 16000
|
||||
|
||||
try:
|
||||
# Load audio file using librosa and resample to target rate
|
||||
audio_array, original_sampling_rate = librosa.load(file_path, sr=None) # type: ignore
|
||||
logger.info(f"Loaded audio file: {file_path}, duration: {len(audio_array)/original_sampling_rate:.2f}s, original sample rate: {original_sampling_rate}Hz") # type: ignore
|
||||
|
||||
# Resample if necessary
|
||||
if original_sampling_rate != target_sample_rate:
|
||||
audio_array = librosa.resample(audio_array, orig_sr=original_sampling_rate, target_sr=target_sample_rate) # type: ignore
|
||||
logger.info(f"Resampled audio from {original_sampling_rate}Hz to {target_sample_rate}Hz")
|
||||
|
||||
return audio_array, target_sample_rate # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading audio file {file_path}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# Example usage - replace with your audio file path
|
||||
audio_file_path = "/voicebot/F_0818_15y11m_1.wav"
|
||||
# Example usage - replace with your audio file path
|
||||
audio_file_path = "/voicebot/F_0818_15y11m_1.wav"
|
||||
|
||||
# Load audio from file instead of dataset
|
||||
try:
|
||||
audio_array, sampling_rate = load_audio_file(audio_file_path)
|
||||
input_features = extract_input_features(audio_array, sampling_rate)
|
||||
|
||||
predicted_ids = pt_model.generate(input_features) # type: ignore
|
||||
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore
|
||||
|
||||
print(f"Audio file: {audio_file_path}")
|
||||
print(f"Transcription: {transcription[0]}")
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Audio file not found: {audio_file_path}")
|
||||
print("Please update the audio_file_path variable with a valid path to your wav file")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {e}")
|
||||
print(f"Error: {e}")
|
||||
# Load audio from file instead of dataset
|
||||
try:
|
||||
audio_array, sampling_rate = load_audio_file(audio_file_path)
|
||||
input_features = extract_input_features(audio_array, sampling_rate)
|
||||
|
||||
predicted_ids = pt_model.generate(input_features) # type: ignore
|
||||
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore
|
||||
|
||||
print(f"Audio file: {audio_file_path}")
|
||||
print(f"Transcription: {transcription[0]}")
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Audio file not found: {audio_file_path}")
|
||||
print("Please update the audio_file_path variable with a valid path to your wav file")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio: {e}")
|
||||
print(f"Error: {e}")
|
||||
|
||||
|
127
voicebot/client_app.py
Normal file
127
voicebot/client_app.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""
|
||||
Client FastAPI application for voicebot.
|
||||
|
||||
This module provides the FastAPI application for client mode operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import fcntl
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
# Add the parent directory to sys.path to allow absolute imports
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from fastapi import FastAPI
|
||||
from logger import logger
|
||||
|
||||
# Import shared models
|
||||
from shared.models import ClientStatusResponse
|
||||
|
||||
from voicebot.models import VoicebotArgs
|
||||
|
||||
|
||||
# Global client arguments storage
|
||||
_client_args: Optional[VoicebotArgs] = None
|
||||
|
||||
|
||||
def create_client_app(args: VoicebotArgs) -> FastAPI:
|
||||
"""Create a FastAPI app for client mode that uvicorn can import."""
|
||||
|
||||
global _client_args
|
||||
_client_args = args
|
||||
|
||||
# Store the client task globally so we can manage it
|
||||
client_task = None
|
||||
lock_file = None
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
nonlocal client_task, lock_file
|
||||
# Startup
|
||||
# Use a file lock to prevent multiple instances from starting
|
||||
lock_file_path = "/tmp/voicebot_client.lock"
|
||||
|
||||
try:
|
||||
lock_file = open(lock_file_path, 'w')
|
||||
# Try to acquire an exclusive lock (non-blocking)
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
|
||||
if _client_args is None:
|
||||
logger.error("Client args not initialized")
|
||||
if lock_file:
|
||||
lock_file.close()
|
||||
lock_file = None
|
||||
yield
|
||||
return
|
||||
|
||||
logger.info("Starting voicebot client...")
|
||||
# Import here to avoid circular imports
|
||||
from .client_main import main_with_args
|
||||
client_task = asyncio.create_task(main_with_args(_client_args))
|
||||
|
||||
except (IOError, OSError):
|
||||
# Another process already has the lock
|
||||
logger.info("Another instance is already running - skipping client startup")
|
||||
if lock_file:
|
||||
lock_file.close()
|
||||
lock_file = None
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
if client_task and not client_task.done():
|
||||
logger.info("Shutting down voicebot client...")
|
||||
client_task.cancel()
|
||||
try:
|
||||
await client_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if lock_file:
|
||||
try:
|
||||
fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
|
||||
lock_file.close()
|
||||
os.unlink(lock_file_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create the client FastAPI app
|
||||
app = FastAPI(title="voicebot-client", lifespan=lifespan)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check(): # type: ignore
|
||||
"""Simple health check endpoint"""
|
||||
return {"status": "running", "mode": "client"}
|
||||
|
||||
@app.get("/status", response_model=ClientStatusResponse)
|
||||
async def client_status() -> ClientStatusResponse: # type: ignore
|
||||
"""Get client status"""
|
||||
return ClientStatusResponse(
|
||||
client_running=client_task is not None and not client_task.done(),
|
||||
session_name=_client_args.session_name if _client_args else 'unknown',
|
||||
lobby=_client_args.lobby if _client_args else 'unknown',
|
||||
server_url=_client_args.server_url if _client_args else 'unknown'
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def get_app() -> FastAPI:
|
||||
"""Get the appropriate FastAPI app based on VOICEBOT_MODE environment variable."""
|
||||
mode = os.getenv('VOICEBOT_MODE', 'provider')
|
||||
|
||||
if mode == 'client':
|
||||
# For client mode, we need to create the client app with args from environment
|
||||
args = VoicebotArgs.from_environment()
|
||||
return create_client_app(args)
|
||||
else:
|
||||
# Provider mode - return the main bot orchestration app
|
||||
from voicebot.bot_orchestrator import app
|
||||
return app
|
||||
|
||||
|
||||
# Create app instance for uvicorn import
|
||||
uvicorn_app = get_app()
|
144
voicebot/client_main.py
Normal file
144
voicebot/client_main.py
Normal file
@ -0,0 +1,144 @@
|
||||
"""
|
||||
Main client logic for voicebot.
|
||||
|
||||
This module contains the main client functionality and entry points.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from logger import logger
|
||||
from voicebot.bots.synthetic_media import AnimatedVideoTrack
|
||||
|
||||
# Add the parent directory to sys.path to allow absolute imports
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from voicebot.models import VoicebotArgs, Peer
|
||||
from voicebot.session_manager import create_or_get_session, create_or_get_lobby
|
||||
from voicebot.webrtc_signaling import WebRTCSignalingClient
|
||||
from voicebot.utils import ws_url
|
||||
from aiortc import MediaStreamTrack
|
||||
|
||||
|
||||
async def main_with_args(args: VoicebotArgs):
|
||||
"""Main voicebot client logic that accepts arguments object."""
|
||||
|
||||
# Resolve session id (create if needed)
|
||||
try:
|
||||
session_id = create_or_get_session(
|
||||
args.server_url, args.session_id, insecure=args.insecure
|
||||
)
|
||||
print(f"Using session id: {session_id}")
|
||||
except Exception as e:
|
||||
print(f"Failed to get/create session: {e}")
|
||||
return
|
||||
|
||||
# Create or get lobby id
|
||||
try:
|
||||
lobby_id = create_or_get_lobby(
|
||||
args.server_url,
|
||||
session_id,
|
||||
args.lobby,
|
||||
args.private,
|
||||
insecure=args.insecure,
|
||||
)
|
||||
print(f"Using lobby id: {lobby_id} (name={args.lobby})")
|
||||
except Exception as e:
|
||||
print(f"Failed to create/get lobby: {e}")
|
||||
return
|
||||
|
||||
# Build websocket base URL (ws:// or wss://) from server_url and pass to client so
|
||||
# it constructs the final websocket path (/ws/lobby/{lobby}/{session}) itself.
|
||||
ws_base = ws_url(args.server_url)
|
||||
|
||||
client = WebRTCSignalingClient(
|
||||
ws_base, lobby_id, session_id, args.session_name,
|
||||
insecure=args.insecure,
|
||||
registration_check_interval=args.registration_check_interval
|
||||
)
|
||||
|
||||
# Set up event handlers
|
||||
async def on_peer_added(peer: Peer):
|
||||
print(f"Peer added: {peer.peer_name}")
|
||||
|
||||
async def on_peer_removed(peer: Peer):
|
||||
print(f"Peer removed: {peer.peer_name}")
|
||||
|
||||
# Remove any video tracks from this peer from our synthetic video track
|
||||
if "video" in client.local_tracks:
|
||||
synthetic_video_track = client.local_tracks["video"]
|
||||
if isinstance(synthetic_video_track, AnimatedVideoTrack):
|
||||
# We need to identify and remove tracks from this specific peer
|
||||
# Since we don't have a direct mapping, we'll need to track this differently
|
||||
# For now, this is a placeholder - we might need to enhance the peer tracking
|
||||
logger.info(
|
||||
f"Peer {peer.peer_name} removed - may need to clean up video tracks"
|
||||
)
|
||||
|
||||
async def on_track_received(peer: Peer, track: MediaStreamTrack):
|
||||
print(f"Received {track.kind} track from {peer.peer_name}")
|
||||
|
||||
# If it's a video track, attach it to our synthetic video track for edge detection
|
||||
if track.kind == "video" and "video" in client.local_tracks:
|
||||
synthetic_video_track = client.local_tracks["video"]
|
||||
if isinstance(synthetic_video_track, AnimatedVideoTrack):
|
||||
synthetic_video_track.add_remote_video_track(track)
|
||||
logger.info(
|
||||
f"Attached remote video track from {peer.peer_name} to synthetic video track"
|
||||
)
|
||||
|
||||
client.on_peer_added = on_peer_added
|
||||
client.on_peer_removed = on_peer_removed
|
||||
client.on_track_received = on_track_received
|
||||
|
||||
# Retry loop for connection resilience
|
||||
max_retries = 5
|
||||
retry_delay = 5.0 # seconds
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# If a password was provided on the CLI, store it on the client for use when setting name
|
||||
if args.password:
|
||||
client.name_password = args.password
|
||||
|
||||
print(f"Attempting to connect (attempt {retry_count + 1}/{max_retries})")
|
||||
await client.connect()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Shutting down...")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Connection failed (attempt {retry_count + 1}): {e}")
|
||||
retry_count += 1
|
||||
|
||||
if retry_count < max_retries:
|
||||
print(f"Retrying in {retry_delay} seconds...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
# Exponential backoff with max delay of 60 seconds
|
||||
retry_delay = min(retry_delay * 1.5, 60.0)
|
||||
else:
|
||||
print("Max retries exceeded. Giving up.")
|
||||
break
|
||||
finally:
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect: {e}")
|
||||
|
||||
print("Voicebot client stopped.")
|
||||
|
||||
|
||||
def start_client_with_reload(args: VoicebotArgs):
|
||||
"""Start the client with auto-reload functionality."""
|
||||
|
||||
logger.info("Creating client app for uvicorn...")
|
||||
from voicebot.client_app import create_client_app
|
||||
create_client_app(args)
|
||||
|
||||
# Note: This function is called when --reload is specified
|
||||
# The actual uvicorn execution should be handled by the entrypoint script
|
||||
logger.info("Client app created. Uvicorn should be started by entrypoint script.")
|
||||
|
||||
# Fall back to running client directly if not using uvicorn
|
||||
asyncio.run(main_with_args(args))
|
1697
voicebot/main.py
1697
voicebot/main.py
File diff suppressed because it is too large
Load Diff
121
voicebot/models.py
Normal file
121
voicebot/models.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""
|
||||
Data models and configuration for voicebot.
|
||||
|
||||
This module provides Pydantic models for configuration and data structures
|
||||
used throughout the voicebot application.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
from dataclasses import dataclass, field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiortc import RTCPeerConnection
|
||||
|
||||
|
||||
class VoicebotMode(str, Enum):
|
||||
"""Voicebot operation modes."""
|
||||
CLIENT = "client"
|
||||
PROVIDER = "provider"
|
||||
|
||||
|
||||
class VoicebotArgs(BaseModel):
|
||||
"""Pydantic model for voicebot CLI arguments and configuration."""
|
||||
|
||||
# Mode selection
|
||||
mode: VoicebotMode = Field(default=VoicebotMode.CLIENT, description="Run as client (connect to lobby) or provider (serve bots)")
|
||||
|
||||
# Provider mode arguments
|
||||
host: str = Field(default="0.0.0.0", description="Host for provider mode")
|
||||
port: int = Field(default=8788, description="Port for provider mode", ge=1, le=65535)
|
||||
reload: bool = Field(default=False, description="Enable auto-reload for development")
|
||||
|
||||
# Client mode arguments
|
||||
server_url: str = Field(
|
||||
default="http://localhost:8000/ai-voicebot",
|
||||
description="AI-Voicebot lobby and signaling server base URL (http:// or https://)"
|
||||
)
|
||||
lobby: str = Field(default="default", description="Lobby name to create or join")
|
||||
session_name: str = Field(default="Python Bot", description="Session (user) display name")
|
||||
session_id: Optional[str] = Field(default=None, description="Optional existing session id to reuse")
|
||||
password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name")
|
||||
private: bool = Field(default=False, description="Create the lobby as private")
|
||||
insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL")
|
||||
registration_check_interval: float = Field(default=30.0, description="Interval in seconds for checking registration status", ge=5.0, le=300.0)
|
||||
|
||||
@classmethod
|
||||
def from_environment(cls) -> 'VoicebotArgs':
|
||||
"""Create VoicebotArgs from environment variables."""
|
||||
import os
|
||||
|
||||
mode_str = os.getenv('VOICEBOT_MODE', 'client')
|
||||
return cls(
|
||||
mode=VoicebotMode(mode_str),
|
||||
host=os.getenv('VOICEBOT_HOST', '0.0.0.0'),
|
||||
port=int(os.getenv('VOICEBOT_PORT', '8788')),
|
||||
reload=os.getenv('VOICEBOT_RELOAD', 'false').lower() == 'true',
|
||||
server_url=os.getenv('VOICEBOT_SERVER_URL', 'http://localhost:8000/ai-voicebot'),
|
||||
lobby=os.getenv('VOICEBOT_LOBBY', 'default'),
|
||||
session_name=os.getenv('VOICEBOT_SESSION_NAME', 'Python Bot'),
|
||||
session_id=os.getenv('VOICEBOT_SESSION_ID', None),
|
||||
password=os.getenv('VOICEBOT_PASSWORD', None),
|
||||
private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true',
|
||||
insecure=os.getenv('VOICEBOT_INSECURE', 'false').lower() == 'true',
|
||||
registration_check_interval=float(os.getenv('VOICEBOT_REGISTRATION_CHECK_INTERVAL', '30.0'))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_argparse(cls, args: argparse.Namespace) -> 'VoicebotArgs':
|
||||
"""Create VoicebotArgs from argparse Namespace."""
|
||||
mode_str = getattr(args, 'mode', 'client')
|
||||
return cls(
|
||||
mode=VoicebotMode(mode_str),
|
||||
host=getattr(args, 'host', '0.0.0.0'),
|
||||
port=getattr(args, 'port', 8788),
|
||||
reload=getattr(args, 'reload', False),
|
||||
server_url=getattr(args, 'server_url', 'http://localhost:8000/ai-voicebot'),
|
||||
lobby=getattr(args, 'lobby', 'default'),
|
||||
session_name=getattr(args, 'session_name', 'Python Bot'),
|
||||
session_id=getattr(args, 'session_id', None),
|
||||
password=getattr(args, 'password', None),
|
||||
private=getattr(args, 'private', False),
|
||||
insecure=getattr(args, 'insecure', False),
|
||||
registration_check_interval=float(getattr(args, 'registration_check_interval', 30.0))
|
||||
)
|
||||
|
||||
|
||||
class JoinRequest(BaseModel):
|
||||
"""Request model for joining a lobby."""
|
||||
lobby_id: str
|
||||
session_id: str
|
||||
nick: str
|
||||
server_url: str
|
||||
insecure: bool = False
|
||||
|
||||
|
||||
def _default_attributes() -> Dict[str, object]:
|
||||
"""Default factory for peer attributes."""
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Peer:
|
||||
"""Represents a WebRTC peer in the session"""
|
||||
|
||||
session_id: str
|
||||
peer_name: str
|
||||
# Generic attributes bag. Values can be tracks or simple metadata.
|
||||
attributes: Dict[str, object] = field(default_factory=_default_attributes)
|
||||
muted: bool = False
|
||||
video_on: bool = True
|
||||
local: bool = False
|
||||
dead: bool = False
|
||||
connection: Optional['RTCPeerConnection'] = None
|
||||
|
||||
|
||||
# Generic message payload type
|
||||
MessageData = dict[str, object]
|
118
voicebot/session_manager.py
Normal file
118
voicebot/session_manager.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""
|
||||
Session and lobby management for voicebot.
|
||||
|
||||
This module handles session creation and lobby management functionality.
|
||||
"""
|
||||
|
||||
import json
|
||||
import ssl
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import sys
|
||||
import os
|
||||
from pydantic import ValidationError
|
||||
|
||||
# Add the parent directory to sys.path to allow absolute imports
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Import shared models
|
||||
from shared.models import SessionModel, LobbyCreateResponse
|
||||
|
||||
from voicebot.utils import http_base_url
|
||||
|
||||
|
||||
def create_or_get_session(
|
||||
server_url: str, session_id: str | None = None, insecure: bool = False
|
||||
) -> str:
|
||||
"""Call GET /api/session to obtain a session_id (unless one was provided).
|
||||
|
||||
Uses urllib so no extra runtime deps are required.
|
||||
"""
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
http_base = http_base_url(server_url)
|
||||
url = f"{http_base}/api/session"
|
||||
req = urllib.request.Request(url, method="GET")
|
||||
# Prepare SSL context if requested (accept self-signed certs)
|
||||
ssl_ctx = None
|
||||
if insecure:
|
||||
ssl_ctx = ssl.create_default_context()
|
||||
ssl_ctx.check_hostname = False
|
||||
ssl_ctx.verify_mode = ssl.CERT_NONE
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp:
|
||||
body = resp.read()
|
||||
data = json.loads(body)
|
||||
|
||||
# Validate response shape using Pydantic
|
||||
try:
|
||||
session = SessionModel.model_validate(data)
|
||||
except ValidationError as e:
|
||||
raise RuntimeError(f"Invalid session response from {url}: {e}")
|
||||
|
||||
sid = session.id
|
||||
if not sid:
|
||||
raise RuntimeError(f"No session id returned from {url}: {data}")
|
||||
return sid
|
||||
except urllib.error.HTTPError as e:
|
||||
raise RuntimeError(f"HTTP error getting session: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error getting session: {e}")
|
||||
|
||||
|
||||
def create_or_get_lobby(
|
||||
server_url: str,
|
||||
session_id: str,
|
||||
lobby_name: str,
|
||||
private: bool = False,
|
||||
insecure: bool = False,
|
||||
) -> str:
|
||||
"""Call POST /api/lobby/{session_id} to create or lookup a lobby by name.
|
||||
|
||||
Returns the lobby id.
|
||||
"""
|
||||
http_base = http_base_url(server_url)
|
||||
url = f"{http_base}/api/lobby/{urllib.parse.quote(session_id)}"
|
||||
payload = json.dumps(
|
||||
{
|
||||
"type": "lobby_create",
|
||||
"data": {"name": lobby_name, "private": private},
|
||||
}
|
||||
).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
url, data=payload, headers={"Content-Type": "application/json"}, method="POST"
|
||||
)
|
||||
# Prepare SSL context if requested (accept self-signed certs)
|
||||
ssl_ctx = None
|
||||
if insecure:
|
||||
ssl_ctx = ssl.create_default_context()
|
||||
ssl_ctx.check_hostname = False
|
||||
ssl_ctx.verify_mode = ssl.CERT_NONE
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10, context=ssl_ctx) as resp:
|
||||
body = resp.read()
|
||||
data = json.loads(body)
|
||||
# Expect shape: { "type": "lobby_created", "data": {"id":..., ...}}
|
||||
try:
|
||||
lobby_resp = LobbyCreateResponse.model_validate(data)
|
||||
except ValidationError as e:
|
||||
raise RuntimeError(f"Invalid lobby response from {url}: {e}")
|
||||
|
||||
lobby_id = lobby_resp.data.id
|
||||
if not lobby_id:
|
||||
raise RuntimeError(f"No lobby id returned from {url}: {data}")
|
||||
return lobby_id
|
||||
except urllib.error.HTTPError as e:
|
||||
# Try to include response body for debugging
|
||||
try:
|
||||
body = e.read()
|
||||
msg = body.decode("utf-8", errors="ignore")
|
||||
except Exception:
|
||||
msg = str(e)
|
||||
raise RuntimeError(f"HTTP error creating lobby: {msg}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error creating lobby: {e}")
|
57
voicebot/utils.py
Normal file
57
voicebot/utils.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""
|
||||
Utility functions for voicebot.
|
||||
|
||||
This module provides common utility functions used throughout the application.
|
||||
"""
|
||||
|
||||
import ssl
|
||||
|
||||
|
||||
def http_base_url(server_url: str) -> str:
|
||||
"""Convert ws:// or wss:// to http(s) and ensure no trailing slash."""
|
||||
if server_url.startswith("ws://"):
|
||||
return "http://" + server_url[len("ws://") :].rstrip("/")
|
||||
if server_url.startswith("wss://"):
|
||||
return "https://" + server_url[len("wss://") :].rstrip("/")
|
||||
return server_url.rstrip("/")
|
||||
|
||||
|
||||
def ws_url(server_url: str) -> str:
|
||||
"""Convert http(s) to ws(s) if needed."""
|
||||
if server_url.startswith("http://"):
|
||||
return "ws://" + server_url[len("http://") :].rstrip("/")
|
||||
if server_url.startswith("https://"):
|
||||
return "wss://" + server_url[len("https://") :].rstrip("/")
|
||||
return server_url.rstrip("/")
|
||||
|
||||
|
||||
def create_ssl_context(insecure: bool = False) -> ssl.SSLContext | None:
|
||||
"""Create SSL context for connections."""
|
||||
if not insecure:
|
||||
return None
|
||||
|
||||
ssl_ctx = ssl.create_default_context()
|
||||
ssl_ctx.check_hostname = False
|
||||
ssl_ctx.verify_mode = ssl.CERT_NONE
|
||||
return ssl_ctx
|
||||
|
||||
|
||||
def log_network_info():
|
||||
"""Log network information for debugging."""
|
||||
from logger import logger
|
||||
|
||||
try:
|
||||
import socket
|
||||
import subprocess
|
||||
|
||||
hostname = socket.gethostname()
|
||||
local_ip = socket.gethostbyname(hostname)
|
||||
logger.info(f"Container hostname: {hostname}, local IP: {local_ip}")
|
||||
|
||||
# Get all network interfaces
|
||||
result = subprocess.run(
|
||||
["ip", "addr", "show"], capture_output=True, text=True
|
||||
)
|
||||
logger.info(f"Network interfaces:\n{result.stdout}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get network info: {e}")
|
894
voicebot/webrtc_signaling.py
Normal file
894
voicebot/webrtc_signaling.py
Normal file
@ -0,0 +1,894 @@
|
||||
"""
|
||||
WebRTC signaling client for voicebot.
|
||||
|
||||
This module provides WebRTC signaling server communication and peer connection management.
|
||||
Synthetic audio/video track creation is handled by the bots.synthetic_media module.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import websockets
|
||||
import time
|
||||
import re
|
||||
from typing import (
|
||||
Dict,
|
||||
Optional,
|
||||
Callable,
|
||||
Awaitable,
|
||||
Protocol,
|
||||
AsyncIterator,
|
||||
cast,
|
||||
)
|
||||
|
||||
# Add the parent directory to sys.path to allow absolute imports
|
||||
|
||||
from pydantic import ValidationError
|
||||
from aiortc import (
|
||||
RTCPeerConnection,
|
||||
RTCSessionDescription,
|
||||
RTCIceCandidate,
|
||||
MediaStreamTrack,
|
||||
)
|
||||
from aiortc.rtcconfiguration import RTCConfiguration, RTCIceServer
|
||||
from aiortc.sdp import candidate_from_sdp
|
||||
|
||||
# Import shared models
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from shared.models import (
|
||||
WebSocketMessageModel,
|
||||
JoinStatusModel,
|
||||
UserJoinedModel,
|
||||
LobbyStateModel,
|
||||
UpdateNameModel,
|
||||
AddPeerModel,
|
||||
RemovePeerModel,
|
||||
SessionDescriptionModel,
|
||||
IceCandidateModel,
|
||||
ICECandidateDictModel,
|
||||
SessionDescriptionTypedModel,
|
||||
)
|
||||
|
||||
from logger import logger
|
||||
from voicebot.bots.synthetic_media import create_synthetic_tracks
|
||||
from voicebot.models import Peer, MessageData
|
||||
from voicebot.utils import create_ssl_context, log_network_info
|
||||
|
||||
|
||||
class WebSocketProtocol(Protocol):
|
||||
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
|
||||
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
|
||||
def __aiter__(self) -> AsyncIterator[str]: ...
|
||||
|
||||
|
||||
class WebRTCSignalingClient:
|
||||
"""
|
||||
WebRTC signaling client that communicates with the FastAPI signaling server.
|
||||
Handles peer-to-peer connection establishment and media streaming.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
lobby_id: str,
|
||||
session_id: str,
|
||||
session_name: str,
|
||||
insecure: bool = False,
|
||||
create_tracks: Optional[Callable[[str], Dict[str, MediaStreamTrack]]] = None,
|
||||
registration_check_interval: float = 30.0,
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.lobby_id = lobby_id
|
||||
self.session_id = session_id
|
||||
self.session_name = session_name
|
||||
self.insecure = insecure
|
||||
|
||||
# Optional factory to create local media tracks for this client (bot provided)
|
||||
self.create_tracks = create_tracks
|
||||
|
||||
# WebSocket client protocol instance (typed as object to avoid Any)
|
||||
self.websocket: Optional[object] = None
|
||||
|
||||
# Optional password to register or takeover a name
|
||||
self.name_password: Optional[str] = None
|
||||
|
||||
self.peers: dict[str, Peer] = {}
|
||||
self.peer_connections: dict[str, RTCPeerConnection] = {}
|
||||
self.local_tracks: dict[str, MediaStreamTrack] = {}
|
||||
|
||||
# State management
|
||||
self.is_negotiating: dict[str, bool] = {}
|
||||
self.making_offer: dict[str, bool] = {}
|
||||
self.initiated_offer: set[str] = set()
|
||||
self.pending_ice_candidates: dict[str, list[ICECandidateDictModel]] = {}
|
||||
|
||||
# Registration status tracking
|
||||
self.is_registered: bool = False
|
||||
self.last_registration_check: float = 0
|
||||
self.registration_check_interval: float = registration_check_interval
|
||||
self.registration_check_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
# Event callbacks
|
||||
self.on_peer_added: Optional[Callable[[Peer], Awaitable[None]]] = None
|
||||
self.on_peer_removed: Optional[Callable[[Peer], Awaitable[None]]] = None
|
||||
self.on_track_received: Optional[
|
||||
Callable[[Peer, MediaStreamTrack], Awaitable[None]]
|
||||
] = None
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the signaling server"""
|
||||
ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
|
||||
logger.info(f"Connecting to signaling server: {ws_url}")
|
||||
|
||||
# Log network information for debugging
|
||||
log_network_info()
|
||||
|
||||
try:
|
||||
# If insecure (self-signed certs), create an SSL context for the websocket
|
||||
ws_ssl = create_ssl_context(self.insecure)
|
||||
|
||||
logger.info(
|
||||
f"Attempting websocket connection to {ws_url} with ssl={bool(ws_ssl)}"
|
||||
)
|
||||
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
|
||||
logger.info("Connected to signaling server")
|
||||
|
||||
# Set up local media
|
||||
await self._setup_local_media()
|
||||
|
||||
# Set name and join lobby
|
||||
name_payload: MessageData = {"name": self.session_name}
|
||||
if self.name_password:
|
||||
name_payload["password"] = self.name_password
|
||||
logger.info(f"Sending set_name: {name_payload}")
|
||||
await self._send_message("set_name", name_payload)
|
||||
logger.info("Sending join message")
|
||||
await self._send_message("join", {})
|
||||
|
||||
# Mark as registered after successful join
|
||||
self.is_registered = True
|
||||
self.last_registration_check = time.time()
|
||||
|
||||
# Start periodic registration check
|
||||
self.registration_check_task = asyncio.create_task(self._periodic_registration_check())
|
||||
|
||||
# Start message handling
|
||||
logger.info("Starting message handler loop")
|
||||
try:
|
||||
await self._handle_messages()
|
||||
except Exception as e:
|
||||
logger.error(f"Message handling stopped: {e}")
|
||||
self.is_registered = False
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to signaling server: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def _periodic_registration_check(self):
|
||||
"""Periodically check registration status and re-register if needed"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.registration_check_interval)
|
||||
|
||||
current_time = time.time()
|
||||
if current_time - self.last_registration_check < self.registration_check_interval:
|
||||
continue
|
||||
|
||||
# Check if we're still connected and registered
|
||||
if not await self._check_registration_status():
|
||||
logger.warning("Registration check failed, attempting to re-register")
|
||||
await self._re_register()
|
||||
|
||||
self.last_registration_check = current_time
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Registration check task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in periodic registration check: {e}", exc_info=True)
|
||||
# Continue checking even if one iteration fails
|
||||
continue
|
||||
|
||||
async def _check_registration_status(self) -> bool:
|
||||
"""Check if the voicebot is still registered with the server"""
|
||||
try:
|
||||
# First check if websocket is still connected
|
||||
if not self.websocket:
|
||||
logger.warning("WebSocket connection lost")
|
||||
return False
|
||||
|
||||
# Try to send a ping/status check message to verify connection
|
||||
# We'll use a simple status message to check connectivity
|
||||
try:
|
||||
await self._send_message("status_check", {"timestamp": time.time()})
|
||||
logger.debug("Registration status check sent")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send status check: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking registration status: {e}")
|
||||
return False
|
||||
|
||||
async def _re_register(self):
|
||||
"""Attempt to re-register with the server"""
|
||||
try:
|
||||
logger.info("Attempting to re-register with server")
|
||||
|
||||
# Mark as not registered during re-registration attempt
|
||||
self.is_registered = False
|
||||
|
||||
# Try to reconnect the websocket if it's lost
|
||||
if not self.websocket:
|
||||
logger.info("WebSocket lost, attempting to reconnect")
|
||||
await self._reconnect_websocket()
|
||||
|
||||
# Re-send name and join messages
|
||||
name_payload: MessageData = {"name": self.session_name}
|
||||
if self.name_password:
|
||||
name_payload["password"] = self.name_password
|
||||
|
||||
logger.info("Re-sending set_name message")
|
||||
await self._send_message("set_name", name_payload)
|
||||
|
||||
logger.info("Re-sending join message")
|
||||
await self._send_message("join", {})
|
||||
|
||||
# Mark as registered after successful re-join
|
||||
self.is_registered = True
|
||||
self.last_registration_check = time.time()
|
||||
|
||||
logger.info("Successfully re-registered with server")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to re-register with server: {e}", exc_info=True)
|
||||
# Will try again on next check interval
|
||||
|
||||
async def _reconnect_websocket(self):
|
||||
"""Reconnect the WebSocket connection"""
|
||||
try:
|
||||
# Close existing connection if any
|
||||
if self.websocket:
|
||||
try:
|
||||
ws = cast(WebSocketProtocol, self.websocket)
|
||||
await ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.websocket = None
|
||||
|
||||
# Reconnect
|
||||
ws_url = f"{self.server_url}/ws/lobby/{self.lobby_id}/{self.session_id}"
|
||||
|
||||
# If insecure (self-signed certs), create an SSL context for the websocket
|
||||
ws_ssl = create_ssl_context(self.insecure)
|
||||
|
||||
logger.info(f"Reconnecting to signaling server: {ws_url}")
|
||||
self.websocket = await websockets.connect(ws_url, ssl=ws_ssl)
|
||||
logger.info("Successfully reconnected to signaling server")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reconnect websocket: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from signaling server and cleanup"""
|
||||
# Cancel the registration check task
|
||||
if self.registration_check_task and not self.registration_check_task.done():
|
||||
self.registration_check_task.cancel()
|
||||
try:
|
||||
await self.registration_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self.registration_check_task = None
|
||||
|
||||
if self.websocket:
|
||||
ws = cast(WebSocketProtocol, self.websocket)
|
||||
await ws.close()
|
||||
|
||||
# Close all peer connections
|
||||
for pc in self.peer_connections.values():
|
||||
await pc.close()
|
||||
|
||||
# Stop local tracks
|
||||
for track in self.local_tracks.values():
|
||||
track.stop()
|
||||
|
||||
# Reset registration status
|
||||
self.is_registered = False
|
||||
|
||||
logger.info("Disconnected from signaling server")
|
||||
|
||||
async def _setup_local_media(self):
|
||||
"""Create local media tracks"""
|
||||
# If a bot provided a create_tracks callable, use it to create tracks.
|
||||
# Otherwise, use default synthetic tracks.
|
||||
try:
|
||||
if self.create_tracks:
|
||||
tracks = self.create_tracks(self.session_name)
|
||||
self.local_tracks.update(tracks)
|
||||
else:
|
||||
# Default fallback to synthetic tracks
|
||||
tracks = create_synthetic_tracks(self.session_name)
|
||||
self.local_tracks.update(tracks)
|
||||
except Exception:
|
||||
logger.exception("Failed to create local tracks using bot factory")
|
||||
|
||||
# Add local peer to peers dict
|
||||
local_peer = Peer(
|
||||
session_id=self.session_id,
|
||||
peer_name=self.session_name,
|
||||
local=True,
|
||||
attributes={"tracks": self.local_tracks},
|
||||
)
|
||||
self.peers[self.session_id] = local_peer
|
||||
|
||||
logger.info("Local media tracks created")
|
||||
|
||||
async def _send_message(
|
||||
self, message_type: str, data: Optional[MessageData] = None
|
||||
):
|
||||
"""Send message to signaling server"""
|
||||
if not self.websocket:
|
||||
logger.error("No websocket connection")
|
||||
return
|
||||
# Build message with explicit type to avoid type narrowing
|
||||
message: dict[str, object] = {"type": message_type}
|
||||
if data is not None:
|
||||
message["data"] = data
|
||||
|
||||
ws = cast(WebSocketProtocol, self.websocket)
|
||||
try:
|
||||
logger.debug(f"_send_message: Sending {message_type} with data: {data}")
|
||||
await ws.send(json.dumps(message))
|
||||
logger.debug(f"_send_message: Sent message: {message_type}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"_send_message: Failed to send {message_type}: {e}", exc_info=True
|
||||
)
|
||||
|
||||
async def _handle_messages(self):
|
||||
"""Handle incoming messages from signaling server"""
|
||||
try:
|
||||
ws = cast(WebSocketProtocol, self.websocket)
|
||||
async for message in ws:
|
||||
logger.debug(f"_handle_messages: Received raw message: {message}")
|
||||
try:
|
||||
data = cast(MessageData, json.loads(message))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"_handle_messages: Failed to parse message: {e}", exc_info=True
|
||||
)
|
||||
continue
|
||||
await self._process_message(data)
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning(f"WebSocket connection closed: {e}")
|
||||
self.is_registered = False
|
||||
# The periodic registration check will detect this and attempt reconnection
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling messages: {e}", exc_info=True)
|
||||
self.is_registered = False
|
||||
|
||||
async def _process_message(self, message: MessageData):
|
||||
"""Process incoming signaling messages"""
|
||||
try:
|
||||
# Validate the base message structure first
|
||||
validated_message = WebSocketMessageModel.model_validate(message)
|
||||
msg_type = validated_message.type
|
||||
data = validated_message.data
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid message structure: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"_process_message: Received message type: {msg_type} with data: {data}"
|
||||
)
|
||||
|
||||
if msg_type == "addPeer":
|
||||
try:
|
||||
validated = AddPeerModel.model_validate(data)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid addPeer payload: {e}", exc_info=True)
|
||||
return
|
||||
await self._handle_add_peer(validated)
|
||||
elif msg_type == "removePeer":
|
||||
try:
|
||||
validated = RemovePeerModel.model_validate(data)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid removePeer payload: {e}", exc_info=True)
|
||||
return
|
||||
await self._handle_remove_peer(validated)
|
||||
elif msg_type == "sessionDescription":
|
||||
try:
|
||||
validated = SessionDescriptionModel.model_validate(data)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid sessionDescription payload: {e}", exc_info=True)
|
||||
return
|
||||
await self._handle_session_description(validated)
|
||||
elif msg_type == "iceCandidate":
|
||||
try:
|
||||
validated = IceCandidateModel.model_validate(data)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid iceCandidate payload: {e}", exc_info=True)
|
||||
return
|
||||
await self._handle_ice_candidate(validated)
|
||||
elif msg_type == "join_status":
|
||||
try:
|
||||
validated = JoinStatusModel.model_validate(data)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid join_status payload: {e}", exc_info=True)
|
||||
return
|
||||
logger.info(f"Join status: {validated.status} - {validated.message}")
|
||||
elif msg_type == "user_joined":
|
||||
try:
|
||||
validated = UserJoinedModel.model_validate(data)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid user_joined payload: {e}", exc_info=True)
|
||||
return
|
||||
logger.info(
|
||||
f"User joined: {validated.name} (session: {validated.session_id})"
|
||||
)
|
||||
elif msg_type == "lobby_state":
|
||||
try:
|
||||
validated = LobbyStateModel.model_validate(data)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid lobby_state payload: {e}", exc_info=True)
|
||||
return
|
||||
participants = validated.participants
|
||||
logger.info(f"Lobby state updated: {len(participants)} participants")
|
||||
elif msg_type == "update_name":
|
||||
try:
|
||||
validated = UpdateNameModel.model_validate(data)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Invalid update payload: {e}", exc_info=True)
|
||||
return
|
||||
logger.info(f"Received update message: {validated}")
|
||||
else:
|
||||
logger.info(f"Unhandled message type: {msg_type} with data: {data}")
|
||||
|
||||
# Continue with more methods in the next part...
|
||||
|
||||
async def _handle_add_peer(self, data: AddPeerModel):
|
||||
"""Handle addPeer message - create new peer connection"""
|
||||
peer_id = data.peer_id
|
||||
peer_name = data.peer_name
|
||||
should_create_offer = data.should_create_offer
|
||||
|
||||
logger.info(
|
||||
f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})"
|
||||
)
|
||||
logger.debug(
|
||||
f"_handle_add_peer: peer_id={peer_id}, peer_name={peer_name}, should_create_offer={should_create_offer}"
|
||||
)
|
||||
|
||||
# Check if peer already exists
|
||||
if peer_id in self.peer_connections:
|
||||
pc = self.peer_connections[peer_id]
|
||||
logger.debug(
|
||||
f"_handle_add_peer: Existing connection state: {pc.connectionState}"
|
||||
)
|
||||
if pc.connectionState in ["new", "connected", "connecting"]:
|
||||
logger.info(f"Peer connection already exists for {peer_name}")
|
||||
return
|
||||
else:
|
||||
# Clean up stale connection
|
||||
logger.debug(
|
||||
f"_handle_add_peer: Closing stale connection for {peer_name}"
|
||||
)
|
||||
await pc.close()
|
||||
del self.peer_connections[peer_id]
|
||||
|
||||
# Create new peer
|
||||
peer = Peer(session_id=peer_id, peer_name=peer_name, local=False)
|
||||
self.peers[peer_id] = peer
|
||||
|
||||
# Create RTCPeerConnection
|
||||
config = RTCConfiguration(
|
||||
iceServers=[
|
||||
RTCIceServer(urls="stun:ketrenos.com:3478"),
|
||||
RTCIceServer(
|
||||
urls="turns:ketrenos.com:5349",
|
||||
username="ketra",
|
||||
credential="ketran",
|
||||
),
|
||||
# Add Google's public STUN server as fallback
|
||||
RTCIceServer(urls="stun:stun.l.google.com:19302"),
|
||||
],
|
||||
)
|
||||
logger.debug(
|
||||
f"_handle_add_peer: Creating RTCPeerConnection for {peer_name} with config: {config}"
|
||||
)
|
||||
pc = RTCPeerConnection(configuration=config)
|
||||
|
||||
# Add ICE gathering state change handler
|
||||
def on_ice_gathering_state_change() -> None:
|
||||
logger.info(f"ICE gathering state: {pc.iceGatheringState}")
|
||||
if pc.iceGatheringState == "complete":
|
||||
logger.info(
|
||||
f"ICE gathering complete for {peer_name} - checking if candidates were generated..."
|
||||
)
|
||||
|
||||
pc.on("icegatheringstatechange")(on_ice_gathering_state_change)
|
||||
|
||||
# Add connection state change handler
|
||||
def on_connection_state_change() -> None:
|
||||
logger.info(f"Connection state: {pc.connectionState}")
|
||||
|
||||
pc.on("connectionstatechange")(on_connection_state_change)
|
||||
|
||||
self.peer_connections[peer_id] = pc
|
||||
peer.connection = pc
|
||||
|
||||
# Set up event handlers
|
||||
def on_track(track: MediaStreamTrack) -> None:
|
||||
logger.info(f"Received {track.kind} track from {peer_name}")
|
||||
logger.info(f"on_track: {track.kind} from {peer_name}, track={track}")
|
||||
peer.attributes[f"{track.kind}_track"] = track
|
||||
if self.on_track_received:
|
||||
asyncio.ensure_future(self.on_track_received(peer, track))
|
||||
|
||||
pc.on("track")(on_track)
|
||||
|
||||
def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None:
|
||||
logger.info(f"on_ice_candidate: {candidate}")
|
||||
logger.info(
|
||||
f"on_ice_candidate CALLED for {peer_name}: candidate={candidate}"
|
||||
)
|
||||
if not candidate:
|
||||
logger.info(
|
||||
f"on_ice_candidate: End of candidates signal for {peer_name}"
|
||||
)
|
||||
return
|
||||
|
||||
# Raw SDP fragment for the candidate
|
||||
raw = getattr(candidate, "candidate", None)
|
||||
|
||||
# Try to infer candidate type from the SDP string (host/srflx/relay/prflx)
|
||||
def _parse_type(s: Optional[str]) -> str:
|
||||
if not s:
|
||||
return "eoc"
|
||||
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", s)
|
||||
return m.group(1) if m else "unknown"
|
||||
|
||||
cand_type = _parse_type(raw)
|
||||
protocol = getattr(candidate, "protocol", "unknown")
|
||||
logger.info(
|
||||
f"ICE candidate outgoing for {peer_name}: type={cand_type} protocol={protocol} sdp={raw}"
|
||||
)
|
||||
|
||||
candidate_model = ICECandidateDictModel(
|
||||
candidate=raw,
|
||||
sdpMid=getattr(candidate, "sdpMid", None),
|
||||
sdpMLineIndex=getattr(candidate, "sdpMLineIndex", None),
|
||||
)
|
||||
payload_model = IceCandidateModel(
|
||||
peer_id=peer_id, peer_name=peer_name, candidate=candidate_model
|
||||
)
|
||||
logger.info(
|
||||
f"on_ice_candidate: Sending relayICECandidate for {peer_name}: {candidate_model}"
|
||||
)
|
||||
asyncio.ensure_future(
|
||||
self._send_message("relayICECandidate", payload_model.model_dump())
|
||||
)
|
||||
|
||||
pc.on("icecandidate")(on_ice_candidate)
|
||||
|
||||
# Add local tracks
|
||||
for track in self.local_tracks.values():
|
||||
logger.debug(
|
||||
f"_handle_add_peer: Adding local track {track.kind} to {peer_name}"
|
||||
)
|
||||
pc.addTrack(track)
|
||||
|
||||
# Create offer if needed
|
||||
if should_create_offer:
|
||||
await self._create_and_send_offer(peer_id, peer_name, pc)
|
||||
|
||||
if self.on_peer_added:
|
||||
await self.on_peer_added(peer)
|
||||
|
||||
async def _create_and_send_offer(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
|
||||
"""Create and send an offer to a peer"""
|
||||
self.initiated_offer.add(peer_id)
|
||||
self.making_offer[peer_id] = True
|
||||
self.is_negotiating[peer_id] = True
|
||||
|
||||
try:
|
||||
logger.debug(f"_handle_add_peer: Creating offer for {peer_name}")
|
||||
offer = await pc.createOffer()
|
||||
logger.debug(
|
||||
f"_handle_add_peer: Offer created for {peer_name}: {offer}"
|
||||
)
|
||||
await pc.setLocalDescription(offer)
|
||||
logger.debug(f"_handle_add_peer: Local description set for {peer_name}")
|
||||
|
||||
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
|
||||
# Use Method 2: Complete SDP approach to extract ICE candidates
|
||||
logger.debug(
|
||||
f"_handle_add_peer: Waiting for ICE gathering to complete for {peer_name}"
|
||||
)
|
||||
while pc.iceGatheringState != "complete":
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.debug(
|
||||
f"_handle_add_peer: ICE gathering complete, extracting candidates from SDP for {peer_name}"
|
||||
)
|
||||
|
||||
await self._extract_and_send_candidates(peer_id, peer_name, pc)
|
||||
|
||||
session_desc_typed = SessionDescriptionTypedModel(
|
||||
type=offer.type, sdp=offer.sdp
|
||||
)
|
||||
session_desc_model = SessionDescriptionModel(
|
||||
peer_id=peer_id,
|
||||
peer_name=peer_name,
|
||||
session_description=session_desc_typed,
|
||||
)
|
||||
await self._send_message(
|
||||
"relaySessionDescription",
|
||||
session_desc_model.model_dump(),
|
||||
)
|
||||
|
||||
logger.info(f"Offer sent to {peer_name}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to create/send offer to {peer_name}: {e}", exc_info=True
|
||||
)
|
||||
finally:
|
||||
self.making_offer[peer_id] = False
|
||||
|
||||
async def _extract_and_send_candidates(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
|
||||
"""Extract ICE candidates from SDP and send them"""
|
||||
# Parse ICE candidates from the local SDP
|
||||
sdp_lines = pc.localDescription.sdp.split("\n")
|
||||
candidate_lines = [
|
||||
line for line in sdp_lines if line.startswith("a=candidate:")
|
||||
]
|
||||
|
||||
# Track which media section we're in to determine sdpMid and sdpMLineIndex
|
||||
current_media_index = -1
|
||||
current_mid = None
|
||||
|
||||
for line in sdp_lines:
|
||||
if line.startswith("m="): # Media section
|
||||
current_media_index += 1
|
||||
elif line.startswith("a=mid:"): # Media ID
|
||||
current_mid = line.split(":", 1)[1].strip()
|
||||
elif line.startswith("a=candidate:"):
|
||||
candidate_sdp = line[2:] # Remove 'a=' prefix
|
||||
|
||||
candidate_model = ICECandidateDictModel(
|
||||
candidate=candidate_sdp,
|
||||
sdpMid=current_mid,
|
||||
sdpMLineIndex=current_media_index,
|
||||
)
|
||||
payload_candidate = IceCandidateModel(
|
||||
peer_id=peer_id,
|
||||
peer_name=peer_name,
|
||||
candidate=candidate_model,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"_extract_and_send_candidates: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}..."
|
||||
)
|
||||
await self._send_message(
|
||||
"relayICECandidate", payload_candidate.model_dump()
|
||||
)
|
||||
|
||||
# Send end-of-candidates signal (empty candidate)
|
||||
end_candidate_model = ICECandidateDictModel(
|
||||
candidate="",
|
||||
sdpMid=None,
|
||||
sdpMLineIndex=None,
|
||||
)
|
||||
payload_end = IceCandidateModel(
|
||||
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
|
||||
)
|
||||
logger.debug(
|
||||
f"_extract_and_send_candidates: Sending end-of-candidates signal for {peer_name}"
|
||||
)
|
||||
await self._send_message("relayICECandidate", payload_end.model_dump())
|
||||
|
||||
logger.debug(
|
||||
f"_extract_and_send_candidates: Sent {len(candidate_lines)} ICE candidates to {peer_name}"
|
||||
)
|
||||
|
||||
async def _handle_remove_peer(self, data: RemovePeerModel):
|
||||
"""Handle removePeer message"""
|
||||
peer_id = data.peer_id
|
||||
peer_name = data.peer_name
|
||||
|
||||
logger.info(f"Removing peer: {peer_name}")
|
||||
|
||||
# Close peer connection
|
||||
if peer_id in self.peer_connections:
|
||||
pc = self.peer_connections[peer_id]
|
||||
await pc.close()
|
||||
del self.peer_connections[peer_id]
|
||||
|
||||
# Clean up state
|
||||
self.is_negotiating.pop(peer_id, None)
|
||||
self.making_offer.pop(peer_id, None)
|
||||
self.initiated_offer.discard(peer_id)
|
||||
self.pending_ice_candidates.pop(peer_id, None)
|
||||
|
||||
# Remove peer
|
||||
peer = self.peers.pop(peer_id, None)
|
||||
if peer and self.on_peer_removed:
|
||||
await self.on_peer_removed(peer)
|
||||
|
||||
async def _handle_session_description(self, data: SessionDescriptionModel):
|
||||
"""Handle sessionDescription message"""
|
||||
peer_id = data.peer_id
|
||||
peer_name = data.peer_name
|
||||
session_description = data.session_description.model_dump()
|
||||
|
||||
logger.info(f"Received {session_description['type']} from {peer_name}")
|
||||
|
||||
pc = self.peer_connections.get(peer_id)
|
||||
if not pc:
|
||||
logger.error(f"No peer connection for {peer_name}")
|
||||
return
|
||||
|
||||
desc = RTCSessionDescription(
|
||||
sdp=session_description["sdp"], type=session_description["type"]
|
||||
)
|
||||
|
||||
# Handle offer collision (polite peer pattern)
|
||||
making_offer = self.making_offer.get(peer_id, False)
|
||||
offer_collision = desc.type == "offer" and (
|
||||
making_offer or pc.signalingState != "stable"
|
||||
)
|
||||
we_initiated = peer_id in self.initiated_offer
|
||||
ignore_offer = we_initiated and offer_collision
|
||||
|
||||
if ignore_offer:
|
||||
logger.info(f"Ignoring offer from {peer_name} due to collision")
|
||||
return
|
||||
|
||||
try:
|
||||
await pc.setRemoteDescription(desc)
|
||||
self.is_negotiating[peer_id] = False
|
||||
logger.info(f"Remote description set for {peer_name}")
|
||||
|
||||
# Process queued ICE candidates
|
||||
pending_candidates = self.pending_ice_candidates.pop(peer_id, [])
|
||||
|
||||
for candidate_data in pending_candidates:
|
||||
# candidate_data is an ICECandidateDictModel Pydantic model
|
||||
cand = candidate_data.candidate
|
||||
# handle end-of-candidates marker
|
||||
if not cand:
|
||||
await pc.addIceCandidate(None)
|
||||
logger.info(f"Added queued end-of-candidates for {peer_name}")
|
||||
continue
|
||||
|
||||
# cand may be the full "candidate:..." string or the inner SDP part
|
||||
if cand and cand.startswith("candidate:"):
|
||||
sdp_part = cand.split(":", 1)[1]
|
||||
else:
|
||||
sdp_part = cand
|
||||
|
||||
try:
|
||||
rtc_candidate = candidate_from_sdp(sdp_part)
|
||||
rtc_candidate.sdpMid = candidate_data.sdpMid
|
||||
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
|
||||
await pc.addIceCandidate(rtc_candidate)
|
||||
logger.info(f"Added queued ICE candidate for {peer_name}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to add queued ICE candidate for {peer_name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set remote description for {peer_name}: {e}")
|
||||
return
|
||||
|
||||
# Create answer if this was an offer
|
||||
if session_description["type"] == "offer":
|
||||
await self._create_and_send_answer(peer_id, peer_name, pc)
|
||||
|
||||
async def _create_and_send_answer(self, peer_id: str, peer_name: str, pc: RTCPeerConnection):
|
||||
"""Create and send an answer to a peer"""
|
||||
try:
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
|
||||
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
|
||||
# Use Method 2: Complete SDP approach to extract ICE candidates
|
||||
logger.debug(
|
||||
f"_create_and_send_answer: Waiting for ICE gathering to complete for {peer_name} (answer)"
|
||||
)
|
||||
while pc.iceGatheringState != "complete":
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
logger.debug(
|
||||
f"_create_and_send_answer: ICE gathering complete, extracting candidates from SDP for {peer_name} (answer)"
|
||||
)
|
||||
|
||||
await self._extract_and_send_candidates(peer_id, peer_name, pc)
|
||||
|
||||
session_desc_typed = SessionDescriptionTypedModel(
|
||||
type=answer.type, sdp=answer.sdp
|
||||
)
|
||||
session_desc_model = SessionDescriptionModel(
|
||||
peer_id=peer_id,
|
||||
peer_name=peer_name,
|
||||
session_description=session_desc_typed,
|
||||
)
|
||||
await self._send_message(
|
||||
"relaySessionDescription",
|
||||
session_desc_model.model_dump(),
|
||||
)
|
||||
|
||||
logger.info(f"Answer sent to {peer_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create/send answer to {peer_name}: {e}")
|
||||
|
||||
async def _handle_ice_candidate(self, data: IceCandidateModel):
|
||||
"""Handle iceCandidate message"""
|
||||
peer_id = data.peer_id
|
||||
peer_name = data.peer_name
|
||||
candidate_data = data.candidate
|
||||
|
||||
logger.info(f"Received ICE candidate from {peer_name}")
|
||||
|
||||
pc = self.peer_connections.get(peer_id)
|
||||
if not pc:
|
||||
logger.error(f"No peer connection for {peer_name}")
|
||||
return
|
||||
|
||||
# Queue candidate if remote description not set
|
||||
if not pc.remoteDescription:
|
||||
logger.info(
|
||||
f"Remote description not set, queuing ICE candidate for {peer_name}"
|
||||
)
|
||||
if peer_id not in self.pending_ice_candidates:
|
||||
self.pending_ice_candidates[peer_id] = []
|
||||
# candidate_data is an ICECandidateDictModel Pydantic model
|
||||
self.pending_ice_candidates[peer_id].append(candidate_data)
|
||||
return
|
||||
|
||||
try:
|
||||
cand = candidate_data.candidate
|
||||
if not cand:
|
||||
# end-of-candidates
|
||||
await pc.addIceCandidate(None)
|
||||
logger.info(f"End-of-candidates added for {peer_name}")
|
||||
return
|
||||
|
||||
if cand and cand.startswith("candidate:"):
|
||||
sdp_part = cand.split(":", 1)[1]
|
||||
else:
|
||||
sdp_part = cand
|
||||
|
||||
# Detect type for logging
|
||||
try:
|
||||
m = re.search(r"\btyp\s+(host|srflx|relay|prflx)\b", sdp_part)
|
||||
cand_type = m.group(1) if m else "unknown"
|
||||
except Exception:
|
||||
cand_type = "unknown"
|
||||
|
||||
try:
|
||||
rtc_candidate = candidate_from_sdp(sdp_part)
|
||||
rtc_candidate.sdpMid = candidate_data.sdpMid
|
||||
rtc_candidate.sdpMLineIndex = candidate_data.sdpMLineIndex
|
||||
|
||||
# aiortc expects an object with attributes (RTCIceCandidate)
|
||||
await pc.addIceCandidate(rtc_candidate)
|
||||
logger.info(f"ICE candidate added for {peer_name}: type={cand_type}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to add ICE candidate for {peer_name}: type={cand_type} error={e} sdp='{sdp_part}'",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error handling ICE candidate for {peer_name}: {e}",
|
||||
exc_info=True,
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user