2025-09-17 15:05:16 -07:00

581 lines
21 KiB
Python

"""
Minimal Bot - Reference Example
A minimal bot that consumes and generates audio/video with configurable visualizations and audio creation.
"""
import numpy as np
import cv2
import fractions
import time
from av.audio.frame import AudioFrame
from av import VideoFrame
from aiortc import MediaStreamTrack
from typing import Awaitable, Callable, Dict, Optional, Tuple, Any, Union
from shared.logger import logger
from shared.models import ChatMessageModel
# Global registry to store active tracks by session
_active_tracks: Dict[str, Dict[str, Any]] = {}
class MediaClock:
"""Shared clock for media synchronization."""
def __init__(self):
self.t0 = time.perf_counter()
def now(self) -> float:
return time.perf_counter() - self.t0
class ConfigurableVideoTrack(MediaStreamTrack):
"""Configurable video track with different visualization modes"""
kind = "video"
def __init__(self, clock: MediaClock, config: Dict[str, Any]):
"""Initialize the configurable video track.
Args:
clock: Media clock for synchronization
config: Configuration dictionary for video settings
"""
super().__init__()
self.clock = clock
self.config = config
self.width = config.get('width', 320)
self.height = config.get('height', 240)
self.fps = config.get('fps', 15)
self.mode = config.get('visualization', 'ball')
self.frame_count = 0
self._start_time = time.time()
# Initialize ball attributes (always, regardless of mode)
self.ball_x = self.width // 2
self.ball_y = self.height // 2
self.ball_dx = 2
self.ball_dy = 2
self.ball_radius = 20
def update_config(self, config_updates: Dict[str, Any]) -> bool:
"""Update the video track configuration dynamically.
Args:
config_updates: Dictionary of configuration values to update
Returns:
bool: True if update was successful
"""
try:
old_mode = self.mode
old_width = self.width
old_height = self.height
# Update configuration
self.config.update(config_updates)
# Update instance variables
if 'width' in config_updates:
self.width = config_updates['width']
if 'height' in config_updates:
self.height = config_updates['height']
if 'fps' in config_updates:
self.fps = config_updates['fps']
if 'visualization' in config_updates:
self.mode = config_updates['visualization']
# Reinitialize mode-specific state if mode changed
if self.mode != old_mode:
self._initialize_mode_state()
logger.info(f"Video mode changed from {old_mode} to {self.mode}")
# Reinitialize ball state if dimensions changed (for any mode, since ball is default)
if self.width != old_width or self.height != old_height:
self._initialize_ball_state()
logger.info(f"Video track configuration updated: {config_updates}")
return True
except Exception as e:
logger.error(f"Error updating video track configuration: {e}")
return False
def _initialize_mode_state(self):
"""Initialize state specific to the current visualization mode."""
# Always initialize ball state since it's used as default
self._initialize_ball_state()
def _initialize_ball_state(self):
"""Initialize bouncing ball state."""
self.ball_x = self.width // 2
self.ball_y = self.height // 2
self.ball_dx = 2
self.ball_dy = 2
self.ball_radius = min(self.width, self.height) * 0.06
async def next_timestamp(self) -> Tuple[int, float]:
pts = int(self.frame_count * (90000 / self.fps))
time_base = 1 / 90000
return pts, time_base
async def recv(self) -> VideoFrame:
pts, time_base = await self.next_timestamp()
# Create frame based on mode
if self.mode == 'ball':
frame_array = self._generate_ball_frame()
elif self.mode == 'waveform':
frame_array = self._generate_waveform_frame()
elif self.mode == 'static':
frame_array = self._generate_static_frame()
else:
# Default to ball mode if unknown mode
frame_array = self._generate_ball_frame()
frame = VideoFrame.from_ndarray(frame_array, format="bgr24") # type: ignore
frame.pts = pts
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
self.frame_count += 1
return frame
def _generate_ball_frame(self) -> Any:
"""Generate bouncing ball visualization"""
frame = np.zeros((self.height, self.width, 3), dtype=np.uint8)
# Update ball position
self.ball_x += self.ball_dx
self.ball_y += self.ball_dy
# Bounce off walls
if self.ball_x <= self.ball_radius or self.ball_x >= self.width - self.ball_radius:
self.ball_dx = -self.ball_dx
if self.ball_y <= self.ball_radius or self.ball_y >= self.height - self.ball_radius:
self.ball_dy = -self.ball_dy
# Draw ball
cv2.circle(frame, (int(self.ball_x), int(self.ball_y)), self.ball_radius, (0, 255, 0), -1)
# Add timestamp
timestamp = f"Frame: {self.frame_count}"
cv2.putText(frame, timestamp, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return frame
def _generate_waveform_frame(self) -> Any:
"""Generate waveform visualization"""
frame = np.zeros((self.height, self.width, 3), dtype=np.uint8)
# Generate sine wave
x = np.linspace(0, 4*np.pi, self.width)
y = np.sin(x + self.frame_count * 0.1) * self.height // 4 + self.height // 2
# Draw waveform
for i in range(1, len(y)):
cv2.line(frame, (i-1, int(y[i-1])), (i, int(y[i])), (255, 255, 255), 2)
return frame
def _generate_static_frame(self) -> Any:
"""Generate static color frame"""
color = self.config.get('static_color', (128, 128, 128))
frame = np.full((self.height, self.width, 3), color, dtype=np.uint8)
return frame
class ConfigurableAudioTrack(MediaStreamTrack):
"""Configurable audio track with different audio generation modes"""
kind = "audio"
def __init__(self, clock: MediaClock, config: Dict[str, Any]):
"""Initialize the configurable audio track.
Args:
clock: Media clock for synchronization
config: Configuration dictionary for audio settings
"""
super().__init__()
self.clock = clock
self.config = config
self.sample_rate = config.get('sample_rate', 48000)
self.samples_per_frame = config.get('samples_per_frame', 960)
self.mode = config.get('audio_mode', 'tone')
self.frequency = config.get('frequency', 440.0)
self.volume = config.get('volume', 0.5)
self._samples_generated: int = 0
def update_config(self, config_updates: Dict[str, Any]) -> bool:
"""Update the audio track configuration dynamically.
Args:
config_updates: Dictionary of configuration values to update
Returns:
bool: True if update was successful
"""
try:
# Update configuration
self.config.update(config_updates)
# Update instance variables
if 'sample_rate' in config_updates:
self.sample_rate = config_updates['sample_rate']
if 'samples_per_frame' in config_updates:
self.samples_per_frame = config_updates['samples_per_frame']
if 'audio_mode' in config_updates:
self.mode = config_updates['audio_mode']
if 'frequency' in config_updates:
self.frequency = config_updates['frequency']
if 'volume' in config_updates:
self.volume = config_updates['volume']
logger.info(f"Audio track configuration updated: {config_updates}")
return True
except Exception as e:
logger.error(f"Error updating audio track configuration: {e}")
return False
async def next_timestamp(self) -> Tuple[int, float]:
pts = self._samples_generated
time_base = 1 / self.sample_rate
return pts, time_base
async def recv(self) -> AudioFrame:
pts, time_base = await self.next_timestamp()
# Generate audio based on mode
if self.mode == 'tone':
samples = self._generate_tone()
elif self.mode == 'noise':
samples = self._generate_noise()
elif self.mode == 'silence':
samples = self._generate_silence()
else:
samples = self._generate_tone() # default
# Convert to stereo
left = (samples * self.volume * 32767).astype(np.int16)
right = left.copy()
# Interleave channels
interleaved = np.empty(self.samples_per_frame * 2, dtype=np.int16) # type: ignore
interleaved[0::2] = left
interleaved[1::2] = right
stereo = interleaved.reshape(1, -1) # type: ignore
frame = AudioFrame.from_ndarray(stereo, format="s16", layout="stereo") # type: ignore
frame.sample_rate = self.sample_rate
frame.pts = pts
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
self._samples_generated += self.samples_per_frame
return frame
def _generate_tone(self) -> Any:
"""Generate sine wave tone"""
t = (np.arange(self.samples_per_frame) + self._samples_generated) / self.sample_rate # type: ignore
return np.sin(2 * np.pi * self.frequency * t).astype(np.float32) # type: ignore
def _generate_noise(self) -> Any:
"""Generate white noise"""
return np.random.uniform(-1, 1, self.samples_per_frame).astype(np.float32) # type: ignore
def _generate_silence(self) -> Any:
"""Generate silence"""
return np.zeros(self.samples_per_frame, dtype=np.float32) # type: ignore
def create_minimal_bot_tracks(session_name: str, config: Optional[Dict[str, Any]] = None) -> Dict[str, MediaStreamTrack]:
"""
Create minimal bot tracks with configurable audio/video generation.
Args:
session_name: Name for the session
config: Configuration dictionary with options:
- visualization: 'ball', 'waveform', 'static'
- audio_mode: 'tone', 'noise', 'silence'
- width: video width (default 320)
- height: video height (default 240)
- fps: frames per second (default 15)
- sample_rate: audio sample rate (default 48000)
- frequency: tone frequency in Hz (default 440)
- volume: audio volume 0-1 (default 0.5)
- static_color: RGB color tuple or string (default (128, 128, 128))
Returns:
Dictionary containing 'video' and 'audio' tracks
"""
if config is None:
config = {}
# Set defaults
default_config = { # type: ignore
'visualization': 'ball',
'audio_mode': 'tone',
'width': 320,
'height': 240,
'fps': 15,
'sample_rate': 48000,
'samples_per_frame': 960,
'frequency': 440.0,
'volume': 0.5,
'static_color': (128, 128, 128)
}
default_config.update(config) # type: ignore
# Parse static_color if it's a string
if isinstance(default_config.get('static_color'), str): # type: ignore
try:
color_str = default_config['static_color'] # type: ignore
r, g, b = map(int, color_str.split(',')) # type: ignore
default_config['static_color'] = (r, g, b)
except (ValueError, TypeError):
logger.warning(f"Invalid static_color format: {default_config.get('static_color')}, using default") # type: ignore
default_config['static_color'] = (128, 128, 128)
media_clock = MediaClock()
video_track = ConfigurableVideoTrack(media_clock, default_config) # type: ignore
audio_track = ConfigurableAudioTrack(media_clock, default_config) # type: ignore
logger.info(f"Created minimal bot tracks for {session_name} with config: {default_config}")
# Store tracks in global registry for dynamic updates
_active_tracks[session_name] = {"video": video_track, "audio": audio_track}
return {"video": video_track, "audio": audio_track}
# Agent descriptor exported for dynamic discovery by the FastAPI service
AGENT_NAME = "Minimal Configurable Bot"
AGENT_DESCRIPTION = "Minimal bot with configurable audio/video generation modes"
def agent_info() -> Dict[str, str]:
"""Return agent metadata for discovery."""
return {
"name": AGENT_NAME,
"description": AGENT_DESCRIPTION,
"has_media": "true",
"configurable": "true"
}
def get_config_schema() -> Dict[str, Any]:
"""Get the configuration schema for the Minimal Configurable Bot.
Returns a schema that defines all configurable parameters for the bot,
allowing frontend applications to dynamically generate configuration UIs.
"""
return {
"bot_name": AGENT_NAME,
"version": "1.0",
"parameters": [
{
"name": "visualization",
"type": "select",
"label": "Video Visualization Mode",
"description": "Choose the type of video visualization to display",
"default_value": "ball",
"required": True,
"options": [
{"value": "ball", "label": "Bouncing Ball"},
{"value": "waveform", "label": "Sine Wave Animation"},
{"value": "static", "label": "Static Color Frame"}
]
},
{
"name": "audio_mode",
"type": "select",
"label": "Audio Generation Mode",
"description": "Choose the type of audio to generate",
"default_value": "tone",
"required": True,
"options": [
{"value": "tone", "label": "Sine Wave Tone"},
{"value": "noise", "label": "White Noise"},
{"value": "silence", "label": "Silence"}
]
},
{
"name": "width",
"type": "number",
"label": "Video Width",
"description": "Width of the video frame in pixels",
"default_value": 320,
"required": False,
"min_value": 160,
"max_value": 1920,
"step": 1
},
{
"name": "height",
"type": "number",
"label": "Video Height",
"description": "Height of the video frame in pixels",
"default_value": 240,
"required": False,
"min_value": 120,
"max_value": 1080,
"step": 1
},
{
"name": "fps",
"type": "number",
"label": "Frames Per Second",
"description": "Video frame rate",
"default_value": 15,
"required": False,
"min_value": 1,
"max_value": 60,
"step": 1
},
{
"name": "sample_rate",
"type": "number",
"label": "Audio Sample Rate",
"description": "Audio sample rate in Hz",
"default_value": 48000,
"required": False,
"min_value": 8000,
"max_value": 96000,
"step": 1000
},
{
"name": "frequency",
"type": "number",
"label": "Audio Frequency (Hz)",
"description": "Frequency of the generated tone in Hz",
"default_value": 440.0,
"required": False,
"min_value": 20.0,
"max_value": 20000.0,
"step": 1.0
},
{
"name": "volume",
"type": "range",
"label": "Audio Volume",
"description": "Volume level (0.0 to 1.0)",
"default_value": 0.5,
"required": False,
"min_value": 0.0,
"max_value": 1.0,
"step": 0.1
},
{
"name": "static_color",
"type": "string",
"label": "Static Color (RGB)",
"description": "RGB color tuple for static mode (e.g., '128,128,128')",
"default_value": "128,128,128",
"required": False,
"pattern": r"^\d{1,3},\d{1,3},\d{1,3}$",
"max_length": 11
}
],
"categories": [
{
"Video Settings": ["visualization", "width", "height", "fps"]
},
{
"Audio Settings": ["audio_mode", "sample_rate", "frequency", "volume"]
},
{
"Advanced": ["static_color"]
}
]
}
def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
"""Factory wrapper used by the FastAPI service to instantiate tracks for an agent."""
return create_minimal_bot_tracks(session_name)
async def handle_chat_message(
chat_message: ChatMessageModel,
send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]
) -> Optional[str]:
"""Handle chat messages."""
logger.info(f"Received chat message: {chat_message.message}")
return None
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
"""
Handle runtime configuration updates for the minimal bot.
Args:
lobby_id: ID of the lobby/bot instance
config_values: Dictionary of configuration values to update
Returns:
bool: True if update was successful, False otherwise
"""
try:
# Validate configuration values
validated_config = {}
# Parse and validate static_color if provided
if 'static_color' in config_values:
color_value = config_values['static_color']
if isinstance(color_value, str):
try:
r, g, b = map(int, color_value.split(',')) # type: ignore
validated_config['static_color'] = (r, g, b)
except (ValueError, TypeError):
logger.warning(f"Invalid static_color format: {color_value}, ignoring update")
return False
elif isinstance(color_value, (tuple, list)) and len(color_value) == 3: # type: ignore
validated_config['static_color'] = tuple(color_value) # type: ignore
else:
logger.warning(f"Invalid static_color type: {type(color_value)}, ignoring update") # type: ignore
return False
# Copy other valid configuration values
valid_keys = {'visualization', 'audio_mode', 'width', 'height', 'fps',
'sample_rate', 'frequency', 'volume'}
for key in valid_keys:
if key in config_values:
validated_config[key] = config_values[key]
if validated_config:
logger.info(f"Configuration updated for {lobby_id}: {validated_config}")
# Update running tracks if they exist in the registry
if lobby_id in _active_tracks:
tracks = _active_tracks[lobby_id]
video_track = tracks.get("video")
audio_track = tracks.get("audio")
# Update video track configuration
if video_track and hasattr(video_track, 'update_config'):
video_updated = video_track.update_config(validated_config)
if video_updated:
logger.info(f"Video track configuration updated for {lobby_id}")
else:
logger.warning(f"Failed to update video track configuration for {lobby_id}")
# Update audio track configuration
if audio_track and hasattr(audio_track, 'update_config'):
audio_updated = audio_track.update_config(validated_config)
if audio_updated:
logger.info(f"Audio track configuration updated for {lobby_id}")
else:
logger.warning(f"Failed to update audio track configuration for {lobby_id}")
return True
else:
logger.warning(f"No active tracks found for session {lobby_id}")
return False
else:
logger.warning(f"No valid configuration values provided for {lobby_id}")
return False
except Exception as e:
logger.error(f"Error updating configuration for {lobby_id}: {e}")
return False