This commit is contained in:
James Ketr 2025-09-19 12:38:08 -07:00
parent 4058f729e2
commit 781748a182
3 changed files with 522 additions and 19 deletions

View File

@ -249,24 +249,37 @@ class WaveformVideoTrack(MediaStreamTrack):
# Non-critical overlay; ignore failures # Non-critical overlay; ignore failures
pass pass
# Select the most active audio buffer and get its speech status # Prefer the generated TTS audio buffer if present. vibevoice ignores
# incoming WebRTC audio; the global `_audio_data` is populated by the
# TTS background worker and represents the audio we want to visualize.
best_proc = None best_proc = None
best_rms = 0.0 best_rms = 0.0
speech_info = None speech_info = None
try: try:
for pname, arr in self.__class__.buffer.items(): if _audio_data is not None and getattr(_audio_data, "size", 0) > 0:
# Use a synthetic pname to indicate TTS-generated audio and
# copy the buffer for safe local use.
try: try:
if len(arr) == 0: tts_arr = np.asarray(_audio_data, dtype=np.float32)
rms = 0.0 best_proc = ("__tts__", tts_arr.copy())
else: # Mark as speech for coloring purposes
rms = float(np.sqrt(np.mean(arr**2))) speech_info = {"is_speech": True, "energy_check": True}
if rms > best_rms:
best_rms = rms
best_proc = (pname, arr.copy())
speech_info = self.__class__.speech_status.get(pname, {})
except Exception: except Exception:
continue best_proc = None
else:
for pname, arr in self.__class__.buffer.items():
try:
if len(arr) == 0:
rms = 0.0
else:
rms = float(np.sqrt(np.mean(arr**2)))
if rms > best_rms:
best_rms = rms
best_proc = (pname, arr.copy())
speech_info = self.__class__.speech_status.get(pname, {})
except Exception:
continue
except Exception: except Exception:
best_proc = None best_proc = None
@ -509,6 +522,14 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None:
logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}") logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}")
return return
# This bot (vibevoice) does not use incoming WebRTC audio for TTS or
# waveform rendering. Ignore audio tracks entirely to avoid populating
# the shared waveform buffer with remote audio. The generated TTS audio
# is stored in the module-global `_audio_data` and will be used for the
# waveform and silent audio track playback instead.
logger.info(f"vibevoice: ignoring incoming audio track from {peer.peer_name}")
return
# Initialize raw audio buffer for immediate graphing # Initialize raw audio buffer for immediate graphing
if peer.peer_name not in WaveformVideoTrack.buffer: if peer.peer_name not in WaveformVideoTrack.buffer:
WaveformVideoTrack.buffer[peer.peer_name] = np.array([], dtype=np.float32) WaveformVideoTrack.buffer[peer.peer_name] = np.array([], dtype=np.float32)

View File

