Working
This commit is contained in:
parent
4058f729e2
commit
781748a182
@ -249,24 +249,37 @@ class WaveformVideoTrack(MediaStreamTrack):
|
|||||||
# Non-critical overlay; ignore failures
|
# Non-critical overlay; ignore failures
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Select the most active audio buffer and get its speech status
|
# Prefer the generated TTS audio buffer if present. vibevoice ignores
|
||||||
|
# incoming WebRTC audio; the global `_audio_data` is populated by the
|
||||||
|
# TTS background worker and represents the audio we want to visualize.
|
||||||
best_proc = None
|
best_proc = None
|
||||||
best_rms = 0.0
|
best_rms = 0.0
|
||||||
speech_info = None
|
speech_info = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for pname, arr in self.__class__.buffer.items():
|
if _audio_data is not None and getattr(_audio_data, "size", 0) > 0:
|
||||||
|
# Use a synthetic pname to indicate TTS-generated audio and
|
||||||
|
# copy the buffer for safe local use.
|
||||||
try:
|
try:
|
||||||
if len(arr) == 0:
|
tts_arr = np.asarray(_audio_data, dtype=np.float32)
|
||||||
rms = 0.0
|
best_proc = ("__tts__", tts_arr.copy())
|
||||||
else:
|
# Mark as speech for coloring purposes
|
||||||
rms = float(np.sqrt(np.mean(arr**2)))
|
speech_info = {"is_speech": True, "energy_check": True}
|
||||||
if rms > best_rms:
|
|
||||||
best_rms = rms
|
|
||||||
best_proc = (pname, arr.copy())
|
|
||||||
speech_info = self.__class__.speech_status.get(pname, {})
|
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
best_proc = None
|
||||||
|
else:
|
||||||
|
for pname, arr in self.__class__.buffer.items():
|
||||||
|
try:
|
||||||
|
if len(arr) == 0:
|
||||||
|
rms = 0.0
|
||||||
|
else:
|
||||||
|
rms = float(np.sqrt(np.mean(arr**2)))
|
||||||
|
if rms > best_rms:
|
||||||
|
best_rms = rms
|
||||||
|
best_proc = (pname, arr.copy())
|
||||||
|
speech_info = self.__class__.speech_status.get(pname, {})
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
except Exception:
|
except Exception:
|
||||||
best_proc = None
|
best_proc = None
|
||||||
|
|
||||||
@ -509,6 +522,14 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack) -> None:
|
|||||||
logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}")
|
logger.info(f"Ignoring non-audio track from {peer.peer_name}: {track.kind}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# This bot (vibevoice) does not use incoming WebRTC audio for TTS or
|
||||||
|
# waveform rendering. Ignore audio tracks entirely to avoid populating
|
||||||
|
# the shared waveform buffer with remote audio. The generated TTS audio
|
||||||
|
# is stored in the module-global `_audio_data` and will be used for the
|
||||||
|
# waveform and silent audio track playback instead.
|
||||||
|
logger.info(f"vibevoice: ignoring incoming audio track from {peer.peer_name}")
|
||||||
|
return
|
||||||
|
|
||||||
# Initialize raw audio buffer for immediate graphing
|
# Initialize raw audio buffer for immediate graphing
|
||||||
if peer.peer_name not in WaveformVideoTrack.buffer:
|
if peer.peer_name not in WaveformVideoTrack.buffer:
|
||||||
WaveformVideoTrack.buffer[peer.peer_name] = np.array([], dtype=np.float32)
|
WaveformVideoTrack.buffer[peer.peer_name] = np.array([], dtype=np.float32)
|
||||||
|
490
voicebot/bots/vibevoicetts.py
Normal file
490
voicebot/bots/vibevoicetts.py
Normal file
@ -0,0 +1,490 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import traceback
|
||||||
|
from typing import Any, List, Tuple, Union, Optional
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
||||||
|
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
||||||
|
|
||||||
|
from shared.logger import logger
|
||||||
|
|
||||||
|
class VoiceMapper:
|
||||||
|
"""Maps speaker names to voice file paths"""
|
||||||
|
|
||||||
|
def __init__(self, voices_dir: Optional[str] = None):
|
||||||
|
if voices_dir is None:
|
||||||
|
voices_dir = os.path.join(os.path.dirname(__file__), "voices")
|
||||||
|
self.voices_dir = voices_dir
|
||||||
|
self.setup_voice_presets()
|
||||||
|
|
||||||
|
# Change name according to our preset wav file
|
||||||
|
new_dict = {}
|
||||||
|
for name, path in self.voice_presets.items():
|
||||||
|
if '_' in name:
|
||||||
|
name = name.split('_')[0]
|
||||||
|
if '-' in name:
|
||||||
|
name = name.split('-')[-1]
|
||||||
|
new_dict[name] = path
|
||||||
|
self.voice_presets.update(new_dict)
|
||||||
|
|
||||||
|
def setup_voice_presets(self):
|
||||||
|
"""Setup voice presets by scanning the voices directory."""
|
||||||
|
# Check if voices directory exists
|
||||||
|
if not os.path.exists(self.voices_dir):
|
||||||
|
logger.info(f"Warning: Voices directory not found at {self.voices_dir}")
|
||||||
|
self.voice_presets = {}
|
||||||
|
self.available_voices = {}
|
||||||
|
return
|
||||||
|
|
||||||
|
# Scan for all WAV files in the voices directory
|
||||||
|
self.voice_presets = {}
|
||||||
|
|
||||||
|
# Get all .wav files in the voices directory
|
||||||
|
wav_files = [f for f in os.listdir(self.voices_dir)
|
||||||
|
if f.lower().endswith('.wav') and os.path.isfile(os.path.join(self.voices_dir, f))]
|
||||||
|
|
||||||
|
# Create dictionary with filename (without extension) as key
|
||||||
|
for wav_file in wav_files:
|
||||||
|
# Remove .wav extension to get the name
|
||||||
|
name = os.path.splitext(wav_file)[0]
|
||||||
|
# Create full path
|
||||||
|
full_path = os.path.join(self.voices_dir, wav_file)
|
||||||
|
self.voice_presets[name] = full_path
|
||||||
|
|
||||||
|
# Sort the voice presets alphabetically by name for better UI
|
||||||
|
self.voice_presets = dict(sorted(self.voice_presets.items()))
|
||||||
|
|
||||||
|
# Filter out voices that don't exist (this is now redundant but kept for safety)
|
||||||
|
self.available_voices = {
|
||||||
|
name: path for name, path in self.voice_presets.items()
|
||||||
|
if os.path.exists(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Found {len(self.available_voices)} voice files in {self.voices_dir}")
|
||||||
|
if self.available_voices:
|
||||||
|
logger.info(f"Available voices: {', '.join(self.available_voices.keys())}")
|
||||||
|
|
||||||
|
def get_voice_path(self, speaker_name: str) -> str:
|
||||||
|
"""Get voice file path for a given speaker name"""
|
||||||
|
# First try exact match
|
||||||
|
if speaker_name in self.voice_presets:
|
||||||
|
return self.voice_presets[speaker_name]
|
||||||
|
|
||||||
|
# Try partial matching (case insensitive)
|
||||||
|
speaker_lower = speaker_name.lower()
|
||||||
|
for preset_name, path in self.voice_presets.items():
|
||||||
|
if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower():
|
||||||
|
return path
|
||||||
|
|
||||||
|
# Default to first voice if no match found
|
||||||
|
if not self.voice_presets:
|
||||||
|
raise ValueError("No voice files available")
|
||||||
|
|
||||||
|
default_voice = list(self.voice_presets.values())[0]
|
||||||
|
logger.info(f"Warning: No voice preset found for '{speaker_name}', using default voice: {default_voice}")
|
||||||
|
return default_voice
|
||||||
|
|
||||||
|
|
||||||
|
class VibeVoiceTTS:
|
||||||
|
"""
|
||||||
|
A reusable Text-to-Speech engine using VibeVoice model.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
tts_engine = VibeVoiceTTS(model_path="microsoft/VibeVoice-1.5b")
|
||||||
|
audio_data = tts_engine.text_to_speech(
|
||||||
|
text="Speaker 1: Hello world!\nSpeaker 2: How are you?",
|
||||||
|
speaker_names=["Andrew", "Ava"]
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str = "microsoft/VibeVoice-1.5b",
|
||||||
|
device: Optional[str] = None,
|
||||||
|
voices_dir: Optional[str] = None,
|
||||||
|
cfg_scale: float = 1.3,
|
||||||
|
ddpm_steps: int = 10
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the TTS engine with model and configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the HuggingFace model directory
|
||||||
|
device: Device for inference ('cuda', 'mps', 'cpu'). Auto-detected if None
|
||||||
|
voices_dir: Directory containing voice sample .wav files
|
||||||
|
cfg_scale: CFG (Classifier-Free Guidance) scale for generation
|
||||||
|
ddpm_steps: Number of DDPM inference steps
|
||||||
|
"""
|
||||||
|
self.model_path = model_path
|
||||||
|
self.cfg_scale = cfg_scale
|
||||||
|
self.ddpm_steps = ddpm_steps
|
||||||
|
|
||||||
|
# Auto-detect device if not specified
|
||||||
|
if device is None:
|
||||||
|
if torch.xpu.is_available():
|
||||||
|
device = "xpu"
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
# Handle potential typos
|
||||||
|
if device.lower() == "mpx":
|
||||||
|
logger.info("Note: device 'mpx' detected, treating it as 'mps'.")
|
||||||
|
device = "mps"
|
||||||
|
|
||||||
|
# Validate mps availability
|
||||||
|
if device == "mps" and not torch.backends.mps.is_available():
|
||||||
|
logger.info("Warning: MPS not available. Falling back to CPU.")
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
logger.info(f"Using device: {self.device}")
|
||||||
|
|
||||||
|
# Initialize voice mapper
|
||||||
|
self.voice_mapper = VoiceMapper(voices_dir)
|
||||||
|
|
||||||
|
# Load model and processor
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""Load the model and processor with device-specific configuration."""
|
||||||
|
logger.info(f"Loading processor & model from {self.model_path}")
|
||||||
|
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
|
||||||
|
|
||||||
|
# Decide dtype & attention implementation
|
||||||
|
if self.device == "mps":
|
||||||
|
load_dtype = torch.float32 # MPS requires float32
|
||||||
|
attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
|
||||||
|
elif self.device == "cuda":
|
||||||
|
load_dtype = torch.bfloat16
|
||||||
|
attn_impl_primary = "flash_attention_2"
|
||||||
|
elif self.device == "xpu":
|
||||||
|
load_dtype = torch.bfloat16
|
||||||
|
attn_impl_primary = "sdpa" # flash_attention_2 not supported on XPU
|
||||||
|
else: # cpu
|
||||||
|
load_dtype = torch.float32
|
||||||
|
attn_impl_primary = "sdpa"
|
||||||
|
|
||||||
|
logger.info(f"Using torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
|
||||||
|
|
||||||
|
# Load model with device-specific logic
|
||||||
|
try:
|
||||||
|
if self.device == "mps":
|
||||||
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
attn_implementation=attn_impl_primary,
|
||||||
|
device_map=None, # load then move
|
||||||
|
)
|
||||||
|
self.model.to("mps")
|
||||||
|
elif self.device == "cuda":
|
||||||
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map="cuda",
|
||||||
|
attn_implementation=attn_impl_primary,
|
||||||
|
)
|
||||||
|
elif self.device == "xpu":
|
||||||
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map="xpu",
|
||||||
|
attn_implementation=attn_impl_primary,
|
||||||
|
)
|
||||||
|
else: # cpu
|
||||||
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
attn_implementation=attn_impl_primary,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if attn_impl_primary == 'flash_attention_2':
|
||||||
|
logger.info(f"[ERROR] : {type(e).__name__}: {e}")
|
||||||
|
logger.info(traceback.format_exc())
|
||||||
|
logger.info("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
||||||
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
torch_dtype=load_dtype,
|
||||||
|
device_map=(self.device if self.device in ("cuda", "cpu", "xpu") else None),
|
||||||
|
attn_implementation='sdpa'
|
||||||
|
)
|
||||||
|
if self.device == "mps":
|
||||||
|
self.model.to("mps")
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
self.model.set_ddpm_inference_steps(num_steps=self.ddpm_steps)
|
||||||
|
|
||||||
|
if hasattr(self.model.model, 'language_model'):
|
||||||
|
logger.info(f"Language model attention: {self.model.model.language_model.config._attn_implementation}")
|
||||||
|
|
||||||
|
logger.info("Model loaded successfully!")
|
||||||
|
|
||||||
|
def _parse_script(self, text: str) -> Tuple[List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
Parse script text and extract speakers and their text.
|
||||||
|
Supports format: "Speaker 1: text", "Speaker 2: text", etc.
|
||||||
|
|
||||||
|
Returns: (scripts, speaker_numbers)
|
||||||
|
"""
|
||||||
|
lines = text.strip().split('\n')
|
||||||
|
scripts = []
|
||||||
|
speaker_numbers = []
|
||||||
|
|
||||||
|
# Pattern to match "Speaker X:" format where X is a number
|
||||||
|
speaker_pattern = r'^Speaker\s+(\d+):\s*(.*)$'
|
||||||
|
|
||||||
|
current_speaker = None
|
||||||
|
current_text = ""
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
match = re.match(speaker_pattern, line, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
# If we have accumulated text from previous speaker, save it
|
||||||
|
if current_speaker and current_text:
|
||||||
|
scripts.append(f"Speaker {current_speaker}: {current_text.strip()}")
|
||||||
|
speaker_numbers.append(current_speaker)
|
||||||
|
|
||||||
|
# Start new speaker
|
||||||
|
current_speaker = match.group(1).strip()
|
||||||
|
current_text = match.group(2).strip()
|
||||||
|
else:
|
||||||
|
# Continue text for current speaker
|
||||||
|
if current_text:
|
||||||
|
current_text += " " + line
|
||||||
|
else:
|
||||||
|
current_text = line
|
||||||
|
|
||||||
|
# Don't forget the last speaker
|
||||||
|
if current_speaker and current_text:
|
||||||
|
scripts.append(f"Speaker {current_speaker}: {current_text.strip()}")
|
||||||
|
speaker_numbers.append(current_speaker)
|
||||||
|
|
||||||
|
return scripts, speaker_numbers
|
||||||
|
|
||||||
|
def text_to_speech(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
speaker_names: Union[str, List[str]] = None,
|
||||||
|
cfg_scale: Optional[float] = None,
|
||||||
|
verbose: bool = False
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Convert text to speech and return audio data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text with speaker labels (e.g., "Speaker 1: Hello\nSpeaker 2: Hi there")
|
||||||
|
speaker_names: Speaker name(s) to map to voice files. Can be single string or list.
|
||||||
|
cfg_scale: Override default CFG scale for this generation
|
||||||
|
verbose: Print detailed generation info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: Audio data as floating point array (sample rate: 24kHz)
|
||||||
|
"""
|
||||||
|
if cfg_scale is None:
|
||||||
|
cfg_scale = self.cfg_scale
|
||||||
|
|
||||||
|
# Parse the script to get speaker segments
|
||||||
|
scripts, speaker_numbers = self._parse_script(text)
|
||||||
|
|
||||||
|
if not scripts:
|
||||||
|
raise ValueError("No valid speaker scripts found in the input text")
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"Found {len(scripts)} speaker segments:")
|
||||||
|
for i, (script, speaker_num) in enumerate(zip(scripts, speaker_numbers)):
|
||||||
|
logger.info(f" {i+1}. Speaker {speaker_num}")
|
||||||
|
logger.info(f" Text preview: {script[:100]}...")
|
||||||
|
|
||||||
|
# Handle speaker names
|
||||||
|
if speaker_names is None:
|
||||||
|
speaker_names = ["Andrew"] # Default speaker
|
||||||
|
elif isinstance(speaker_names, str):
|
||||||
|
speaker_names = [speaker_names]
|
||||||
|
|
||||||
|
# Map speaker numbers to provided speaker names
|
||||||
|
speaker_name_mapping = {}
|
||||||
|
for i, name in enumerate(speaker_names, 1):
|
||||||
|
speaker_name_mapping[str(i)] = name
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logger.info("\nSpeaker mapping:")
|
||||||
|
for speaker_num in set(speaker_numbers):
|
||||||
|
mapped_name = speaker_name_mapping.get(speaker_num, f"Speaker {speaker_num}")
|
||||||
|
logger.info(f" Speaker {speaker_num} -> {mapped_name}")
|
||||||
|
|
||||||
|
# Map speakers to voice files
|
||||||
|
voice_samples = []
|
||||||
|
actual_speakers = []
|
||||||
|
|
||||||
|
# Get unique speaker numbers in order of first appearance
|
||||||
|
unique_speaker_numbers = []
|
||||||
|
seen = set()
|
||||||
|
for speaker_num in speaker_numbers:
|
||||||
|
if speaker_num not in seen:
|
||||||
|
unique_speaker_numbers.append(speaker_num)
|
||||||
|
seen.add(speaker_num)
|
||||||
|
|
||||||
|
for speaker_num in unique_speaker_numbers:
|
||||||
|
speaker_name = speaker_name_mapping.get(speaker_num, f"Speaker {speaker_num}")
|
||||||
|
voice_path = self.voice_mapper.get_voice_path(speaker_name)
|
||||||
|
voice_samples.append(voice_path)
|
||||||
|
actual_speakers.append(speaker_name)
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"Speaker {speaker_num} ('{speaker_name}') -> Voice: {os.path.basename(voice_path)}")
|
||||||
|
|
||||||
|
# Prepare data for model
|
||||||
|
full_script = '\n'.join(scripts)
|
||||||
|
full_script = full_script.replace("'", "'")
|
||||||
|
|
||||||
|
# Prepare inputs for the model
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[full_script], # Wrap in list for batch processing
|
||||||
|
voice_samples=[voice_samples], # Wrap in list for batch processing
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_attention_mask=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move tensors to target device
|
||||||
|
target_device = self.device if self.device != "cpu" else "cpu"
|
||||||
|
for k, v in inputs.items():
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
inputs[k] = v.to(target_device)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"Starting generation with cfg_scale: {cfg_scale}")
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=None,
|
||||||
|
cfg_scale=cfg_scale,
|
||||||
|
tokenizer=self.processor.tokenizer,
|
||||||
|
generation_config={'do_sample': False},
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
logger.info(f"Generation time: {generation_time:.2f} seconds")
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
|
||||||
|
sample_rate = 24000
|
||||||
|
audio_samples = outputs.speech_outputs[0].shape[-1] if len(outputs.speech_outputs[0].shape) > 0 else len(outputs.speech_outputs[0])
|
||||||
|
audio_duration = audio_samples / sample_rate
|
||||||
|
rtf = generation_time / audio_duration if audio_duration > 0 else float('inf')
|
||||||
|
|
||||||
|
logger.info(f"Generated audio duration: {audio_duration:.2f} seconds")
|
||||||
|
logger.info(f"RTF (Real Time Factor): {rtf:.2f}x")
|
||||||
|
|
||||||
|
# Token metrics
|
||||||
|
input_tokens = inputs['input_ids'].shape[1]
|
||||||
|
output_tokens = outputs.sequences.shape[1]
|
||||||
|
generated_tokens = output_tokens - input_tokens
|
||||||
|
|
||||||
|
logger.info(f"Prefilling tokens: {input_tokens}")
|
||||||
|
logger.info(f"Generated tokens: {generated_tokens}")
|
||||||
|
logger.info(f"Total tokens: {output_tokens}")
|
||||||
|
|
||||||
|
# Return audio data as numpy array
|
||||||
|
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
|
||||||
|
audio_tensor = outputs.speech_outputs[0]
|
||||||
|
|
||||||
|
# Convert to numpy array on CPU
|
||||||
|
if hasattr(audio_tensor, 'cpu'):
|
||||||
|
audio_data = audio_tensor.cpu().numpy()
|
||||||
|
else:
|
||||||
|
audio_data = np.array(audio_tensor)
|
||||||
|
|
||||||
|
# Ensure it's a 1D array
|
||||||
|
if audio_data.ndim > 1:
|
||||||
|
audio_data = audio_data.squeeze()
|
||||||
|
|
||||||
|
return audio_data
|
||||||
|
else:
|
||||||
|
raise RuntimeError("No audio output generated")
|
||||||
|
|
||||||
|
def get_available_voices(self) -> List[str]:
|
||||||
|
"""Get list of available voice names."""
|
||||||
|
return list(self.voice_mapper.available_voices.keys())
|
||||||
|
|
||||||
|
def get_sample_rate(self) -> int:
|
||||||
|
"""Get the sample rate of generated audio."""
|
||||||
|
return 24000 # VibeVoice uses 24kHz
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance for easy access
|
||||||
|
_global_tts_engine = None
|
||||||
|
|
||||||
|
def get_tts_engine(**kwargs) -> VibeVoiceTTS:
|
||||||
|
"""
|
||||||
|
Get or create a global TTS engine instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Arguments to pass to VibeVoiceTTS constructor (only used on first call)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VibeVoiceTTS: Global TTS engine instance
|
||||||
|
"""
|
||||||
|
global _global_tts_engine
|
||||||
|
if _global_tts_engine is None:
|
||||||
|
_global_tts_engine = VibeVoiceTTS(**kwargs)
|
||||||
|
return _global_tts_engine
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience function for quick TTS
|
||||||
|
def text_to_speech(text: str, speaker_names: Optional[Union[str, List[str]]] = None, **kwargs: dict[str, Any]) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Quick text-to-speech conversion using global engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text with speaker labels
|
||||||
|
speaker_names: Speaker name(s) to use
|
||||||
|
**kwargs: Additional arguments for TTS engine or text_to_speech method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: Audio data
|
||||||
|
"""
|
||||||
|
# Separate engine kwargs from TTS kwargs
|
||||||
|
engine_kwargs = {k: v for k, v in kwargs.items()
|
||||||
|
if k in ['model_path', 'device', 'voices_dir', 'cfg_scale', 'ddpm_steps']}
|
||||||
|
tts_kwargs = {k: v for k, v in kwargs.items() if k not in engine_kwargs}
|
||||||
|
|
||||||
|
engine = get_tts_engine(**engine_kwargs)
|
||||||
|
return engine.text_to_speech(text, speaker_names, **tts_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
# Method 1: Create instance
|
||||||
|
# tts_engine = VibeVoiceTTS(model_path="microsoft/VibeVoice-1.5b")
|
||||||
|
# audio_data = tts_engine.text_to_speech(
|
||||||
|
# "Speaker 1: Hello world!\nSpeaker 2: How are you?",
|
||||||
|
# speaker_names=["Andrew", "Ava"]
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Method 2: Use global instance
|
||||||
|
# audio_data = text_to_speech(
|
||||||
|
# "Speaker 1: Hello world!",
|
||||||
|
# speaker_names="Andrew",
|
||||||
|
# verbose=True
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # Method 3: Global engine with custom config
|
||||||
|
# engine = get_tts_engine(device="cuda", cfg_scale=1.5)
|
||||||
|
# audio_data = engine.text_to_speech("Speaker 1: Hello!", ["Andrew"])
|
@ -1,6 +1,4 @@
|
|||||||
about-time==4.2.1
|
about-time==4.2.1
|
||||||
absl-py==2.3.1
|
|
||||||
accelerate==1.6.0
|
|
||||||
aiofiles==24.1.0
|
aiofiles==24.1.0
|
||||||
aiohappyeyeballs==2.6.1
|
aiohappyeyeballs==2.6.1
|
||||||
aiohttp==3.12.15
|
aiohttp==3.12.15
|
||||||
@ -27,7 +25,6 @@ cycler==0.12.1
|
|||||||
datasets==4.1.0
|
datasets==4.1.0
|
||||||
decorator==5.2.1
|
decorator==5.2.1
|
||||||
deprecated==1.2.18
|
deprecated==1.2.18
|
||||||
diffusers==0.35.1
|
|
||||||
dill==0.4.0
|
dill==0.4.0
|
||||||
distro==1.9.0
|
distro==1.9.0
|
||||||
dnspython==2.8.0
|
dnspython==2.8.0
|
||||||
@ -50,7 +47,6 @@ httpx==0.28.1
|
|||||||
huggingface-hub==0.34.5
|
huggingface-hub==0.34.5
|
||||||
idna==3.10
|
idna==3.10
|
||||||
ifaddr==0.2.0
|
ifaddr==0.2.0
|
||||||
importlib-metadata==8.7.0
|
|
||||||
iniconfig==2.1.0
|
iniconfig==2.1.0
|
||||||
jinja2==3.1.6
|
jinja2==3.1.6
|
||||||
jiter==0.11.0
|
jiter==0.11.0
|
||||||
@ -66,7 +62,6 @@ markdown-it-py==4.0.0
|
|||||||
markupsafe==3.0.2
|
markupsafe==3.0.2
|
||||||
matplotlib==3.10.6
|
matplotlib==3.10.6
|
||||||
mdurl==0.1.2
|
mdurl==0.1.2
|
||||||
ml-collections==1.1.0
|
|
||||||
ml-dtypes==0.5.3
|
ml-dtypes==0.5.3
|
||||||
more-itertools==10.8.0
|
more-itertools==10.8.0
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
@ -106,7 +101,6 @@ optimum-intel @ git+https://github.com/huggingface/optimum-intel.git@b9c151fec6b
|
|||||||
orjson==3.11.3
|
orjson==3.11.3
|
||||||
packaging==25.0
|
packaging==25.0
|
||||||
pandas==2.3.2
|
pandas==2.3.2
|
||||||
peft==0.17.1
|
|
||||||
pillow==11.3.0
|
pillow==11.3.0
|
||||||
platformdirs==4.4.0
|
platformdirs==4.4.0
|
||||||
pluggy==1.6.0
|
pluggy==1.6.0
|
||||||
@ -174,10 +168,8 @@ typing-inspection==0.4.1
|
|||||||
tzdata==2025.2
|
tzdata==2025.2
|
||||||
urllib3==2.5.0
|
urllib3==2.5.0
|
||||||
uvicorn==0.35.0
|
uvicorn==0.35.0
|
||||||
-e file:///voicebot/VibeVoice
|
|
||||||
watchdog==6.0.0
|
watchdog==6.0.0
|
||||||
websockets==15.0.1
|
websockets==15.0.1
|
||||||
wrapt==1.17.3
|
wrapt==1.17.3
|
||||||
xxhash==3.5.0
|
xxhash==3.5.0
|
||||||
yarl==1.20.1
|
yarl==1.20.1
|
||||||
zipp==3.23.0
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user