Minimal bot

This commit is contained in:
James Ketr 2025-09-16 14:32:29 -07:00
parent 6b2f5555a4
commit 8caef2c3de
2 changed files with 189 additions and 12 deletions

View File

@ -11,10 +11,14 @@ import time
from av.audio.frame import AudioFrame
from av import VideoFrame
from aiortc import MediaStreamTrack
from typing import Dict, Any, Optional, Tuple
from typing import Dict, Optional, Tuple, Any
from shared.logger import logger
# Global registry to store active tracks by session
_active_tracks: Dict[str, Dict[str, Any]] = {}
class MediaClock:
"""Shared clock for media synchronization."""
@ -47,13 +51,68 @@ class ConfigurableVideoTrack(MediaStreamTrack):
self.frame_count = 0
self._start_time = time.time()
# Mode-specific state
if self.mode == 'ball':
self.ball_x = self.width // 2
self.ball_y = self.height // 2
self.ball_dx = 2
self.ball_dy = 2
self.ball_radius = 20
# Initialize ball attributes (always, regardless of mode)
self.ball_x = self.width // 2
self.ball_y = self.height // 2
self.ball_dx = 2
self.ball_dy = 2
self.ball_radius = 20
def update_config(self, config_updates: Dict[str, Any]) -> bool:
"""Update the video track configuration dynamically.
Args:
config_updates: Dictionary of configuration values to update
Returns:
bool: True if update was successful
"""
try:
old_mode = self.mode
old_width = self.width
old_height = self.height
# Update configuration
self.config.update(config_updates)
# Update instance variables
if 'width' in config_updates:
self.width = config_updates['width']
if 'height' in config_updates:
self.height = config_updates['height']
if 'fps' in config_updates:
self.fps = config_updates['fps']
if 'visualization' in config_updates:
self.mode = config_updates['visualization']
# Reinitialize mode-specific state if mode changed
if self.mode != old_mode:
self._initialize_mode_state()
logger.info(f"Video mode changed from {old_mode} to {self.mode}")
# Reinitialize ball state if dimensions changed (for any mode, since ball is default)
if self.width != old_width or self.height != old_height:
self._initialize_ball_state()
logger.info(f"Video track configuration updated: {config_updates}")
return True
except Exception as e:
logger.error(f"Error updating video track configuration: {e}")
return False
def _initialize_mode_state(self):
"""Initialize state specific to the current visualization mode."""
# Always initialize ball state since it's used as default
self._initialize_ball_state()
def _initialize_ball_state(self):
"""Initialize bouncing ball state."""
self.ball_x = self.width // 2
self.ball_y = self.height // 2
self.ball_dx = 2
self.ball_dy = 2
self.ball_radius = min(self.width, self.height) * 0.06
async def next_timestamp(self) -> Tuple[int, float]:
pts = int(self.frame_count * (90000 / self.fps))
@ -71,7 +130,8 @@ class ConfigurableVideoTrack(MediaStreamTrack):
elif self.mode == 'static':
frame_array = self._generate_static_frame()
else:
frame_array = self._generate_ball_frame() # default
# Default to ball mode if unknown mode
frame_array = self._generate_ball_frame()
frame = VideoFrame.from_ndarray(frame_array, format="bgr24") # type: ignore
frame.pts = pts
@ -144,7 +204,38 @@ class ConfigurableAudioTrack(MediaStreamTrack):
self.mode = config.get('audio_mode', 'tone')
self.frequency = config.get('frequency', 440.0)
self.volume = config.get('volume', 0.5)
self._samples_generated = 0
self._samples_generated: int = 0
def update_config(self, config_updates: Dict[str, Any]) -> bool:
"""Update the audio track configuration dynamically.
Args:
config_updates: Dictionary of configuration values to update
Returns:
bool: True if update was successful
"""
try:
# Update configuration
self.config.update(config_updates)
# Update instance variables
if 'sample_rate' in config_updates:
self.sample_rate = config_updates['sample_rate']
if 'samples_per_frame' in config_updates:
self.samples_per_frame = config_updates['samples_per_frame']
if 'audio_mode' in config_updates:
self.mode = config_updates['audio_mode']
if 'frequency' in config_updates:
self.frequency = config_updates['frequency']
if 'volume' in config_updates:
self.volume = config_updates['volume']
logger.info(f"Audio track configuration updated: {config_updates}")
return True
except Exception as e:
logger.error(f"Error updating audio track configuration: {e}")
return False
async def next_timestamp(self) -> Tuple[int, float]:
pts = self._samples_generated
@ -252,6 +343,9 @@ def create_minimal_bot_tracks(session_name: str, config: Optional[Dict[str, Any]
logger.info(f"Created minimal bot tracks for {session_name} with config: {default_config}")
# Store tracks in global registry for dynamic updates
_active_tracks[session_name] = {"video": video_track, "audio": audio_track}
return {"video": video_track, "audio": audio_track}
@ -393,4 +487,85 @@ def get_config_schema() -> Dict[str, Any]:
"Advanced": ["static_color"]
}
]
}
}
def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
"""Factory wrapper used by the FastAPI service to instantiate tracks for an agent."""
return create_minimal_bot_tracks(session_name)
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
"""
Handle runtime configuration updates for the minimal bot.
Args:
lobby_id: ID of the lobby/bot instance
config_values: Dictionary of configuration values to update
Returns:
bool: True if update was successful, False otherwise
"""
try:
# Validate configuration values
validated_config = {}
# Parse and validate static_color if provided
if 'static_color' in config_values:
color_value = config_values['static_color']
if isinstance(color_value, str):
try:
r, g, b = map(int, color_value.split(',')) # type: ignore
validated_config['static_color'] = (r, g, b)
except (ValueError, TypeError):
logger.warning(f"Invalid static_color format: {color_value}, ignoring update")
return False
elif isinstance(color_value, (tuple, list)) and len(color_value) == 3: # type: ignore
validated_config['static_color'] = tuple(color_value) # type: ignore
else:
logger.warning(f"Invalid static_color type: {type(color_value)}, ignoring update") # type: ignore
return False
# Copy other valid configuration values
valid_keys = {'visualization', 'audio_mode', 'width', 'height', 'fps',
'sample_rate', 'frequency', 'volume'}
for key in valid_keys:
if key in config_values:
validated_config[key] = config_values[key]
if validated_config:
logger.info(f"Configuration updated for {lobby_id}: {validated_config}")
# Update running tracks if they exist in the registry
if lobby_id in _active_tracks:
tracks = _active_tracks[lobby_id]
video_track = tracks.get("video")
audio_track = tracks.get("audio")
# Update video track configuration
if video_track and hasattr(video_track, 'update_config'):
video_updated = video_track.update_config(validated_config)
if video_updated:
logger.info(f"Video track configuration updated for {lobby_id}")
else:
logger.warning(f"Failed to update video track configuration for {lobby_id}")
# Update audio track configuration
if audio_track and hasattr(audio_track, 'update_config'):
audio_updated = audio_track.update_config(validated_config)
if audio_updated:
logger.info(f"Audio track configuration updated for {lobby_id}")
else:
logger.warning(f"Failed to update audio track configuration for {lobby_id}")
return True
else:
logger.warning(f"No active tracks found for session {lobby_id}")
return False
else:
logger.warning(f"No valid configuration values provided for {lobby_id}")
return False
except Exception as e:
logger.error(f"Error updating configuration for {lobby_id}: {e}")
return False

View File

@ -473,7 +473,9 @@ class WebRTCSignalingClient:
# Otherwise, use default synthetic tracks.
try:
if self.create_tracks:
tracks = self.create_tracks(self.session_name)
# Use lobby_id as the session identifier for track creation
# This ensures config updates can find the tracks using the same key
tracks = self.create_tracks(self.lobby_id)
self.local_tracks.update(tracks)
else:
# Default fallback to synthetic tracks