Minimal bot
This commit is contained in:
parent
6b2f5555a4
commit
8caef2c3de
@ -11,10 +11,14 @@ import time
|
|||||||
from av.audio.frame import AudioFrame
|
from av.audio.frame import AudioFrame
|
||||||
from av import VideoFrame
|
from av import VideoFrame
|
||||||
from aiortc import MediaStreamTrack
|
from aiortc import MediaStreamTrack
|
||||||
from typing import Dict, Any, Optional, Tuple
|
from typing import Dict, Optional, Tuple, Any
|
||||||
from shared.logger import logger
|
from shared.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
# Global registry to store active tracks by session
|
||||||
|
_active_tracks: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
|
||||||
class MediaClock:
|
class MediaClock:
|
||||||
"""Shared clock for media synchronization."""
|
"""Shared clock for media synchronization."""
|
||||||
|
|
||||||
@ -47,14 +51,69 @@ class ConfigurableVideoTrack(MediaStreamTrack):
|
|||||||
self.frame_count = 0
|
self.frame_count = 0
|
||||||
self._start_time = time.time()
|
self._start_time = time.time()
|
||||||
|
|
||||||
# Mode-specific state
|
# Initialize ball attributes (always, regardless of mode)
|
||||||
if self.mode == 'ball':
|
|
||||||
self.ball_x = self.width // 2
|
self.ball_x = self.width // 2
|
||||||
self.ball_y = self.height // 2
|
self.ball_y = self.height // 2
|
||||||
self.ball_dx = 2
|
self.ball_dx = 2
|
||||||
self.ball_dy = 2
|
self.ball_dy = 2
|
||||||
self.ball_radius = 20
|
self.ball_radius = 20
|
||||||
|
|
||||||
|
def update_config(self, config_updates: Dict[str, Any]) -> bool:
|
||||||
|
"""Update the video track configuration dynamically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_updates: Dictionary of configuration values to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if update was successful
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
old_mode = self.mode
|
||||||
|
old_width = self.width
|
||||||
|
old_height = self.height
|
||||||
|
|
||||||
|
# Update configuration
|
||||||
|
self.config.update(config_updates)
|
||||||
|
|
||||||
|
# Update instance variables
|
||||||
|
if 'width' in config_updates:
|
||||||
|
self.width = config_updates['width']
|
||||||
|
if 'height' in config_updates:
|
||||||
|
self.height = config_updates['height']
|
||||||
|
if 'fps' in config_updates:
|
||||||
|
self.fps = config_updates['fps']
|
||||||
|
if 'visualization' in config_updates:
|
||||||
|
self.mode = config_updates['visualization']
|
||||||
|
|
||||||
|
# Reinitialize mode-specific state if mode changed
|
||||||
|
if self.mode != old_mode:
|
||||||
|
self._initialize_mode_state()
|
||||||
|
logger.info(f"Video mode changed from {old_mode} to {self.mode}")
|
||||||
|
|
||||||
|
# Reinitialize ball state if dimensions changed (for any mode, since ball is default)
|
||||||
|
if self.width != old_width or self.height != old_height:
|
||||||
|
self._initialize_ball_state()
|
||||||
|
|
||||||
|
logger.info(f"Video track configuration updated: {config_updates}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating video track configuration: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _initialize_mode_state(self):
|
||||||
|
"""Initialize state specific to the current visualization mode."""
|
||||||
|
# Always initialize ball state since it's used as default
|
||||||
|
self._initialize_ball_state()
|
||||||
|
|
||||||
|
def _initialize_ball_state(self):
|
||||||
|
"""Initialize bouncing ball state."""
|
||||||
|
self.ball_x = self.width // 2
|
||||||
|
self.ball_y = self.height // 2
|
||||||
|
self.ball_dx = 2
|
||||||
|
self.ball_dy = 2
|
||||||
|
self.ball_radius = min(self.width, self.height) * 0.06
|
||||||
|
|
||||||
async def next_timestamp(self) -> Tuple[int, float]:
|
async def next_timestamp(self) -> Tuple[int, float]:
|
||||||
pts = int(self.frame_count * (90000 / self.fps))
|
pts = int(self.frame_count * (90000 / self.fps))
|
||||||
time_base = 1 / 90000
|
time_base = 1 / 90000
|
||||||
@ -71,7 +130,8 @@ class ConfigurableVideoTrack(MediaStreamTrack):
|
|||||||
elif self.mode == 'static':
|
elif self.mode == 'static':
|
||||||
frame_array = self._generate_static_frame()
|
frame_array = self._generate_static_frame()
|
||||||
else:
|
else:
|
||||||
frame_array = self._generate_ball_frame() # default
|
# Default to ball mode if unknown mode
|
||||||
|
frame_array = self._generate_ball_frame()
|
||||||
|
|
||||||
frame = VideoFrame.from_ndarray(frame_array, format="bgr24") # type: ignore
|
frame = VideoFrame.from_ndarray(frame_array, format="bgr24") # type: ignore
|
||||||
frame.pts = pts
|
frame.pts = pts
|
||||||
@ -144,7 +204,38 @@ class ConfigurableAudioTrack(MediaStreamTrack):
|
|||||||
self.mode = config.get('audio_mode', 'tone')
|
self.mode = config.get('audio_mode', 'tone')
|
||||||
self.frequency = config.get('frequency', 440.0)
|
self.frequency = config.get('frequency', 440.0)
|
||||||
self.volume = config.get('volume', 0.5)
|
self.volume = config.get('volume', 0.5)
|
||||||
self._samples_generated = 0
|
self._samples_generated: int = 0
|
||||||
|
def update_config(self, config_updates: Dict[str, Any]) -> bool:
|
||||||
|
"""Update the audio track configuration dynamically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_updates: Dictionary of configuration values to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if update was successful
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Update configuration
|
||||||
|
self.config.update(config_updates)
|
||||||
|
|
||||||
|
# Update instance variables
|
||||||
|
if 'sample_rate' in config_updates:
|
||||||
|
self.sample_rate = config_updates['sample_rate']
|
||||||
|
if 'samples_per_frame' in config_updates:
|
||||||
|
self.samples_per_frame = config_updates['samples_per_frame']
|
||||||
|
if 'audio_mode' in config_updates:
|
||||||
|
self.mode = config_updates['audio_mode']
|
||||||
|
if 'frequency' in config_updates:
|
||||||
|
self.frequency = config_updates['frequency']
|
||||||
|
if 'volume' in config_updates:
|
||||||
|
self.volume = config_updates['volume']
|
||||||
|
|
||||||
|
logger.info(f"Audio track configuration updated: {config_updates}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating audio track configuration: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
async def next_timestamp(self) -> Tuple[int, float]:
|
async def next_timestamp(self) -> Tuple[int, float]:
|
||||||
pts = self._samples_generated
|
pts = self._samples_generated
|
||||||
@ -252,6 +343,9 @@ def create_minimal_bot_tracks(session_name: str, config: Optional[Dict[str, Any]
|
|||||||
|
|
||||||
logger.info(f"Created minimal bot tracks for {session_name} with config: {default_config}")
|
logger.info(f"Created minimal bot tracks for {session_name} with config: {default_config}")
|
||||||
|
|
||||||
|
# Store tracks in global registry for dynamic updates
|
||||||
|
_active_tracks[session_name] = {"video": video_track, "audio": audio_track}
|
||||||
|
|
||||||
return {"video": video_track, "audio": audio_track}
|
return {"video": video_track, "audio": audio_track}
|
||||||
|
|
||||||
|
|
||||||
@ -394,3 +488,84 @@ def get_config_schema() -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
|
||||||
|
"""Factory wrapper used by the FastAPI service to instantiate tracks for an agent."""
|
||||||
|
return create_minimal_bot_tracks(session_name)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Handle runtime configuration updates for the minimal bot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lobby_id: ID of the lobby/bot instance
|
||||||
|
config_values: Dictionary of configuration values to update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if update was successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate configuration values
|
||||||
|
validated_config = {}
|
||||||
|
|
||||||
|
# Parse and validate static_color if provided
|
||||||
|
if 'static_color' in config_values:
|
||||||
|
color_value = config_values['static_color']
|
||||||
|
if isinstance(color_value, str):
|
||||||
|
try:
|
||||||
|
r, g, b = map(int, color_value.split(',')) # type: ignore
|
||||||
|
validated_config['static_color'] = (r, g, b)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(f"Invalid static_color format: {color_value}, ignoring update")
|
||||||
|
return False
|
||||||
|
elif isinstance(color_value, (tuple, list)) and len(color_value) == 3: # type: ignore
|
||||||
|
validated_config['static_color'] = tuple(color_value) # type: ignore
|
||||||
|
else:
|
||||||
|
logger.warning(f"Invalid static_color type: {type(color_value)}, ignoring update") # type: ignore
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Copy other valid configuration values
|
||||||
|
valid_keys = {'visualization', 'audio_mode', 'width', 'height', 'fps',
|
||||||
|
'sample_rate', 'frequency', 'volume'}
|
||||||
|
for key in valid_keys:
|
||||||
|
if key in config_values:
|
||||||
|
validated_config[key] = config_values[key]
|
||||||
|
|
||||||
|
if validated_config:
|
||||||
|
logger.info(f"Configuration updated for {lobby_id}: {validated_config}")
|
||||||
|
|
||||||
|
# Update running tracks if they exist in the registry
|
||||||
|
if lobby_id in _active_tracks:
|
||||||
|
tracks = _active_tracks[lobby_id]
|
||||||
|
video_track = tracks.get("video")
|
||||||
|
audio_track = tracks.get("audio")
|
||||||
|
|
||||||
|
# Update video track configuration
|
||||||
|
if video_track and hasattr(video_track, 'update_config'):
|
||||||
|
video_updated = video_track.update_config(validated_config)
|
||||||
|
if video_updated:
|
||||||
|
logger.info(f"Video track configuration updated for {lobby_id}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to update video track configuration for {lobby_id}")
|
||||||
|
|
||||||
|
# Update audio track configuration
|
||||||
|
if audio_track and hasattr(audio_track, 'update_config'):
|
||||||
|
audio_updated = audio_track.update_config(validated_config)
|
||||||
|
if audio_updated:
|
||||||
|
logger.info(f"Audio track configuration updated for {lobby_id}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to update audio track configuration for {lobby_id}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"No active tracks found for session {lobby_id}")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
logger.warning(f"No valid configuration values provided for {lobby_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating configuration for {lobby_id}: {e}")
|
||||||
|
return False
|
@ -473,7 +473,9 @@ class WebRTCSignalingClient:
|
|||||||
# Otherwise, use default synthetic tracks.
|
# Otherwise, use default synthetic tracks.
|
||||||
try:
|
try:
|
||||||
if self.create_tracks:
|
if self.create_tracks:
|
||||||
tracks = self.create_tracks(self.session_name)
|
# Use lobby_id as the session identifier for track creation
|
||||||
|
# This ensures config updates can find the tracks using the same key
|
||||||
|
tracks = self.create_tracks(self.lobby_id)
|
||||||
self.local_tracks.update(tracks)
|
self.local_tracks.update(tracks)
|
||||||
else:
|
else:
|
||||||
# Default fallback to synthetic tracks
|
# Default fallback to synthetic tracks
|
||||||
|
Loading…
x
Reference in New Issue
Block a user