Minimal bot
This commit is contained in:
parent
6b2f5555a4
commit
8caef2c3de
@ -11,10 +11,14 @@ import time
|
||||
from av.audio.frame import AudioFrame
|
||||
from av import VideoFrame
|
||||
from aiortc import MediaStreamTrack
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple, Any
|
||||
from shared.logger import logger
|
||||
|
||||
|
||||
# Global registry to store active tracks by session
|
||||
_active_tracks: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class MediaClock:
|
||||
"""Shared clock for media synchronization."""
|
||||
|
||||
@ -47,13 +51,68 @@ class ConfigurableVideoTrack(MediaStreamTrack):
|
||||
self.frame_count = 0
|
||||
self._start_time = time.time()
|
||||
|
||||
# Mode-specific state
|
||||
if self.mode == 'ball':
|
||||
self.ball_x = self.width // 2
|
||||
self.ball_y = self.height // 2
|
||||
self.ball_dx = 2
|
||||
self.ball_dy = 2
|
||||
self.ball_radius = 20
|
||||
# Initialize ball attributes (always, regardless of mode)
|
||||
self.ball_x = self.width // 2
|
||||
self.ball_y = self.height // 2
|
||||
self.ball_dx = 2
|
||||
self.ball_dy = 2
|
||||
self.ball_radius = 20
|
||||
|
||||
def update_config(self, config_updates: Dict[str, Any]) -> bool:
|
||||
"""Update the video track configuration dynamically.
|
||||
|
||||
Args:
|
||||
config_updates: Dictionary of configuration values to update
|
||||
|
||||
Returns:
|
||||
bool: True if update was successful
|
||||
"""
|
||||
try:
|
||||
old_mode = self.mode
|
||||
old_width = self.width
|
||||
old_height = self.height
|
||||
|
||||
# Update configuration
|
||||
self.config.update(config_updates)
|
||||
|
||||
# Update instance variables
|
||||
if 'width' in config_updates:
|
||||
self.width = config_updates['width']
|
||||
if 'height' in config_updates:
|
||||
self.height = config_updates['height']
|
||||
if 'fps' in config_updates:
|
||||
self.fps = config_updates['fps']
|
||||
if 'visualization' in config_updates:
|
||||
self.mode = config_updates['visualization']
|
||||
|
||||
# Reinitialize mode-specific state if mode changed
|
||||
if self.mode != old_mode:
|
||||
self._initialize_mode_state()
|
||||
logger.info(f"Video mode changed from {old_mode} to {self.mode}")
|
||||
|
||||
# Reinitialize ball state if dimensions changed (for any mode, since ball is default)
|
||||
if self.width != old_width or self.height != old_height:
|
||||
self._initialize_ball_state()
|
||||
|
||||
logger.info(f"Video track configuration updated: {config_updates}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating video track configuration: {e}")
|
||||
return False
|
||||
|
||||
def _initialize_mode_state(self):
|
||||
"""Initialize state specific to the current visualization mode."""
|
||||
# Always initialize ball state since it's used as default
|
||||
self._initialize_ball_state()
|
||||
|
||||
def _initialize_ball_state(self):
|
||||
"""Initialize bouncing ball state."""
|
||||
self.ball_x = self.width // 2
|
||||
self.ball_y = self.height // 2
|
||||
self.ball_dx = 2
|
||||
self.ball_dy = 2
|
||||
self.ball_radius = min(self.width, self.height) * 0.06
|
||||
|
||||
async def next_timestamp(self) -> Tuple[int, float]:
|
||||
pts = int(self.frame_count * (90000 / self.fps))
|
||||
@ -71,7 +130,8 @@ class ConfigurableVideoTrack(MediaStreamTrack):
|
||||
elif self.mode == 'static':
|
||||
frame_array = self._generate_static_frame()
|
||||
else:
|
||||
frame_array = self._generate_ball_frame() # default
|
||||
# Default to ball mode if unknown mode
|
||||
frame_array = self._generate_ball_frame()
|
||||
|
||||
frame = VideoFrame.from_ndarray(frame_array, format="bgr24") # type: ignore
|
||||
frame.pts = pts
|
||||
@ -144,7 +204,38 @@ class ConfigurableAudioTrack(MediaStreamTrack):
|
||||
self.mode = config.get('audio_mode', 'tone')
|
||||
self.frequency = config.get('frequency', 440.0)
|
||||
self.volume = config.get('volume', 0.5)
|
||||
self._samples_generated = 0
|
||||
self._samples_generated: int = 0
|
||||
def update_config(self, config_updates: Dict[str, Any]) -> bool:
|
||||
"""Update the audio track configuration dynamically.
|
||||
|
||||
Args:
|
||||
config_updates: Dictionary of configuration values to update
|
||||
|
||||
Returns:
|
||||
bool: True if update was successful
|
||||
"""
|
||||
try:
|
||||
# Update configuration
|
||||
self.config.update(config_updates)
|
||||
|
||||
# Update instance variables
|
||||
if 'sample_rate' in config_updates:
|
||||
self.sample_rate = config_updates['sample_rate']
|
||||
if 'samples_per_frame' in config_updates:
|
||||
self.samples_per_frame = config_updates['samples_per_frame']
|
||||
if 'audio_mode' in config_updates:
|
||||
self.mode = config_updates['audio_mode']
|
||||
if 'frequency' in config_updates:
|
||||
self.frequency = config_updates['frequency']
|
||||
if 'volume' in config_updates:
|
||||
self.volume = config_updates['volume']
|
||||
|
||||
logger.info(f"Audio track configuration updated: {config_updates}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating audio track configuration: {e}")
|
||||
return False
|
||||
|
||||
async def next_timestamp(self) -> Tuple[int, float]:
|
||||
pts = self._samples_generated
|
||||
@ -252,6 +343,9 @@ def create_minimal_bot_tracks(session_name: str, config: Optional[Dict[str, Any]
|
||||
|
||||
logger.info(f"Created minimal bot tracks for {session_name} with config: {default_config}")
|
||||
|
||||
# Store tracks in global registry for dynamic updates
|
||||
_active_tracks[session_name] = {"video": video_track, "audio": audio_track}
|
||||
|
||||
return {"video": video_track, "audio": audio_track}
|
||||
|
||||
|
||||
@ -393,4 +487,85 @@ def get_config_schema() -> Dict[str, Any]:
|
||||
"Advanced": ["static_color"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
|
||||
"""Factory wrapper used by the FastAPI service to instantiate tracks for an agent."""
|
||||
return create_minimal_bot_tracks(session_name)
|
||||
|
||||
|
||||
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Handle runtime configuration updates for the minimal bot.
|
||||
|
||||
Args:
|
||||
lobby_id: ID of the lobby/bot instance
|
||||
config_values: Dictionary of configuration values to update
|
||||
|
||||
Returns:
|
||||
bool: True if update was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Validate configuration values
|
||||
validated_config = {}
|
||||
|
||||
# Parse and validate static_color if provided
|
||||
if 'static_color' in config_values:
|
||||
color_value = config_values['static_color']
|
||||
if isinstance(color_value, str):
|
||||
try:
|
||||
r, g, b = map(int, color_value.split(',')) # type: ignore
|
||||
validated_config['static_color'] = (r, g, b)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"Invalid static_color format: {color_value}, ignoring update")
|
||||
return False
|
||||
elif isinstance(color_value, (tuple, list)) and len(color_value) == 3: # type: ignore
|
||||
validated_config['static_color'] = tuple(color_value) # type: ignore
|
||||
else:
|
||||
logger.warning(f"Invalid static_color type: {type(color_value)}, ignoring update") # type: ignore
|
||||
return False
|
||||
|
||||
# Copy other valid configuration values
|
||||
valid_keys = {'visualization', 'audio_mode', 'width', 'height', 'fps',
|
||||
'sample_rate', 'frequency', 'volume'}
|
||||
for key in valid_keys:
|
||||
if key in config_values:
|
||||
validated_config[key] = config_values[key]
|
||||
|
||||
if validated_config:
|
||||
logger.info(f"Configuration updated for {lobby_id}: {validated_config}")
|
||||
|
||||
# Update running tracks if they exist in the registry
|
||||
if lobby_id in _active_tracks:
|
||||
tracks = _active_tracks[lobby_id]
|
||||
video_track = tracks.get("video")
|
||||
audio_track = tracks.get("audio")
|
||||
|
||||
# Update video track configuration
|
||||
if video_track and hasattr(video_track, 'update_config'):
|
||||
video_updated = video_track.update_config(validated_config)
|
||||
if video_updated:
|
||||
logger.info(f"Video track configuration updated for {lobby_id}")
|
||||
else:
|
||||
logger.warning(f"Failed to update video track configuration for {lobby_id}")
|
||||
|
||||
# Update audio track configuration
|
||||
if audio_track and hasattr(audio_track, 'update_config'):
|
||||
audio_updated = audio_track.update_config(validated_config)
|
||||
if audio_updated:
|
||||
logger.info(f"Audio track configuration updated for {lobby_id}")
|
||||
else:
|
||||
logger.warning(f"Failed to update audio track configuration for {lobby_id}")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"No active tracks found for session {lobby_id}")
|
||||
return False
|
||||
else:
|
||||
logger.warning(f"No valid configuration values provided for {lobby_id}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating configuration for {lobby_id}: {e}")
|
||||
return False
|
@ -473,7 +473,9 @@ class WebRTCSignalingClient:
|
||||
# Otherwise, use default synthetic tracks.
|
||||
try:
|
||||
if self.create_tracks:
|
||||
tracks = self.create_tracks(self.session_name)
|
||||
# Use lobby_id as the session identifier for track creation
|
||||
# This ensures config updates can find the tracks using the same key
|
||||
tracks = self.create_tracks(self.lobby_id)
|
||||
self.local_tracks.update(tracks)
|
||||
else:
|
||||
# Default fallback to synthetic tracks
|
||||
|
Loading…
x
Reference in New Issue
Block a user