diff --git a/voicebot/bots/minimal.py b/voicebot/bots/minimal.py index 68d9070..2401568 100644 --- a/voicebot/bots/minimal.py +++ b/voicebot/bots/minimal.py @@ -11,10 +11,14 @@ import time from av.audio.frame import AudioFrame from av import VideoFrame from aiortc import MediaStreamTrack -from typing import Dict, Any, Optional, Tuple +from typing import Dict, Optional, Tuple, Any from shared.logger import logger +# Global registry to store active tracks by session +_active_tracks: Dict[str, Dict[str, Any]] = {} + + class MediaClock: """Shared clock for media synchronization.""" @@ -47,13 +51,68 @@ class ConfigurableVideoTrack(MediaStreamTrack): self.frame_count = 0 self._start_time = time.time() - # Mode-specific state - if self.mode == 'ball': - self.ball_x = self.width // 2 - self.ball_y = self.height // 2 - self.ball_dx = 2 - self.ball_dy = 2 - self.ball_radius = 20 + # Initialize ball attributes (always, regardless of mode) + self.ball_x = self.width // 2 + self.ball_y = self.height // 2 + self.ball_dx = 2 + self.ball_dy = 2 + self.ball_radius = 20 + + def update_config(self, config_updates: Dict[str, Any]) -> bool: + """Update the video track configuration dynamically. + + Args: + config_updates: Dictionary of configuration values to update + + Returns: + bool: True if update was successful + """ + try: + old_mode = self.mode + old_width = self.width + old_height = self.height + + # Update configuration + self.config.update(config_updates) + + # Update instance variables + if 'width' in config_updates: + self.width = config_updates['width'] + if 'height' in config_updates: + self.height = config_updates['height'] + if 'fps' in config_updates: + self.fps = config_updates['fps'] + if 'visualization' in config_updates: + self.mode = config_updates['visualization'] + + # Reinitialize mode-specific state if mode changed + if self.mode != old_mode: + self._initialize_mode_state() + logger.info(f"Video mode changed from {old_mode} to {self.mode}") + + # Reinitialize ball state if dimensions changed (for any mode, since ball is default) + if self.width != old_width or self.height != old_height: + self._initialize_ball_state() + + logger.info(f"Video track configuration updated: {config_updates}") + return True + + except Exception as e: + logger.error(f"Error updating video track configuration: {e}") + return False + + def _initialize_mode_state(self): + """Initialize state specific to the current visualization mode.""" + # Always initialize ball state since it's used as default + self._initialize_ball_state() + + def _initialize_ball_state(self): + """Initialize bouncing ball state.""" + self.ball_x = self.width // 2 + self.ball_y = self.height // 2 + self.ball_dx = 2 + self.ball_dy = 2 + self.ball_radius = min(self.width, self.height) * 0.06 async def next_timestamp(self) -> Tuple[int, float]: pts = int(self.frame_count * (90000 / self.fps)) @@ -71,7 +130,8 @@ class ConfigurableVideoTrack(MediaStreamTrack): elif self.mode == 'static': frame_array = self._generate_static_frame() else: - frame_array = self._generate_ball_frame() # default + # Default to ball mode if unknown mode + frame_array = self._generate_ball_frame() frame = VideoFrame.from_ndarray(frame_array, format="bgr24") # type: ignore frame.pts = pts @@ -144,7 +204,38 @@ class ConfigurableAudioTrack(MediaStreamTrack): self.mode = config.get('audio_mode', 'tone') self.frequency = config.get('frequency', 440.0) self.volume = config.get('volume', 0.5) - self._samples_generated = 0 + self._samples_generated: int = 0 + def update_config(self, config_updates: Dict[str, Any]) -> bool: + """Update the audio track configuration dynamically. + + Args: + config_updates: Dictionary of configuration values to update + + Returns: + bool: True if update was successful + """ + try: + # Update configuration + self.config.update(config_updates) + + # Update instance variables + if 'sample_rate' in config_updates: + self.sample_rate = config_updates['sample_rate'] + if 'samples_per_frame' in config_updates: + self.samples_per_frame = config_updates['samples_per_frame'] + if 'audio_mode' in config_updates: + self.mode = config_updates['audio_mode'] + if 'frequency' in config_updates: + self.frequency = config_updates['frequency'] + if 'volume' in config_updates: + self.volume = config_updates['volume'] + + logger.info(f"Audio track configuration updated: {config_updates}") + return True + + except Exception as e: + logger.error(f"Error updating audio track configuration: {e}") + return False async def next_timestamp(self) -> Tuple[int, float]: pts = self._samples_generated @@ -252,6 +343,9 @@ def create_minimal_bot_tracks(session_name: str, config: Optional[Dict[str, Any] logger.info(f"Created minimal bot tracks for {session_name} with config: {default_config}") + # Store tracks in global registry for dynamic updates + _active_tracks[session_name] = {"video": video_track, "audio": audio_track} + return {"video": video_track, "audio": audio_track} @@ -393,4 +487,85 @@ def get_config_schema() -> Dict[str, Any]: "Advanced": ["static_color"] } ] - } \ No newline at end of file + } + + +def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]: + """Factory wrapper used by the FastAPI service to instantiate tracks for an agent.""" + return create_minimal_bot_tracks(session_name) + + +def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool: + """ + Handle runtime configuration updates for the minimal bot. + + Args: + lobby_id: ID of the lobby/bot instance + config_values: Dictionary of configuration values to update + + Returns: + bool: True if update was successful, False otherwise + """ + try: + # Validate configuration values + validated_config = {} + + # Parse and validate static_color if provided + if 'static_color' in config_values: + color_value = config_values['static_color'] + if isinstance(color_value, str): + try: + r, g, b = map(int, color_value.split(',')) # type: ignore + validated_config['static_color'] = (r, g, b) + except (ValueError, TypeError): + logger.warning(f"Invalid static_color format: {color_value}, ignoring update") + return False + elif isinstance(color_value, (tuple, list)) and len(color_value) == 3: # type: ignore + validated_config['static_color'] = tuple(color_value) # type: ignore + else: + logger.warning(f"Invalid static_color type: {type(color_value)}, ignoring update") # type: ignore + return False + + # Copy other valid configuration values + valid_keys = {'visualization', 'audio_mode', 'width', 'height', 'fps', + 'sample_rate', 'frequency', 'volume'} + for key in valid_keys: + if key in config_values: + validated_config[key] = config_values[key] + + if validated_config: + logger.info(f"Configuration updated for {lobby_id}: {validated_config}") + + # Update running tracks if they exist in the registry + if lobby_id in _active_tracks: + tracks = _active_tracks[lobby_id] + video_track = tracks.get("video") + audio_track = tracks.get("audio") + + # Update video track configuration + if video_track and hasattr(video_track, 'update_config'): + video_updated = video_track.update_config(validated_config) + if video_updated: + logger.info(f"Video track configuration updated for {lobby_id}") + else: + logger.warning(f"Failed to update video track configuration for {lobby_id}") + + # Update audio track configuration + if audio_track and hasattr(audio_track, 'update_config'): + audio_updated = audio_track.update_config(validated_config) + if audio_updated: + logger.info(f"Audio track configuration updated for {lobby_id}") + else: + logger.warning(f"Failed to update audio track configuration for {lobby_id}") + + return True + else: + logger.warning(f"No active tracks found for session {lobby_id}") + return False + else: + logger.warning(f"No valid configuration values provided for {lobby_id}") + return False + + except Exception as e: + logger.error(f"Error updating configuration for {lobby_id}: {e}") + return False \ No newline at end of file diff --git a/voicebot/webrtc_signaling.py b/voicebot/webrtc_signaling.py index 6559cc1..f951305 100644 --- a/voicebot/webrtc_signaling.py +++ b/voicebot/webrtc_signaling.py @@ -473,7 +473,9 @@ class WebRTCSignalingClient: # Otherwise, use default synthetic tracks. try: if self.create_tracks: - tracks = self.create_tracks(self.session_name) + # Use lobby_id as the session identifier for track creation + # This ensures config updates can find the tracks using the same key + tracks = self.create_tracks(self.lobby_id) self.local_tracks.update(tracks) else: # Default fallback to synthetic tracks