Refactoring synthetic audio/video to its own file
This commit is contained in:
parent
2910789c86
commit
7c5616fbd9
187
voicebot/main.py
187
voicebot/main.py
@ -1,8 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
WebRTC Media Agent for Python
|
WebRTC Media Agent for Python
|
||||||
|
|
||||||
This module provides synthetic audio/video track creation and WebRTC signaling
|
This module provides WebRTC signaling server communication and peer connection management.
|
||||||
server communication, ported from the JavaScript MediaControl implementation.
|
Synthetic audio/video track creation is handled by the synthetic_media module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -10,9 +10,6 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import websockets
|
import websockets
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import fractions
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Optional,
|
Optional,
|
||||||
@ -57,12 +54,12 @@ from aiortc import (
|
|||||||
RTCIceCandidate,
|
RTCIceCandidate,
|
||||||
MediaStreamTrack,
|
MediaStreamTrack,
|
||||||
)
|
)
|
||||||
from av import VideoFrame, AudioFrame
|
|
||||||
import time
|
|
||||||
from logger import logger
|
from logger import logger
|
||||||
|
from synthetic_media import create_synthetic_tracks
|
||||||
|
|
||||||
# import debug_aioice
|
# import debug_aioice
|
||||||
|
|
||||||
|
|
||||||
# TypedDict for ICE candidate payloads received from signalling
|
# TypedDict for ICE candidate payloads received from signalling
|
||||||
class ICECandidateDict(TypedDict, total=False):
|
class ICECandidateDict(TypedDict, total=False):
|
||||||
candidate: str
|
candidate: str
|
||||||
@ -108,7 +105,6 @@ class IceCandidatePayload(TypedDict):
|
|||||||
candidate: ICECandidateDict
|
candidate: ICECandidateDict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketProtocol(Protocol):
|
class WebSocketProtocol(Protocol):
|
||||||
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
|
def send(self, message: object, text: Optional[bool] = None) -> Awaitable[None]: ...
|
||||||
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
|
def close(self, code: int = 1000, reason: str = "") -> Awaitable[None]: ...
|
||||||
@ -134,171 +130,6 @@ class Peer:
|
|||||||
connection: Optional[RTCPeerConnection] = None
|
connection: Optional[RTCPeerConnection] = None
|
||||||
|
|
||||||
|
|
||||||
class AnimatedVideoTrack(MediaStreamTrack):
|
|
||||||
async def next_timestamp(self):
|
|
||||||
# Returns (pts, time_base) for 15 FPS video
|
|
||||||
pts = int(self.frame_count * (1 / 15) * 90000)
|
|
||||||
time_base = 1 / 90000
|
|
||||||
return pts, time_base
|
|
||||||
|
|
||||||
"""
|
|
||||||
Synthetic video track that generates animated content with a bouncing ball.
|
|
||||||
Ported from JavaScript createAnimatedVideoTrack function.
|
|
||||||
"""
|
|
||||||
kind = "video"
|
|
||||||
|
|
||||||
def __init__(self, width: int = 320, height: int = 240, name: str = ""):
|
|
||||||
super().__init__()
|
|
||||||
self.width = width
|
|
||||||
self.height = height
|
|
||||||
self.name = name
|
|
||||||
|
|
||||||
# Generate color from name hash (similar to JavaScript nameToColor)
|
|
||||||
self.ball_color = (
|
|
||||||
self._name_to_color(name) if name else (0, 255, 136)
|
|
||||||
) # Default green
|
|
||||||
|
|
||||||
# Ball properties
|
|
||||||
self.ball = {
|
|
||||||
"x": width / 2,
|
|
||||||
"y": height / 2,
|
|
||||||
"radius": min(width, height) * 0.06,
|
|
||||||
"dx": 3.0,
|
|
||||||
"dy": 2.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.frame_count = 0
|
|
||||||
self._start_time = time.time()
|
|
||||||
|
|
||||||
def _name_to_color(self, name: str) -> tuple[int, int, int]:
|
|
||||||
"""Convert name to HSL color, then to RGB tuple"""
|
|
||||||
# Simple hash function (djb2)
|
|
||||||
hash_value = 5381
|
|
||||||
for char in name:
|
|
||||||
hash_value = ((hash_value << 5) + hash_value + ord(char)) & 0xFFFFFFFF
|
|
||||||
|
|
||||||
# Generate HSL color from hash
|
|
||||||
hue = abs(hash_value) % 360
|
|
||||||
sat = 60 + (abs(hash_value) % 30) # 60-89%
|
|
||||||
light = 45 + (abs(hash_value) % 30) # 45-74%
|
|
||||||
|
|
||||||
# Convert HSL to RGB
|
|
||||||
h = hue / 360.0
|
|
||||||
s = sat / 100.0
|
|
||||||
lightness = light / 100.0
|
|
||||||
|
|
||||||
c = (1 - abs(2 * lightness - 1)) * s
|
|
||||||
x = c * (1 - abs((h * 6) % 2 - 1))
|
|
||||||
m = lightness - c / 2
|
|
||||||
|
|
||||||
if h < 1 / 6:
|
|
||||||
r, g, b = c, x, 0
|
|
||||||
elif h < 2 / 6:
|
|
||||||
r, g, b = x, c, 0
|
|
||||||
elif h < 3 / 6:
|
|
||||||
r, g, b = 0, c, x
|
|
||||||
elif h < 4 / 6:
|
|
||||||
r, g, b = 0, x, c
|
|
||||||
elif h < 5 / 6:
|
|
||||||
r, g, b = x, 0, c
|
|
||||||
else:
|
|
||||||
r, g, b = c, 0, x
|
|
||||||
|
|
||||||
return (
|
|
||||||
int((b + m) * 255),
|
|
||||||
int((g + m) * 255),
|
|
||||||
int((r + m) * 255),
|
|
||||||
) # BGR for OpenCV
|
|
||||||
|
|
||||||
async def recv(self):
|
|
||||||
"""Generate video frames at 15 FPS"""
|
|
||||||
pts, time_base = await self.next_timestamp()
|
|
||||||
|
|
||||||
# Create black background
|
|
||||||
frame_array = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
|
||||||
|
|
||||||
# Update ball position
|
|
||||||
ball = self.ball
|
|
||||||
ball["x"] += ball["dx"]
|
|
||||||
ball["y"] += ball["dy"]
|
|
||||||
|
|
||||||
# Bounce off walls
|
|
||||||
if ball["x"] + ball["radius"] >= self.width or ball["x"] - ball["radius"] <= 0:
|
|
||||||
ball["dx"] = -ball["dx"]
|
|
||||||
if ball["y"] + ball["radius"] >= self.height or ball["y"] - ball["radius"] <= 0:
|
|
||||||
ball["dy"] = -ball["dy"]
|
|
||||||
|
|
||||||
# Keep ball in bounds
|
|
||||||
ball["x"] = max(ball["radius"], min(self.width - ball["radius"], ball["x"]))
|
|
||||||
ball["y"] = max(ball["radius"], min(self.height - ball["radius"], ball["y"]))
|
|
||||||
|
|
||||||
# Draw ball
|
|
||||||
cv2.circle(
|
|
||||||
frame_array,
|
|
||||||
(int(ball["x"]), int(ball["y"])),
|
|
||||||
int(ball["radius"]),
|
|
||||||
self.ball_color,
|
|
||||||
-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add frame counter text
|
|
||||||
frame_text = f"Frame: {int(time.time() * 1000) % 10000}"
|
|
||||||
# logger.info(frame_text)
|
|
||||||
cv2.putText(
|
|
||||||
frame_array,
|
|
||||||
frame_text,
|
|
||||||
(10, 20),
|
|
||||||
cv2.FONT_HERSHEY_SIMPLEX,
|
|
||||||
0.5,
|
|
||||||
(255, 255, 255),
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to VideoFrame
|
|
||||||
frame = VideoFrame.from_ndarray(frame_array, format="bgr24")
|
|
||||||
frame.pts = pts
|
|
||||||
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
|
||||||
|
|
||||||
self.frame_count += 1
|
|
||||||
return frame
|
|
||||||
|
|
||||||
|
|
||||||
class SilentAudioTrack(MediaStreamTrack):
|
|
||||||
async def next_timestamp(self):
|
|
||||||
# Returns (pts, time_base) for 20ms audio frames at 48kHz
|
|
||||||
pts = int(time.time() * self.sample_rate)
|
|
||||||
time_base = 1 / self.sample_rate
|
|
||||||
return pts, time_base
|
|
||||||
|
|
||||||
"""
|
|
||||||
Synthetic audio track that generates silence.
|
|
||||||
Ported from JavaScript createSilentAudioTrack function.
|
|
||||||
"""
|
|
||||||
kind = "audio"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.sample_rate = 48000
|
|
||||||
self.samples_per_frame = 960 # 20ms at 48kHz
|
|
||||||
|
|
||||||
async def recv(self):
|
|
||||||
"""Generate silent audio frames"""
|
|
||||||
pts, time_base = await self.next_timestamp()
|
|
||||||
|
|
||||||
# Create silent audio data in s16 format (required by Opus encoder)
|
|
||||||
samples = np.zeros((self.samples_per_frame,), dtype=np.int16)
|
|
||||||
|
|
||||||
# Convert to AudioFrame
|
|
||||||
frame = AudioFrame.from_ndarray(
|
|
||||||
samples.reshape(1, -1), format="s16", layout="mono"
|
|
||||||
)
|
|
||||||
frame.sample_rate = self.sample_rate
|
|
||||||
frame.pts = pts
|
|
||||||
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
|
||||||
|
|
||||||
return frame
|
|
||||||
|
|
||||||
|
|
||||||
class WebRTCSignalingClient:
|
class WebRTCSignalingClient:
|
||||||
"""
|
"""
|
||||||
WebRTC signaling client that communicates with the FastAPI signaling server.
|
WebRTC signaling client that communicates with the FastAPI signaling server.
|
||||||
@ -417,13 +248,9 @@ class WebRTCSignalingClient:
|
|||||||
|
|
||||||
async def _setup_local_media(self):
|
async def _setup_local_media(self):
|
||||||
"""Create local synthetic media tracks"""
|
"""Create local synthetic media tracks"""
|
||||||
# Create synthetic video track
|
# Create synthetic tracks using the new module
|
||||||
video_track = AnimatedVideoTrack(name=self.session_name)
|
tracks = create_synthetic_tracks(self.session_name)
|
||||||
self.local_tracks["video"] = video_track
|
self.local_tracks.update(tracks)
|
||||||
|
|
||||||
# Create synthetic audio track
|
|
||||||
audio_track = SilentAudioTrack()
|
|
||||||
self.local_tracks["audio"] = audio_track
|
|
||||||
|
|
||||||
# Add local peer to peers dict
|
# Add local peer to peers dict
|
||||||
local_peer = Peer(
|
local_peer = Peer(
|
||||||
|
195
voicebot/synthetic_media.py
Normal file
195
voicebot/synthetic_media.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
"""
|
||||||
|
Synthetic Media Tracks Module
|
||||||
|
|
||||||
|
This module provides synthetic audio and video track creation for WebRTC media streaming.
|
||||||
|
Contains AnimatedVideoTrack and SilentAudioTrack implementations ported from JavaScript.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import fractions
|
||||||
|
import time
|
||||||
|
from aiortc import MediaStreamTrack
|
||||||
|
from av import VideoFrame, AudioFrame
|
||||||
|
|
||||||
|
|
||||||
|
class AnimatedVideoTrack(MediaStreamTrack):
|
||||||
|
"""
|
||||||
|
Synthetic video track that generates animated content with a bouncing ball.
|
||||||
|
Ported from JavaScript createAnimatedVideoTrack function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kind = "video"
|
||||||
|
|
||||||
|
def __init__(self, width: int = 320, height: int = 240, name: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
# Generate color from name hash (similar to JavaScript nameToColor)
|
||||||
|
self.ball_color = (
|
||||||
|
self._name_to_color(name) if name else (0, 255, 136)
|
||||||
|
) # Default green
|
||||||
|
|
||||||
|
# Ball properties
|
||||||
|
self.ball = {
|
||||||
|
"x": width / 2,
|
||||||
|
"y": height / 2,
|
||||||
|
"radius": min(width, height) * 0.06,
|
||||||
|
"dx": 3.0,
|
||||||
|
"dy": 2.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.frame_count = 0
|
||||||
|
self._start_time = time.time()
|
||||||
|
|
||||||
|
async def next_timestamp(self):
|
||||||
|
"""Returns (pts, time_base) for 15 FPS video"""
|
||||||
|
pts = int(self.frame_count * (1 / 15) * 90000)
|
||||||
|
time_base = 1 / 90000
|
||||||
|
return pts, time_base
|
||||||
|
|
||||||
|
def _name_to_color(self, name: str) -> tuple[int, int, int]:
|
||||||
|
"""Convert name to HSL color, then to RGB tuple"""
|
||||||
|
# Simple hash function (djb2)
|
||||||
|
hash_value = 5381
|
||||||
|
for char in name:
|
||||||
|
hash_value = ((hash_value << 5) + hash_value + ord(char)) & 0xFFFFFFFF
|
||||||
|
|
||||||
|
# Generate HSL color from hash
|
||||||
|
hue = abs(hash_value) % 360
|
||||||
|
sat = 60 + (abs(hash_value) % 30) # 60-89%
|
||||||
|
light = 45 + (abs(hash_value) % 30) # 45-74%
|
||||||
|
|
||||||
|
# Convert HSL to RGB
|
||||||
|
h = hue / 360.0
|
||||||
|
s = sat / 100.0
|
||||||
|
lightness = light / 100.0
|
||||||
|
|
||||||
|
c = (1 - abs(2 * lightness - 1)) * s
|
||||||
|
x = c * (1 - abs((h * 6) % 2 - 1))
|
||||||
|
m = lightness - c / 2
|
||||||
|
|
||||||
|
if h < 1 / 6:
|
||||||
|
r, g, b = c, x, 0
|
||||||
|
elif h < 2 / 6:
|
||||||
|
r, g, b = x, c, 0
|
||||||
|
elif h < 3 / 6:
|
||||||
|
r, g, b = 0, c, x
|
||||||
|
elif h < 4 / 6:
|
||||||
|
r, g, b = 0, x, c
|
||||||
|
elif h < 5 / 6:
|
||||||
|
r, g, b = x, 0, c
|
||||||
|
else:
|
||||||
|
r, g, b = c, 0, x
|
||||||
|
|
||||||
|
return (
|
||||||
|
int((b + m) * 255),
|
||||||
|
int((g + m) * 255),
|
||||||
|
int((r + m) * 255),
|
||||||
|
) # BGR for OpenCV
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
"""Generate video frames at 15 FPS"""
|
||||||
|
pts, time_base = await self.next_timestamp()
|
||||||
|
|
||||||
|
# Create black background
|
||||||
|
frame_array = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# Update ball position
|
||||||
|
ball = self.ball
|
||||||
|
ball["x"] += ball["dx"]
|
||||||
|
ball["y"] += ball["dy"]
|
||||||
|
|
||||||
|
# Bounce off walls
|
||||||
|
if ball["x"] + ball["radius"] >= self.width or ball["x"] - ball["radius"] <= 0:
|
||||||
|
ball["dx"] = -ball["dx"]
|
||||||
|
if ball["y"] + ball["radius"] >= self.height or ball["y"] - ball["radius"] <= 0:
|
||||||
|
ball["dy"] = -ball["dy"]
|
||||||
|
|
||||||
|
# Keep ball in bounds
|
||||||
|
ball["x"] = max(ball["radius"], min(self.width - ball["radius"], ball["x"]))
|
||||||
|
ball["y"] = max(ball["radius"], min(self.height - ball["radius"], ball["y"]))
|
||||||
|
|
||||||
|
# Draw ball
|
||||||
|
cv2.circle(
|
||||||
|
frame_array,
|
||||||
|
(int(ball["x"]), int(ball["y"])),
|
||||||
|
int(ball["radius"]),
|
||||||
|
self.ball_color,
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add frame counter text
|
||||||
|
frame_text = f"Frame: {int(time.time() * 1000) % 10000}"
|
||||||
|
cv2.putText(
|
||||||
|
frame_array,
|
||||||
|
frame_text,
|
||||||
|
(10, 20),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.5,
|
||||||
|
(255, 255, 255),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to VideoFrame
|
||||||
|
frame = VideoFrame.from_ndarray(frame_array, format="bgr24")
|
||||||
|
frame.pts = pts
|
||||||
|
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
||||||
|
|
||||||
|
self.frame_count += 1
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
class SilentAudioTrack(MediaStreamTrack):
|
||||||
|
"""
|
||||||
|
Synthetic audio track that generates silence.
|
||||||
|
Ported from JavaScript createSilentAudioTrack function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kind = "audio"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.sample_rate = 48000
|
||||||
|
self.samples_per_frame = 960 # 20ms at 48kHz
|
||||||
|
|
||||||
|
async def next_timestamp(self):
|
||||||
|
"""Returns (pts, time_base) for 20ms audio frames at 48kHz"""
|
||||||
|
pts = int(time.time() * self.sample_rate)
|
||||||
|
time_base = 1 / self.sample_rate
|
||||||
|
return pts, time_base
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
"""Generate silent audio frames"""
|
||||||
|
pts, time_base = await self.next_timestamp()
|
||||||
|
|
||||||
|
# Create silent audio data in s16 format (required by Opus encoder)
|
||||||
|
samples = np.zeros((self.samples_per_frame,), dtype=np.int16)
|
||||||
|
|
||||||
|
# Convert to AudioFrame
|
||||||
|
frame = AudioFrame.from_ndarray(
|
||||||
|
samples.reshape(1, -1), format="s16", layout="mono"
|
||||||
|
)
|
||||||
|
frame.sample_rate = self.sample_rate
|
||||||
|
frame.pts = pts
|
||||||
|
frame.time_base = fractions.Fraction(time_base).limit_denominator(1000000)
|
||||||
|
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
def create_synthetic_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
|
||||||
|
"""
|
||||||
|
Create synthetic audio and video tracks for WebRTC streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_name: Name to use for generating video track colors
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing 'video' and 'audio' tracks
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"video": AnimatedVideoTrack(name=session_name),
|
||||||
|
"audio": SilentAudioTrack()
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user