@ -0,0 +1,490 @@
import os
import re
import traceback
from typing import Any, List, Tuple, Union, Optional
import time
import torch
import numpy as np
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
from shared.logger import logger
class VoiceMapper:
"""Maps speaker names to voice file paths"""
def __init__(self, voices_dir: Optional[str] = None):
if voices_dir is None:
voices_dir = os.path.join(os.path.dirname(__file__), "voices")
self.voices_dir = voices_dir
self.setup_voice_presets()
# Change name according to our preset wav file
new_dict = {}
for name, path in self.voice_presets.items():
if '_' in name:
name = name.split('_')[0]
if '-' in name:
name = name.split('-')[-1]
new_dict[name] = path
self.voice_presets.update(new_dict)
def setup_voice_presets(self):
"""Setup voice presets by scanning the voices directory."""
# Check if voices directory exists
if not os.path.exists(self.voices_dir):
logger.info(f"Warning: Voices directory not found at {self.voices_dir}")
self.voice_presets = {}
self.available_voices = {}
return
# Scan for all WAV files in the voices directory
self.voice_presets = {}
# Get all .wav files in the voices directory
wav_files = [f for f in os.listdir(self.voices_dir)
if f.lower().endswith('.wav') and os.path.isfile(os.path.join(self.voices_dir, f))]
# Create dictionary with filename (without extension) as key
for wav_file in wav_files:
# Remove .wav extension to get the name
name = os.path.splitext(wav_file)[0]
# Create full path
full_path = os.path.join(self.voices_dir, wav_file)
self.voice_presets[name] = full_path
# Sort the voice presets alphabetically by name for better UI
self.voice_presets = dict(sorted(self.voice_presets.items()))
# Filter out voices that don't exist (this is now redundant but kept for safety)
self.available_voices = {
name: path for name, path in self.voice_presets.items()
if os.path.exists(path)
}
logger.info(f"Found {len(self.available_voices)} voice files in {self.voices_dir}")
if self.available_voices:
logger.info(f"Available voices: {', '.join(self.available_voices.keys())}")
def get_voice_path(self, speaker_name: str) -> str:
"""Get voice file path for a given speaker name"""
# First try exact match
if speaker_name in self.voice_presets:
return self.voice_presets[speaker_name]
# Try partial matching (case insensitive)
speaker_lower = speaker_name.lower()
for preset_name, path in self.voice_presets.items():
if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower():
return path
# Default to first voice if no match found
if not self.voice_presets:
raise ValueError("No voice files available")
default_voice = list(self.voice_presets.values())[0]
logger.info(f"Warning: No voice preset found for '{speaker_name}', using default voice: {default_voice}")
return default_voice
class VibeVoiceTTS:
"""
A reusable Text-to-Speech engine using VibeVoice model.
Example usage:
tts_engine = VibeVoiceTTS(model_path="microsoft/VibeVoice-1.5b")
audio_data = tts_engine.text_to_speech(
text="Speaker 1: Hello world!\nSpeaker 2: How are you?",
speaker_names=["Andrew", "Ava"]
)
"""
def __init__(
self,
model_path: str = "microsoft/VibeVoice-1.5b",
device: Optional[str] = None,
voices_dir: Optional[str] = None,
cfg_scale: float = 1.3,
ddpm_steps: int = 10
):
"""
Initialize the TTS engine with model and configuration.
Args:
model_path: Path to the HuggingFace model directory
device: Device for inference ('cuda', 'mps', 'cpu'). Auto-detected if None
voices_dir: Directory containing voice sample .wav files
cfg_scale: CFG (Classifier-Free Guidance) scale for generation
ddpm_steps: Number of DDPM inference steps
"""
self.model_path = model_path
self.cfg_scale = cfg_scale
self.ddpm_steps = ddpm_steps
# Auto-detect device if not specified
if device is None:
if torch.xpu.is_available():
device = "xpu"
elif torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
# Handle potential typos
if device.lower() == "mpx":
logger.info("Note: device 'mpx' detected, treating it as 'mps'.")
device = "mps"
# Validate mps availability
if device == "mps" and not torch.backends.mps.is_available():
logger.info("Warning: MPS not available. Falling back to CPU.")
device = "cpu"
self.device = device
logger.info(f"Using device: {self.device}")
# Initialize voice mapper
self.voice_mapper = VoiceMapper(voices_dir)
# Load model and processor
self._load_model()
def _load_model(self):
"""Load the model and processor with device-specific configuration."""
logger.info(f"Loading processor & model from {self.model_path}")
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
# Decide dtype & attention implementation
if self.device == "mps":
load_dtype = torch.float32 # MPS requires float32
attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
elif self.device == "cuda":
load_dtype = torch.bfloat16
attn_impl_primary = "flash_attention_2"
elif self.device == "xpu":
load_dtype = torch.bfloat16
attn_impl_primary = "sdpa" # flash_attention_2 not supported on XPU
else: # cpu
load_dtype = torch.float32
attn_impl_primary = "sdpa"
logger.info(f"Using torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
# Load model with device-specific logic
try:
if self.device == "mps":
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
attn_implementation=attn_impl_primary,
device_map=None, # load then move
)
self.model.to("mps")
elif self.device == "cuda":
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl_primary,
)
elif self.device == "xpu":
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map="xpu",
attn_implementation=attn_impl_primary,
)
else: # cpu
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map="cpu",
attn_implementation=attn_impl_primary,
)
except Exception as e:
if attn_impl_primary == 'flash_attention_2':
logger.info(f"[ERROR] : {type(e).__name__}: {e}")
logger.info(traceback.format_exc())
logger.info("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map=(self.device if self.device in ("cuda", "cpu", "xpu") else None),
attn_implementation='sdpa'
)
if self.device == "mps":
self.model.to("mps")
else:
raise e
self.model.eval()
self.model.set_ddpm_inference_steps(num_steps=self.ddpm_steps)
if hasattr(self.model.model, 'language_model'):
logger.info(f"Language model attention: {self.model.model.language_model.config._attn_implementation}")
logger.info("Model loaded successfully!")
def _parse_script(self, text: str) -> Tuple[List[str], List[str]]:
"""
Parse script text and extract speakers and their text.
Supports format: "Speaker 1: text", "Speaker 2: text", etc.
Returns: (scripts, speaker_numbers)
"""
lines = text.strip().split('\n')
scripts = []
speaker_numbers = []
# Pattern to match "Speaker X:" format where X is a number
speaker_pattern = r'^Speaker\s+(\d+):\s*(.*)$'
current_speaker = None
current_text = ""
for line in lines:
line = line.strip()
if not line:
continue
match = re.match(speaker_pattern, line, re.IGNORECASE)
if match:
# If we have accumulated text from previous speaker, save it
if current_speaker and current_text:
scripts.append(f"Speaker {current_speaker}: {current_text.strip()}")
speaker_numbers.append(current_speaker)
# Start new speaker
current_speaker = match.group(1).strip()
current_text = match.group(2).strip()
else:
# Continue text for current speaker
if current_text:
current_text += " " + line
else:
current_text = line
# Don't forget the last speaker
if current_speaker and current_text:
scripts.append(f"Speaker {current_speaker}: {current_text.strip()}")
speaker_numbers.append(current_speaker)
return scripts, speaker_numbers
def text_to_speech(
self,
text: str,
speaker_names: Union[str, List[str]] = None,
cfg_scale: Optional[float] = None,
verbose: bool = False
) -> np.ndarray:
"""
Convert text to speech and return audio data.
Args:
text: Input text with speaker labels (e.g., "Speaker 1: Hello\nSpeaker 2: Hi there")
speaker_names: Speaker name(s) to map to voice files. Can be single string or list.
cfg_scale: Override default CFG scale for this generation
verbose: Print detailed generation info
Returns:
numpy.ndarray: Audio data as floating point array (sample rate: 24kHz)
"""
if cfg_scale is None:
cfg_scale = self.cfg_scale
# Parse the script to get speaker segments
scripts, speaker_numbers = self._parse_script(text)
if not scripts:
raise ValueError("No valid speaker scripts found in the input text")
if verbose:
logger.info(f"Found {len(scripts)} speaker segments:")
for i, (script, speaker_num) in enumerate(zip(scripts, speaker_numbers)):
logger.info(f" {i+1}. Speaker {speaker_num}")
logger.info(f" Text preview: {script[:100]}...")
# Handle speaker names
if speaker_names is None:
speaker_names = ["Andrew"] # Default speaker
elif isinstance(speaker_names, str):
speaker_names = [speaker_names]
# Map speaker numbers to provided speaker names
speaker_name_mapping = {}
for i, name in enumerate(speaker_names, 1):
speaker_name_mapping[str(i)] = name
if verbose:
logger.info("\nSpeaker mapping:")
for speaker_num in set(speaker_numbers):
mapped_name = speaker_name_mapping.get(speaker_num, f"Speaker {speaker_num}")
logger.info(f" Speaker {speaker_num} -> {mapped_name}")
# Map speakers to voice files
voice_samples = []
actual_speakers = []
# Get unique speaker numbers in order of first appearance
unique_speaker_numbers = []
seen = set()
for speaker_num in speaker_numbers:
if speaker_num not in seen:
unique_speaker_numbers.append(speaker_num)
seen.add(speaker_num)
for speaker_num in unique_speaker_numbers:
speaker_name = speaker_name_mapping.get(speaker_num, f"Speaker {speaker_num}")
voice_path = self.voice_mapper.get_voice_path(speaker_name)
voice_samples.append(voice_path)
actual_speakers.append(speaker_name)
if verbose:
logger.info(f"Speaker {speaker_num} ('{speaker_name}') -> Voice: {os.path.basename(voice_path)}")
# Prepare data for model
full_script = '\n'.join(scripts)
full_script = full_script.replace("'", "'")
# Prepare inputs for the model
inputs = self.processor(
text=[full_script], # Wrap in list for batch processing
voice_samples=[voice_samples], # Wrap in list for batch processing
padding=True,
return_tensors="pt",
return_attention_mask=True,
)
# Move tensors to target device
target_device = self.device if self.device != "cpu" else "cpu"
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(target_device)
if verbose:
logger.info(f"Starting generation with cfg_scale: {cfg_scale}")
# Generate audio
start_time = time.time()
outputs = self.model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=cfg_scale,
tokenizer=self.processor.tokenizer,
generation_config={'do_sample': False},
verbose=verbose,
)
generation_time = time.time() - start_time
if verbose:
logger.info(f"Generation time: {generation_time:.2f} seconds")
# Calculate metrics
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
sample_rate = 24000
audio_samples = outputs.speech_outputs[0].shape[-1] if len(outputs.speech_outputs[0].shape) > 0 else len(outputs.speech_outputs[0])
audio_duration = audio_samples / sample_rate
rtf = generation_time / audio_duration if audio_duration > 0 else float('inf')
logger.info(f"Generated audio duration: {audio_duration:.2f} seconds")
logger.info(f"RTF (Real Time Factor): {rtf:.2f}x")
# Token metrics
input_tokens = inputs['input_ids'].shape[1]
output_tokens = outputs.sequences.shape[1]
generated_tokens = output_tokens - input_tokens
logger.info(f"Prefilling tokens: {input_tokens}")
logger.info(f"Generated tokens: {generated_tokens}")
logger.info(f"Total tokens: {output_tokens}")
# Return audio data as numpy array
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
audio_tensor = outputs.speech_outputs[0]
# Convert to numpy array on CPU
if hasattr(audio_tensor, 'cpu'):
audio_data = audio_tensor.cpu().numpy()
else:
audio_data = np.array(audio_tensor)
# Ensure it's a 1D array
if audio_data.ndim > 1:
audio_data = audio_data.squeeze()
return audio_data
else:
raise RuntimeError("No audio output generated")
def get_available_voices(self) -> List[str]:
"""Get list of available voice names."""
return list(self.voice_mapper.available_voices.keys())
def get_sample_rate(self) -> int:
"""Get the sample rate of generated audio."""
return 24000 # VibeVoice uses 24kHz
# Global instance for easy access
_global_tts_engine = None
def get_tts_engine(**kwargs) -> VibeVoiceTTS:
"""
Get or create a global TTS engine instance.
Args:
**kwargs: Arguments to pass to VibeVoiceTTS constructor (only used on first call)
Returns:
VibeVoiceTTS: Global TTS engine instance
"""
global _global_tts_engine
if _global_tts_engine is None:
_global_tts_engine = VibeVoiceTTS(**kwargs)
return _global_tts_engine
# Convenience function for quick TTS
def text_to_speech(text: str, speaker_names: Optional[Union[str, List[str]]] = None, **kwargs: dict[str, Any]) -> np.ndarray:
"""
Quick text-to-speech conversion using global engine.
Args:
text: Input text with speaker labels
speaker_names: Speaker name(s) to use
**kwargs: Additional arguments for TTS engine or text_to_speech method
Returns:
numpy.ndarray: Audio data
"""
# Separate engine kwargs from TTS kwargs
engine_kwargs = {k: v for k, v in kwargs.items()
if k in ['model_path', 'device', 'voices_dir', 'cfg_scale', 'ddpm_steps']}
tts_kwargs = {k: v for k, v in kwargs.items() if k not in engine_kwargs}
engine = get_tts_engine(**engine_kwargs)
return engine.text_to_speech(text, speaker_names, **tts_kwargs)
# Example usage:
# Method 1: Create instance
# tts_engine = VibeVoiceTTS(model_path="microsoft/VibeVoice-1.5b")
# audio_data = tts_engine.text_to_speech(
# "Speaker 1: Hello world!\nSpeaker 2: How are you?",
# speaker_names=["Andrew", "Ava"]
# )
# # Method 2: Use global instance
# audio_data = text_to_speech(
# "Speaker 1: Hello world!",
# speaker_names="Andrew",
# verbose=True
# )
# # Method 3: Global engine with custom config
# engine = get_tts_engine(device="cuda", cfg_scale=1.5)
# audio_data = engine.text_to_speech("Speaker 1: Hello!", ["Andrew"])

