581 lines
21 KiB
Python
581 lines
21 KiB
Python
"""
|
|
Minimal Bot - Reference Example
|
|
|
|
A minimal bot that consumes and generates audio/video with configurable visualizations and audio creation.
|
|
"""
|
|
|
|
import numpy as np
|
|
import cv2
|
|
import fractions
|
|
import time
|
|
from av.audio.frame import AudioFrame
|
|
from av import VideoFrame
|
|
from aiortc import MediaStreamTrack
|
|
from typing import Awaitable, Callable, Dict, Optional, Tuple, Any, Union
|
|
from shared.logger import logger
|
|
from shared.models import ChatMessageModel
|
|
|
|
|
|
# Global registry to store active tracks by session
|
|
_active_tracks: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
|
class MediaClock:
|
|
"""Shared clock for media synchronization."""
|
|
|
|
def __init__(self):
|
|
self.t0 = time.perf_counter()
|
|
|
|
def now(self) -> float:
|
|
return time.perf_counter() - self.t0
|
|
|
|
|
|
class ConfigurableVideoTrack(MediaStreamTrack):
|
|
"""Configurable video track with different visualization modes"""
|
|
|
|
kind = "video"
|
|
|
|
def __init__(self, clock: MediaClock, config: Dict[str, Any]):
|
|
"""Initialize the configurable video track.
|
|
|
|
Args:
|
|
clock: Media clock for synchronization
|
|
config: Configuration dictionary for video settings
|
|
"""
|
|
super().__init__()
|
|
self.clock = clock
|
|
self.config = config
|
|
self.width = config.get('width', 320)
|
|
self.height = config.get('height', 240)
|
|
self.fps = config.get('fps', 15)
|
|
self.mode = config.get('visualization', 'ball')
|
|
self.frame_count = 0
|
|
self._start_time = time.time()
|
|
|
|
# Initialize ball attributes (always, regardless of mode)
|
|
self.ball_x = self.width // 2
|
|
self.ball_y = self.height // 2
|
|
self.ball_dx = 2
|
|
self.ball_dy = 2
|
|
self.ball_radius = 20
|
|
|
|
def update_config(self, config_updates: Dict[str, Any]) -> bool:
|
|
"""Update the video track configuration dynamically.
|
|
|
|
Args:
|
|
config_updates: Dictionary of configuration values to update
|
|
|
|
Returns:
|
|
bool: True if update was successful
|
|
"""
|
|
try:
|
|
old_mode = self.mode
|
|
old_width = self.width
|
|
old_height = self.height
|
|
|
|
# Update configuration
|
|
self.config.update(config_updates)
|
|
|
|
# Update instance variables
|
|
if 'width' in config_updates:
|
|
self.width = config_updates['width']
|
|
if 'height' in config_updates:
|
|
self.height = config_updates['height']
|
|
if 'fps' in config_updates:
|
|
self.fps = config_updates['fps']
|
|
if 'visualization' in config_updates:
|
|
self.mode = config_updates['visualization']
|
|
|
|
# Reinitialize mode-specific state if mode changed
|
|
if self.mode != old_mode:
|
|
self._initialize_mode_state()
|
|
logger.info(f"Video mode changed from {old_mode} to {self.mode}")
|
|
|
|
# Reinitialize ball state if dimensions changed (for any mode, since ball is default)
|
|
if self.width != old_width or self.height != old_height:
|
|
self._initialize_ball_state()
|
|
|
|
logger.info(f"Video track configuration updated: {config_updates}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating video track configuration: {e}")
|
|
return False
|
|
|
|
def _initialize_mode_state(self):
|
|
"""Initialize state specific to the current visualization mode."""
|
|
# Always initialize ball state since it's used as default
|
|
self._initialize_ball_state()
|
|
|
|
def _initialize_ball_state(self):
|
|
"""Initialize bouncing ball state."""
|
|
self.ball_x = self.width // 2
|
|
self.ball_y = self.height // 2
|
|
self.ball_dx = 2
|
|
self.ball_dy = 2
|
|
self.ball_radius = min(self.width, self.height) * 0.06
|
|
|
|
async def next_timestamp(self) -> Tuple[int, float]:
|
|
pts = int(self.frame_count * (90000 / self.fps))
|
|
time_base = 1 / 90000
|
|
return pts, time_base
|
|
|
|
async def recv(self) -> VideoFrame:
|
|
pts, time_base = await self.next_timestamp()
|
|
|
|
# Create frame based on mode
|
|
if self.mode == 'ball':
|
|
frame_array = self._generate_ball_frame()
|
|
elif self.mode == 'waveform':
|
|
frame_array = self._generate_waveform_frame()
|
|
elif self.mode == 'static':
|
|
frame_array = self._generate_static_frame()
|
|
else:
|
|
# Default to ball mode if unknown mode
|
|
frame_array = self._generate_ball_frame()
|
|
|
|
frame = VideoFrame.from_ndarray(frame_array, format="bgr24") # type: ignore
|
|
frame.pts = pts
|
|
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
|
|
|
self.frame_count += 1
|
|
return frame
|
|
|
|
def _generate_ball_frame(self) -> Any:
|
|
"""Generate bouncing ball visualization"""
|
|
frame = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
|
|
|
# Update ball position
|
|
self.ball_x += self.ball_dx
|
|
self.ball_y += self.ball_dy
|
|
|
|
# Bounce off walls
|
|
if self.ball_x <= self.ball_radius or self.ball_x >= self.width - self.ball_radius:
|
|
self.ball_dx = -self.ball_dx
|
|
if self.ball_y <= self.ball_radius or self.ball_y >= self.height - self.ball_radius:
|
|
self.ball_dy = -self.ball_dy
|
|
|
|
# Draw ball
|
|
cv2.circle(frame, (int(self.ball_x), int(self.ball_y)), self.ball_radius, (0, 255, 0), -1)
|
|
|
|
# Add timestamp
|
|
timestamp = f"Frame: {self.frame_count}"
|
|
cv2.putText(frame, timestamp, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
|
|
|
return frame
|
|
|
|
def _generate_waveform_frame(self) -> Any:
|
|
"""Generate waveform visualization"""
|
|
frame = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
|
|
|
# Generate sine wave
|
|
x = np.linspace(0, 4*np.pi, self.width)
|
|
y = np.sin(x + self.frame_count * 0.1) * self.height // 4 + self.height // 2
|
|
|
|
# Draw waveform
|
|
for i in range(1, len(y)):
|
|
cv2.line(frame, (i-1, int(y[i-1])), (i, int(y[i])), (255, 255, 255), 2)
|
|
|
|
return frame
|
|
|
|
def _generate_static_frame(self) -> Any:
|
|
"""Generate static color frame"""
|
|
color = self.config.get('static_color', (128, 128, 128))
|
|
frame = np.full((self.height, self.width, 3), color, dtype=np.uint8)
|
|
return frame
|
|
|
|
|
|
class ConfigurableAudioTrack(MediaStreamTrack):
|
|
"""Configurable audio track with different audio generation modes"""
|
|
|
|
kind = "audio"
|
|
|
|
def __init__(self, clock: MediaClock, config: Dict[str, Any]):
|
|
"""Initialize the configurable audio track.
|
|
|
|
Args:
|
|
clock: Media clock for synchronization
|
|
config: Configuration dictionary for audio settings
|
|
"""
|
|
super().__init__()
|
|
self.clock = clock
|
|
self.config = config
|
|
self.sample_rate = config.get('sample_rate', 48000)
|
|
self.samples_per_frame = config.get('samples_per_frame', 960)
|
|
self.mode = config.get('audio_mode', 'tone')
|
|
self.frequency = config.get('frequency', 440.0)
|
|
self.volume = config.get('volume', 0.5)
|
|
self._samples_generated: int = 0
|
|
def update_config(self, config_updates: Dict[str, Any]) -> bool:
|
|
"""Update the audio track configuration dynamically.
|
|
|
|
Args:
|
|
config_updates: Dictionary of configuration values to update
|
|
|
|
Returns:
|
|
bool: True if update was successful
|
|
"""
|
|
try:
|
|
# Update configuration
|
|
self.config.update(config_updates)
|
|
|
|
# Update instance variables
|
|
if 'sample_rate' in config_updates:
|
|
self.sample_rate = config_updates['sample_rate']
|
|
if 'samples_per_frame' in config_updates:
|
|
self.samples_per_frame = config_updates['samples_per_frame']
|
|
if 'audio_mode' in config_updates:
|
|
self.mode = config_updates['audio_mode']
|
|
if 'frequency' in config_updates:
|
|
self.frequency = config_updates['frequency']
|
|
if 'volume' in config_updates:
|
|
self.volume = config_updates['volume']
|
|
|
|
logger.info(f"Audio track configuration updated: {config_updates}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating audio track configuration: {e}")
|
|
return False
|
|
|
|
async def next_timestamp(self) -> Tuple[int, float]:
|
|
pts = self._samples_generated
|
|
time_base = 1 / self.sample_rate
|
|
return pts, time_base
|
|
|
|
async def recv(self) -> AudioFrame:
|
|
pts, time_base = await self.next_timestamp()
|
|
|
|
# Generate audio based on mode
|
|
if self.mode == 'tone':
|
|
samples = self._generate_tone()
|
|
elif self.mode == 'noise':
|
|
samples = self._generate_noise()
|
|
elif self.mode == 'silence':
|
|
samples = self._generate_silence()
|
|
else:
|
|
samples = self._generate_tone() # default
|
|
|
|
# Convert to stereo
|
|
left = (samples * self.volume * 32767).astype(np.int16)
|
|
right = left.copy()
|
|
|
|
# Interleave channels
|
|
interleaved = np.empty(self.samples_per_frame * 2, dtype=np.int16) # type: ignore
|
|
interleaved[0::2] = left
|
|
interleaved[1::2] = right
|
|
|
|
stereo = interleaved.reshape(1, -1) # type: ignore
|
|
|
|
frame = AudioFrame.from_ndarray(stereo, format="s16", layout="stereo") # type: ignore
|
|
frame.sample_rate = self.sample_rate
|
|
frame.pts = pts
|
|
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
|
|
|
self._samples_generated += self.samples_per_frame
|
|
return frame
|
|
|
|
def _generate_tone(self) -> Any:
|
|
"""Generate sine wave tone"""
|
|
t = (np.arange(self.samples_per_frame) + self._samples_generated) / self.sample_rate # type: ignore
|
|
return np.sin(2 * np.pi * self.frequency * t).astype(np.float32) # type: ignore
|
|
|
|
def _generate_noise(self) -> Any:
|
|
"""Generate white noise"""
|
|
return np.random.uniform(-1, 1, self.samples_per_frame).astype(np.float32) # type: ignore
|
|
|
|
def _generate_silence(self) -> Any:
|
|
"""Generate silence"""
|
|
return np.zeros(self.samples_per_frame, dtype=np.float32) # type: ignore
|
|
|
|
|
|
def create_minimal_bot_tracks(session_name: str, config: Optional[Dict[str, Any]] = None) -> Dict[str, MediaStreamTrack]:
|
|
"""
|
|
Create minimal bot tracks with configurable audio/video generation.
|
|
|
|
Args:
|
|
session_name: Name for the session
|
|
config: Configuration dictionary with options:
|
|
- visualization: 'ball', 'waveform', 'static'
|
|
- audio_mode: 'tone', 'noise', 'silence'
|
|
- width: video width (default 320)
|
|
- height: video height (default 240)
|
|
- fps: frames per second (default 15)
|
|
- sample_rate: audio sample rate (default 48000)
|
|
- frequency: tone frequency in Hz (default 440)
|
|
- volume: audio volume 0-1 (default 0.5)
|
|
- static_color: RGB color tuple or string (default (128, 128, 128))
|
|
|
|
Returns:
|
|
Dictionary containing 'video' and 'audio' tracks
|
|
"""
|
|
if config is None:
|
|
config = {}
|
|
|
|
# Set defaults
|
|
default_config = { # type: ignore
|
|
'visualization': 'ball',
|
|
'audio_mode': 'tone',
|
|
'width': 320,
|
|
'height': 240,
|
|
'fps': 15,
|
|
'sample_rate': 48000,
|
|
'samples_per_frame': 960,
|
|
'frequency': 440.0,
|
|
'volume': 0.5,
|
|
'static_color': (128, 128, 128)
|
|
}
|
|
default_config.update(config) # type: ignore
|
|
|
|
# Parse static_color if it's a string
|
|
if isinstance(default_config.get('static_color'), str): # type: ignore
|
|
try:
|
|
color_str = default_config['static_color'] # type: ignore
|
|
r, g, b = map(int, color_str.split(',')) # type: ignore
|
|
default_config['static_color'] = (r, g, b)
|
|
except (ValueError, TypeError):
|
|
logger.warning(f"Invalid static_color format: {default_config.get('static_color')}, using default") # type: ignore
|
|
default_config['static_color'] = (128, 128, 128)
|
|
|
|
media_clock = MediaClock()
|
|
|
|
video_track = ConfigurableVideoTrack(media_clock, default_config) # type: ignore
|
|
audio_track = ConfigurableAudioTrack(media_clock, default_config) # type: ignore
|
|
|
|
logger.info(f"Created minimal bot tracks for {session_name} with config: {default_config}")
|
|
|
|
# Store tracks in global registry for dynamic updates
|
|
_active_tracks[session_name] = {"video": video_track, "audio": audio_track}
|
|
|
|
return {"video": video_track, "audio": audio_track}
|
|
|
|
|
|
# Agent descriptor exported for dynamic discovery by the FastAPI service
|
|
AGENT_NAME = "Minimal Configurable Bot"
|
|
AGENT_DESCRIPTION = "Minimal bot with configurable audio/video generation modes"
|
|
|
|
def agent_info() -> Dict[str, str]:
|
|
"""Return agent metadata for discovery."""
|
|
return {
|
|
"name": AGENT_NAME,
|
|
"description": AGENT_DESCRIPTION,
|
|
"has_media": "true",
|
|
"configurable": "true"
|
|
}
|
|
|
|
|
|
def get_config_schema() -> Dict[str, Any]:
|
|
"""Get the configuration schema for the Minimal Configurable Bot.
|
|
|
|
Returns a schema that defines all configurable parameters for the bot,
|
|
allowing frontend applications to dynamically generate configuration UIs.
|
|
"""
|
|
return {
|
|
"bot_name": AGENT_NAME,
|
|
"version": "1.0",
|
|
"parameters": [
|
|
{
|
|
"name": "visualization",
|
|
"type": "select",
|
|
"label": "Video Visualization Mode",
|
|
"description": "Choose the type of video visualization to display",
|
|
"default_value": "ball",
|
|
"required": True,
|
|
"options": [
|
|
{"value": "ball", "label": "Bouncing Ball"},
|
|
{"value": "waveform", "label": "Sine Wave Animation"},
|
|
{"value": "static", "label": "Static Color Frame"}
|
|
]
|
|
},
|
|
{
|
|
"name": "audio_mode",
|
|
"type": "select",
|
|
"label": "Audio Generation Mode",
|
|
"description": "Choose the type of audio to generate",
|
|
"default_value": "tone",
|
|
"required": True,
|
|
"options": [
|
|
{"value": "tone", "label": "Sine Wave Tone"},
|
|
{"value": "noise", "label": "White Noise"},
|
|
{"value": "silence", "label": "Silence"}
|
|
]
|
|
},
|
|
{
|
|
"name": "width",
|
|
"type": "number",
|
|
"label": "Video Width",
|
|
"description": "Width of the video frame in pixels",
|
|
"default_value": 320,
|
|
"required": False,
|
|
"min_value": 160,
|
|
"max_value": 1920,
|
|
"step": 1
|
|
},
|
|
{
|
|
"name": "height",
|
|
"type": "number",
|
|
"label": "Video Height",
|
|
"description": "Height of the video frame in pixels",
|
|
"default_value": 240,
|
|
"required": False,
|
|
"min_value": 120,
|
|
"max_value": 1080,
|
|
"step": 1
|
|
},
|
|
{
|
|
"name": "fps",
|
|
"type": "number",
|
|
"label": "Frames Per Second",
|
|
"description": "Video frame rate",
|
|
"default_value": 15,
|
|
"required": False,
|
|
"min_value": 1,
|
|
"max_value": 60,
|
|
"step": 1
|
|
},
|
|
{
|
|
"name": "sample_rate",
|
|
"type": "number",
|
|
"label": "Audio Sample Rate",
|
|
"description": "Audio sample rate in Hz",
|
|
"default_value": 48000,
|
|
"required": False,
|
|
"min_value": 8000,
|
|
"max_value": 96000,
|
|
"step": 1000
|
|
},
|
|
{
|
|
"name": "frequency",
|
|
"type": "number",
|
|
"label": "Audio Frequency (Hz)",
|
|
"description": "Frequency of the generated tone in Hz",
|
|
"default_value": 440.0,
|
|
"required": False,
|
|
"min_value": 20.0,
|
|
"max_value": 20000.0,
|
|
"step": 1.0
|
|
},
|
|
{
|
|
"name": "volume",
|
|
"type": "range",
|
|
"label": "Audio Volume",
|
|
"description": "Volume level (0.0 to 1.0)",
|
|
"default_value": 0.5,
|
|
"required": False,
|
|
"min_value": 0.0,
|
|
"max_value": 1.0,
|
|
"step": 0.1
|
|
},
|
|
{
|
|
"name": "static_color",
|
|
"type": "string",
|
|
"label": "Static Color (RGB)",
|
|
"description": "RGB color tuple for static mode (e.g., '128,128,128')",
|
|
"default_value": "128,128,128",
|
|
"required": False,
|
|
"pattern": r"^\d{1,3},\d{1,3},\d{1,3}$",
|
|
"max_length": 11
|
|
}
|
|
],
|
|
"categories": [
|
|
{
|
|
"Video Settings": ["visualization", "width", "height", "fps"]
|
|
},
|
|
{
|
|
"Audio Settings": ["audio_mode", "sample_rate", "frequency", "volume"]
|
|
},
|
|
{
|
|
"Advanced": ["static_color"]
|
|
}
|
|
]
|
|
}
|
|
|
|
|
|
def create_agent_tracks(session_name: str) -> Dict[str, MediaStreamTrack]:
|
|
"""Factory wrapper used by the FastAPI service to instantiate tracks for an agent."""
|
|
return create_minimal_bot_tracks(session_name)
|
|
|
|
|
|
|
|
async def handle_chat_message(
|
|
chat_message: ChatMessageModel,
|
|
send_message_func: Callable[[Union[str, ChatMessageModel]], Awaitable[None]]
|
|
) -> Optional[str]:
|
|
"""Handle chat messages."""
|
|
logger.info(f"Received chat message: {chat_message.message}")
|
|
return None
|
|
|
|
def handle_config_update(lobby_id: str, config_values: Dict[str, Any]) -> bool:
|
|
"""
|
|
Handle runtime configuration updates for the minimal bot.
|
|
|
|
Args:
|
|
lobby_id: ID of the lobby/bot instance
|
|
config_values: Dictionary of configuration values to update
|
|
|
|
Returns:
|
|
bool: True if update was successful, False otherwise
|
|
"""
|
|
try:
|
|
# Validate configuration values
|
|
validated_config = {}
|
|
|
|
# Parse and validate static_color if provided
|
|
if 'static_color' in config_values:
|
|
color_value = config_values['static_color']
|
|
if isinstance(color_value, str):
|
|
try:
|
|
r, g, b = map(int, color_value.split(',')) # type: ignore
|
|
validated_config['static_color'] = (r, g, b)
|
|
except (ValueError, TypeError):
|
|
logger.warning(f"Invalid static_color format: {color_value}, ignoring update")
|
|
return False
|
|
elif isinstance(color_value, (tuple, list)) and len(color_value) == 3: # type: ignore
|
|
validated_config['static_color'] = tuple(color_value) # type: ignore
|
|
else:
|
|
logger.warning(f"Invalid static_color type: {type(color_value)}, ignoring update") # type: ignore
|
|
return False
|
|
|
|
# Copy other valid configuration values
|
|
valid_keys = {'visualization', 'audio_mode', 'width', 'height', 'fps',
|
|
'sample_rate', 'frequency', 'volume'}
|
|
for key in valid_keys:
|
|
if key in config_values:
|
|
validated_config[key] = config_values[key]
|
|
|
|
if validated_config:
|
|
logger.info(f"Configuration updated for {lobby_id}: {validated_config}")
|
|
|
|
# Update running tracks if they exist in the registry
|
|
if lobby_id in _active_tracks:
|
|
tracks = _active_tracks[lobby_id]
|
|
video_track = tracks.get("video")
|
|
audio_track = tracks.get("audio")
|
|
|
|
# Update video track configuration
|
|
if video_track and hasattr(video_track, 'update_config'):
|
|
video_updated = video_track.update_config(validated_config)
|
|
if video_updated:
|
|
logger.info(f"Video track configuration updated for {lobby_id}")
|
|
else:
|
|
logger.warning(f"Failed to update video track configuration for {lobby_id}")
|
|
|
|
# Update audio track configuration
|
|
if audio_track and hasattr(audio_track, 'update_config'):
|
|
audio_updated = audio_track.update_config(validated_config)
|
|
if audio_updated:
|
|
logger.info(f"Audio track configuration updated for {lobby_id}")
|
|
else:
|
|
logger.warning(f"Failed to update audio track configuration for {lobby_id}")
|
|
|
|
return True
|
|
else:
|
|
logger.warning(f"No active tracks found for session {lobby_id}")
|
|
return False
|
|
else:
|
|
logger.warning(f"No valid configuration values provided for {lobby_id}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating configuration for {lobby_id}: {e}")
|
|
return False |