123 lines
4.8 KiB
Python
123 lines
4.8 KiB
Python
"""
|
|
Data models and configuration for voicebot.
|
|
|
|
This module provides Pydantic models for configuration and data structures
|
|
used throughout the voicebot application.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
from enum import Enum
|
|
from typing import Dict, Optional, Any, TYPE_CHECKING
|
|
from dataclasses import dataclass, field
|
|
from pydantic import BaseModel, Field
|
|
|
|
if TYPE_CHECKING:
|
|
from aiortc import RTCPeerConnection
|
|
|
|
|
|
class VoicebotMode(str, Enum):
|
|
"""Voicebot operation modes."""
|
|
CLIENT = "client"
|
|
PROVIDER = "provider"
|
|
|
|
|
|
class VoicebotArgs(BaseModel):
|
|
"""Pydantic model for voicebot CLI arguments and configuration."""
|
|
|
|
# Mode selection
|
|
mode: VoicebotMode = Field(default=VoicebotMode.CLIENT, description="Run as client (connect to lobby) or provider (serve bots)")
|
|
|
|
# Provider mode arguments
|
|
host: str = Field(default="0.0.0.0", description="Host for provider mode")
|
|
port: int = Field(default=8788, description="Port for provider mode", ge=1, le=65535)
|
|
reload: bool = Field(default=False, description="Enable auto-reload for development")
|
|
|
|
# Client mode arguments
|
|
server_url: str = Field(
|
|
default="http://localhost:8000/ai-voicebot",
|
|
description="AI-Voicebot lobby and signaling server base URL (http:// or https://)"
|
|
)
|
|
lobby: str = Field(default="default", description="Lobby name to create or join")
|
|
session_name: str = Field(default="Python Bot", description="Session (user) display name")
|
|
session_id: Optional[str] = Field(default=None, description="Optional existing session id to reuse")
|
|
password: Optional[str] = Field(default=None, description="Optional password to register or takeover a name")
|
|
private: bool = Field(default=False, description="Create the lobby as private")
|
|
insecure: bool = Field(default=False, description="Allow insecure server connections when using SSL")
|
|
registration_check_interval: float = Field(default=30.0, description="Interval in seconds for checking registration status", ge=5.0, le=300.0)
|
|
|
|
@classmethod
|
|
def from_environment(cls) -> 'VoicebotArgs':
|
|
"""Create VoicebotArgs from environment variables."""
|
|
import os
|
|
|
|
mode_str = os.getenv('VOICEBOT_MODE', 'client')
|
|
return cls(
|
|
mode=VoicebotMode(mode_str),
|
|
host=os.getenv('VOICEBOT_HOST', '0.0.0.0'),
|
|
port=int(os.getenv('VOICEBOT_PORT', '8788')),
|
|
reload=os.getenv('VOICEBOT_RELOAD', 'false').lower() == 'true',
|
|
server_url=os.getenv('VOICEBOT_SERVER_URL', 'https://server:8000/ai-voicebot'),
|
|
lobby=os.getenv('VOICEBOT_LOBBY', 'default'),
|
|
session_name=os.getenv('VOICEBOT_SESSION_NAME', 'Python Bot'),
|
|
session_id=os.getenv('VOICEBOT_SESSION_ID', None),
|
|
password=os.getenv('VOICEBOT_PASSWORD', None),
|
|
private=os.getenv('VOICEBOT_PRIVATE', 'false').lower() == 'true',
|
|
insecure=os.getenv('VOICEBOT_SERVER_INSECURE', 'false').lower() == 'true',
|
|
registration_check_interval=float(os.getenv('VOICEBOT_REGISTRATION_CHECK_INTERVAL', '30.0'))
|
|
)
|
|
|
|
@classmethod
|
|
def from_argparse(cls, args: argparse.Namespace) -> 'VoicebotArgs':
|
|
"""Create VoicebotArgs from argparse Namespace."""
|
|
mode_str = getattr(args, 'mode', 'client')
|
|
return cls(
|
|
mode=VoicebotMode(mode_str),
|
|
host=getattr(args, 'host', '0.0.0.0'),
|
|
port=getattr(args, 'port', 8788),
|
|
reload=getattr(args, 'reload', False),
|
|
server_url=getattr(args, 'server_url', 'http://localhost:8000/ai-voicebot'),
|
|
lobby=getattr(args, 'lobby', 'default'),
|
|
session_name=getattr(args, 'session_name', 'Python Bot'),
|
|
session_id=getattr(args, 'session_id', None),
|
|
password=getattr(args, 'password', None),
|
|
private=getattr(args, 'private', False),
|
|
insecure=getattr(args, 'insecure', False),
|
|
registration_check_interval=float(getattr(args, 'registration_check_interval', 30.0))
|
|
)
|
|
|
|
|
|
class JoinRequest(BaseModel):
|
|
"""Request model for joining a lobby."""
|
|
lobby_id: str
|
|
session_id: str
|
|
nick: str
|
|
server_url: str
|
|
insecure: bool = False
|
|
config_values: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
def _default_attributes() -> Dict[str, object]:
|
|
"""Default factory for peer attributes."""
|
|
return {}
|
|
|
|
|
|
@dataclass
|
|
class Peer:
|
|
"""Represents a WebRTC peer in the session"""
|
|
|
|
session_id: str
|
|
peer_name: str
|
|
# Generic attributes bag. Values can be tracks or simple metadata.
|
|
attributes: Dict[str, object] = field(default_factory=_default_attributes)
|
|
muted: bool = False
|
|
video_on: bool = True
|
|
local: bool = False
|
|
dead: bool = False
|
|
connection: Optional['RTCPeerConnection'] = None
|
|
|
|
|
|
# Generic message payload type
|
|
MessageData = dict[str, object]
|