View File

@ -1,6 +1,4 @@
about-time==4.2.1 about-time==4.2.1
absl-py==2.3.1
accelerate==1.6.0
aiofiles==24.1.0 aiofiles==24.1.0
aiohappyeyeballs==2.6.1 aiohappyeyeballs==2.6.1
aiohttp==3.12.15 aiohttp==3.12.15
@ -27,7 +25,6 @@ cycler==0.12.1
datasets==4.1.0 datasets==4.1.0
decorator==5.2.1 decorator==5.2.1
deprecated==1.2.18 deprecated==1.2.18
diffusers==0.35.1
dill==0.4.0 dill==0.4.0
distro==1.9.0 distro==1.9.0
dnspython==2.8.0 dnspython==2.8.0
@ -50,7 +47,6 @@ httpx==0.28.1
huggingface-hub==0.34.5 huggingface-hub==0.34.5
idna==3.10 idna==3.10
ifaddr==0.2.0 ifaddr==0.2.0
importlib-metadata==8.7.0
iniconfig==2.1.0 iniconfig==2.1.0
jinja2==3.1.6 jinja2==3.1.6
jiter==0.11.0 jiter==0.11.0
@ -66,7 +62,6 @@ markdown-it-py==4.0.0
markupsafe==3.0.2 markupsafe==3.0.2
matplotlib==3.10.6 matplotlib==3.10.6
mdurl==0.1.2 mdurl==0.1.2
ml-collections==1.1.0
ml-dtypes==0.5.3 ml-dtypes==0.5.3
more-itertools==10.8.0 more-itertools==10.8.0
mpmath==1.3.0 mpmath==1.3.0
@ -106,7 +101,6 @@ optimum-intel @ git+https://github.com/huggingface/optimum-intel.git@b9c151fec6b
orjson==3.11.3 orjson==3.11.3
packaging==25.0 packaging==25.0
pandas==2.3.2 pandas==2.3.2
peft==0.17.1
pillow==11.3.0 pillow==11.3.0
platformdirs==4.4.0 platformdirs==4.4.0
pluggy==1.6.0 pluggy==1.6.0
@ -174,10 +168,8 @@ typing-inspection==0.4.1
tzdata==2025.2 tzdata==2025.2
urllib3==2.5.0 urllib3==2.5.0
uvicorn==0.35.0 uvicorn==0.35.0
-e file:///voicebot/VibeVoice
watchdog==6.0.0 watchdog==6.0.0
websockets==15.0.1 websockets==15.0.1
wrapt==1.17.3 wrapt==1.17.3
xxhash==3.5.0 xxhash==3.5.0
yarl==1.20.1 yarl==1.20.1
zipp==3.23.0