540 lines
22 KiB
Python
540 lines
22 KiB
Python
import os
|
|
import re
|
|
import traceback
|
|
from typing import Any, List, Tuple, Union, Optional
|
|
import time
|
|
import torch
|
|
import numpy as np
|
|
import sys
|
|
|
|
# Defer importing the external `vibevoice` package until we actually need it.
|
|
# In some environments the `vibevoice` package isn't installed into site-packages
|
|
# but the repo contains a local copy under `voicebot/VibeVoice`. Attempt a lazy
|
|
# import and, if that fails, add the local path(s) to sys.path and retry.
|
|
|
|
def _import_vibevoice_symbols():
|
|
try:
|
|
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
|
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
|
return VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor
|
|
except Exception:
|
|
# If a required package (like `diffusers`) is missing inside the
|
|
# container or venv, importing deeper VibeVoice modules will raise
|
|
# ModuleNotFoundError. Detect that and raise a clearer error that
|
|
# includes install instructions.
|
|
import traceback as _tb
|
|
exc_type, exc_val, exc_tb = _tb.sys.exc_info()
|
|
if isinstance(exc_val, ModuleNotFoundError):
|
|
missing = str(exc_val).split("'")[1] if "'" in str(exc_val) else str(exc_val)
|
|
raise ModuleNotFoundError(
|
|
f"Missing dependency when importing VibeVoice: {missing}.\n"
|
|
"Install required packages inside the voicebot container.\n"
|
|
"Example (inside container):\n"
|
|
" PYTHONPATH=/shared:/voicebot uv run python3 -m pip install diffusers accelerate safetensors\n"
|
|
"Or add the packages to the voicebot service environment / pyproject and rebuild."
|
|
) from exc_val
|
|
# Try adding likely repository-local paths where VibeVoice lives
|
|
base = os.path.dirname(__file__) # voicebot/bots
|
|
candidates = [
|
|
os.path.abspath(os.path.join(base, "..", "VibeVoice")),
|
|
os.path.abspath(os.path.join("/", "voicebot", "VibeVoice")),
|
|
]
|
|
for p in candidates:
|
|
if os.path.isdir(p) and p not in sys.path:
|
|
sys.path.insert(0, p)
|
|
# Retry import
|
|
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
|
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
|
return VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor
|
|
|
|
from shared.logger import logger
|
|
|
|
class VoiceMapper:
|
|
"""Maps speaker names to voice file paths"""
|
|
|
|
def __init__(self, voices_dir: Optional[str] = None):
|
|
if voices_dir is None:
|
|
voices_dir = os.path.join(os.path.dirname(__file__), "voices")
|
|
self.voices_dir = voices_dir
|
|
self.setup_voice_presets()
|
|
|
|
# Change name according to our preset wav file
|
|
new_dict = {}
|
|
for name, path in self.voice_presets.items():
|
|
if '_' in name:
|
|
name = name.split('_')[0]
|
|
if '-' in name:
|
|
name = name.split('-')[-1]
|
|
new_dict[name] = path
|
|
self.voice_presets.update(new_dict)
|
|
|
|
def setup_voice_presets(self):
|
|
"""Setup voice presets by scanning the voices directory."""
|
|
# Check if voices directory exists
|
|
if not os.path.exists(self.voices_dir):
|
|
logger.info(f"Warning: Voices directory not found at {self.voices_dir}")
|
|
self.voice_presets = {}
|
|
self.available_voices = {}
|
|
return
|
|
|
|
# Scan for all WAV files in the voices directory
|
|
self.voice_presets = {}
|
|
|
|
# Get all .wav files in the voices directory
|
|
wav_files = [f for f in os.listdir(self.voices_dir)
|
|
if f.lower().endswith('.wav') and os.path.isfile(os.path.join(self.voices_dir, f))]
|
|
|
|
# Create dictionary with filename (without extension) as key
|
|
for wav_file in wav_files:
|
|
# Remove .wav extension to get the name
|
|
name = os.path.splitext(wav_file)[0]
|
|
# Create full path
|
|
full_path = os.path.join(self.voices_dir, wav_file)
|
|
self.voice_presets[name] = full_path
|
|
|
|
# Sort the voice presets alphabetically by name for better UI
|
|
self.voice_presets = dict(sorted(self.voice_presets.items()))
|
|
|
|
# Filter out voices that don't exist (this is now redundant but kept for safety)
|
|
self.available_voices = {
|
|
name: path for name, path in self.voice_presets.items()
|
|
if os.path.exists(path)
|
|
}
|
|
|
|
logger.info(f"Found {len(self.available_voices)} voice files in {self.voices_dir}")
|
|
if self.available_voices:
|
|
logger.info(f"Available voices: {', '.join(self.available_voices.keys())}")
|
|
|
|
def get_voice_path(self, speaker_name: str) -> str:
|
|
"""Get voice file path for a given speaker name"""
|
|
# First try exact match
|
|
if speaker_name in self.voice_presets:
|
|
return self.voice_presets[speaker_name]
|
|
|
|
# Try partial matching (case insensitive)
|
|
speaker_lower = speaker_name.lower()
|
|
for preset_name, path in self.voice_presets.items():
|
|
if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower():
|
|
return path
|
|
|
|
# Default to first voice if no match found
|
|
if not self.voice_presets:
|
|
raise ValueError("No voice files available")
|
|
|
|
default_voice = list(self.voice_presets.values())[0]
|
|
logger.info(f"Warning: No voice preset found for '{speaker_name}', using default voice: {default_voice}")
|
|
return default_voice
|
|
|
|
|
|
class VibeVoiceTTS:
|
|
"""
|
|
A reusable Text-to-Speech engine using VibeVoice model.
|
|
|
|
Example usage:
|
|
tts_engine = VibeVoiceTTS(model_path="microsoft/VibeVoice-1.5b")
|
|
audio_data = tts_engine.text_to_speech(
|
|
text="Speaker 1: Hello world!\nSpeaker 2: How are you?",
|
|
speaker_names=["Andrew", "Ava"]
|
|
)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str = "microsoft/VibeVoice-1.5b",
|
|
device: Optional[str] = None,
|
|
voices_dir: Optional[str] = None,
|
|
cfg_scale: float = 1.3,
|
|
ddpm_steps: int = 10
|
|
):
|
|
"""
|
|
Initialize the TTS engine with model and configuration.
|
|
|
|
Args:
|
|
model_path: Path to the HuggingFace model directory
|
|
device: Device for inference ('cuda', 'mps', 'cpu'). Auto-detected if None
|
|
voices_dir: Directory containing voice sample .wav files
|
|
cfg_scale: CFG (Classifier-Free Guidance) scale for generation
|
|
ddpm_steps: Number of DDPM inference steps
|
|
"""
|
|
self.model_path = model_path
|
|
self.cfg_scale = cfg_scale
|
|
self.ddpm_steps = ddpm_steps
|
|
|
|
# Auto-detect device if not specified
|
|
if device is None:
|
|
if torch.xpu.is_available():
|
|
device = "xpu"
|
|
elif torch.cuda.is_available():
|
|
device = "cuda"
|
|
elif torch.backends.mps.is_available():
|
|
device = "mps"
|
|
else:
|
|
device = "cpu"
|
|
|
|
# Handle potential typos
|
|
if device.lower() == "mpx":
|
|
logger.info("Note: device 'mpx' detected, treating it as 'mps'.")
|
|
device = "mps"
|
|
|
|
# Validate mps availability
|
|
if device == "mps" and not torch.backends.mps.is_available():
|
|
logger.info("Warning: MPS not available. Falling back to CPU.")
|
|
device = "cpu"
|
|
|
|
self.device = device
|
|
logger.info(f"Using device: {self.device}")
|
|
|
|
# Initialize voice mapper
|
|
self.voice_mapper = VoiceMapper(voices_dir)
|
|
|
|
# Load model and processor
|
|
self._load_model()
|
|
|
|
def _load_model(self):
|
|
"""Load the model and processor with device-specific configuration."""
|
|
logger.info(f"Loading processor & model from {self.model_path}")
|
|
# Ensure external vibevoice symbols are available (lazy import)
|
|
VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor = _import_vibevoice_symbols()
|
|
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
|
|
|
|
# Decide dtype & attention implementation
|
|
if self.device == "mps":
|
|
load_dtype = torch.float32 # MPS requires float32
|
|
attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
|
|
elif self.device == "cuda":
|
|
load_dtype = torch.bfloat16
|
|
attn_impl_primary = "flash_attention_2"
|
|
elif self.device == "xpu":
|
|
load_dtype = torch.bfloat16
|
|
attn_impl_primary = "sdpa" # flash_attention_2 not supported on XPU
|
|
else: # cpu
|
|
load_dtype = torch.float32
|
|
attn_impl_primary = "sdpa"
|
|
|
|
logger.info(f"Using torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
|
|
|
|
# Load model with device-specific logic
|
|
try:
|
|
if self.device == "mps":
|
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
|
self.model_path,
|
|
torch_dtype=load_dtype,
|
|
attn_implementation=attn_impl_primary,
|
|
device_map=None, # load then move
|
|
)
|
|
self.model.to("mps")
|
|
elif self.device == "cuda":
|
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
|
self.model_path,
|
|
torch_dtype=load_dtype,
|
|
device_map="cuda",
|
|
attn_implementation=attn_impl_primary,
|
|
)
|
|
elif self.device == "xpu":
|
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
|
self.model_path,
|
|
torch_dtype=load_dtype,
|
|
attn_implementation=attn_impl_primary,
|
|
device_map={"": self.device}, # Ensure XPU is used
|
|
low_cpu_mem_usage=True, # Optimized for Intel GPUs
|
|
)
|
|
try:
|
|
import intel_extension_for_pytorch as ipex
|
|
logger.info("Applying IPEX optimizations")
|
|
self.model = ipex.optimize(self.model, dtype=torch.bfloat16, inplace=True)
|
|
except ImportError:
|
|
logger.info("intel_extension_for_pytorch not found, proceeding without IPEX optimizations.")
|
|
else: # cpu
|
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
|
self.model_path,
|
|
torch_dtype=load_dtype,
|
|
device_map="cpu",
|
|
attn_implementation=attn_impl_primary,
|
|
)
|
|
except Exception as e:
|
|
if attn_impl_primary == 'flash_attention_2':
|
|
logger.info(f"[ERROR] : {type(e).__name__}: {e}")
|
|
logger.info(traceback.format_exc())
|
|
logger.info("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
|
|
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
|
|
self.model_path,
|
|
torch_dtype=load_dtype,
|
|
device_map=(self.device if self.device in ("cuda", "cpu", "xpu") else None),
|
|
attn_implementation='sdpa'
|
|
)
|
|
if self.device == "mps":
|
|
self.model.to("mps")
|
|
else:
|
|
raise e
|
|
|
|
self.model.eval()
|
|
self.model.set_ddpm_inference_steps(num_steps=self.ddpm_steps)
|
|
|
|
if hasattr(self.model.model, 'language_model'):
|
|
logger.info(f"Language model attention: {self.model.model.language_model.config._attn_implementation}")
|
|
|
|
logger.info("Model loaded successfully!")
|
|
|
|
def _parse_script(self, text: str) -> Tuple[List[str], List[str]]:
|
|
"""
|
|
Parse script text and extract speakers and their text.
|
|
Supports format: "Speaker 1: text", "Speaker 2: text", etc.
|
|
|
|
Returns: (scripts, speaker_numbers)
|
|
"""
|
|
lines = text.strip().split('\n')
|
|
scripts = []
|
|
speaker_numbers = []
|
|
|
|
# Pattern to match "Speaker X:" format where X is a number
|
|
speaker_pattern = r'^Speaker\s+(\d+):\s*(.*)$'
|
|
|
|
current_speaker = None
|
|
current_text = ""
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
match = re.match(speaker_pattern, line, re.IGNORECASE)
|
|
if match:
|
|
# If we have accumulated text from previous speaker, save it
|
|
if current_speaker and current_text:
|
|
scripts.append(f"Speaker {current_speaker}: {current_text.strip()}")
|
|
speaker_numbers.append(current_speaker)
|
|
|
|
# Start new speaker
|
|
current_speaker = match.group(1).strip()
|
|
current_text = match.group(2).strip()
|
|
else:
|
|
# Continue text for current speaker
|
|
if current_text:
|
|
current_text += " " + line
|
|
else:
|
|
current_text = line
|
|
|
|
# Don't forget the last speaker
|
|
if current_speaker and current_text:
|
|
scripts.append(f"Speaker {current_speaker}: {current_text.strip()}")
|
|
speaker_numbers.append(current_speaker)
|
|
|
|
return scripts, speaker_numbers
|
|
|
|
def text_to_speech(
|
|
self,
|
|
text: str,
|
|
speaker_names: Union[str, List[str]] = None,
|
|
cfg_scale: Optional[float] = None,
|
|
verbose: bool = False
|
|
) -> np.ndarray:
|
|
"""
|
|
Convert text to speech and return audio data.
|
|
|
|
Args:
|
|
text: Input text with speaker labels (e.g., "Speaker 1: Hello\nSpeaker 2: Hi there")
|
|
speaker_names: Speaker name(s) to map to voice files. Can be single string or list.
|
|
cfg_scale: Override default CFG scale for this generation
|
|
verbose: Print detailed generation info
|
|
|
|
Returns:
|
|
numpy.ndarray: Audio data as floating point array (sample rate: 24kHz)
|
|
"""
|
|
if cfg_scale is None:
|
|
cfg_scale = self.cfg_scale
|
|
|
|
# Parse the script to get speaker segments
|
|
scripts, speaker_numbers = self._parse_script(text)
|
|
|
|
if not scripts:
|
|
raise ValueError("No valid speaker scripts found in the input text")
|
|
|
|
if verbose:
|
|
logger.info(f"Found {len(scripts)} speaker segments:")
|
|
for i, (script, speaker_num) in enumerate(zip(scripts, speaker_numbers)):
|
|
logger.info(f" {i+1}. Speaker {speaker_num}")
|
|
logger.info(f" Text preview: {script[:100]}...")
|
|
|
|
# Handle speaker names
|
|
if speaker_names is None:
|
|
speaker_names = ["Andrew"] # Default speaker
|
|
elif isinstance(speaker_names, str):
|
|
speaker_names = [speaker_names]
|
|
|
|
# Map speaker numbers to provided speaker names
|
|
speaker_name_mapping = {}
|
|
for i, name in enumerate(speaker_names, 1):
|
|
speaker_name_mapping[str(i)] = name
|
|
|
|
if verbose:
|
|
logger.info("\nSpeaker mapping:")
|
|
for speaker_num in set(speaker_numbers):
|
|
mapped_name = speaker_name_mapping.get(speaker_num, f"Speaker {speaker_num}")
|
|
logger.info(f" Speaker {speaker_num} -> {mapped_name}")
|
|
|
|
# Map speakers to voice files
|
|
voice_samples = []
|
|
actual_speakers = []
|
|
|
|
# Get unique speaker numbers in order of first appearance
|
|
unique_speaker_numbers = []
|
|
seen = set()
|
|
for speaker_num in speaker_numbers:
|
|
if speaker_num not in seen:
|
|
unique_speaker_numbers.append(speaker_num)
|
|
seen.add(speaker_num)
|
|
|
|
for speaker_num in unique_speaker_numbers:
|
|
speaker_name = speaker_name_mapping.get(speaker_num, f"Speaker {speaker_num}")
|
|
voice_path = self.voice_mapper.get_voice_path(speaker_name)
|
|
voice_samples.append(voice_path)
|
|
actual_speakers.append(speaker_name)
|
|
if verbose:
|
|
logger.info(f"Speaker {speaker_num} ('{speaker_name}') -> Voice: {os.path.basename(voice_path)}")
|
|
|
|
# Prepare data for model
|
|
full_script = '\n'.join(scripts)
|
|
full_script = full_script.replace("'", "'")
|
|
|
|
# Prepare inputs for the model
|
|
inputs = self.processor(
|
|
text=[full_script], # Wrap in list for batch processing
|
|
voice_samples=[voice_samples], # Wrap in list for batch processing
|
|
padding=True,
|
|
return_tensors="pt",
|
|
return_attention_mask=True,
|
|
)
|
|
|
|
# Move tensors to target device
|
|
for k, v in inputs.items():
|
|
if torch.is_tensor(v):
|
|
inputs[k] = v.to(self.device, non_blocking=True)
|
|
|
|
if verbose:
|
|
logger.info(f"Starting generation with cfg_scale: {cfg_scale}")
|
|
|
|
# Generate audio
|
|
start_time = time.time()
|
|
outputs = self.model.generate(
|
|
**inputs,
|
|
max_new_tokens=None,
|
|
cfg_scale=cfg_scale,
|
|
tokenizer=self.processor.tokenizer,
|
|
generation_config={'do_sample': False},
|
|
verbose=verbose,
|
|
)
|
|
generation_time = time.time() - start_time
|
|
|
|
if verbose:
|
|
logger.info(f"Generation time: {generation_time:.2f} seconds")
|
|
|
|
# Calculate metrics
|
|
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
|
|
sample_rate = 24000
|
|
audio_samples = outputs.speech_outputs[0].shape[-1] if len(outputs.speech_outputs[0].shape) > 0 else len(outputs.speech_outputs[0])
|
|
audio_duration = audio_samples / sample_rate
|
|
rtf = generation_time / audio_duration if audio_duration > 0 else float('inf')
|
|
|
|
logger.info(f"Generated audio duration: {audio_duration:.2f} seconds")
|
|
logger.info(f"RTF (Real Time Factor): {rtf:.2f}x")
|
|
|
|
# Token metrics
|
|
input_tokens = inputs['input_ids'].shape[1]
|
|
output_tokens = outputs.sequences.shape[1]
|
|
generated_tokens = output_tokens - input_tokens
|
|
|
|
logger.info(f"Prefilling tokens: {input_tokens}")
|
|
logger.info(f"Generated tokens: {generated_tokens}")
|
|
logger.info(f"Total tokens: {output_tokens}")
|
|
|
|
# Return audio data as numpy array
|
|
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
|
|
audio_tensor = outputs.speech_outputs[0].to("cpu", non_blocking=True).float()
|
|
audio_data = audio_tensor.numpy()
|
|
return audio_data.squeeze()
|
|
|
|
# audio_tensor = outputs.speech_outputs[0]
|
|
|
|
# # Convert to numpy array on CPU, ensuring compatible dtype
|
|
# if hasattr(audio_tensor, 'cpu'):
|
|
# audio_data = audio_tensor.cpu().float().numpy() # Convert to float32 first
|
|
# else:
|
|
# audio_data = np.array(audio_tensor, dtype=np.float32)
|
|
|
|
# # Ensure it's a 1D array
|
|
# if audio_data.ndim > 1:
|
|
# audio_data = audio_data.squeeze()
|
|
|
|
# return audio_data
|
|
else:
|
|
raise RuntimeError("No audio output generated")
|
|
|
|
def get_available_voices(self) -> List[str]:
|
|
"""Get list of available voice names."""
|
|
return list(self.voice_mapper.available_voices.keys())
|
|
|
|
def get_sample_rate(self) -> int:
|
|
"""Get the sample rate of generated audio."""
|
|
return 24000 # VibeVoice uses 24kHz
|
|
|
|
|
|
# Global instance for easy access
|
|
_global_tts_engine = None
|
|
|
|
def get_tts_engine(**kwargs) -> VibeVoiceTTS:
|
|
"""
|
|
Get or create a global TTS engine instance.
|
|
|
|
Args:
|
|
**kwargs: Arguments to pass to VibeVoiceTTS constructor (only used on first call)
|
|
|
|
Returns:
|
|
VibeVoiceTTS: Global TTS engine instance
|
|
"""
|
|
global _global_tts_engine
|
|
if _global_tts_engine is None:
|
|
_global_tts_engine = VibeVoiceTTS(**kwargs)
|
|
return _global_tts_engine
|
|
|
|
|
|
# Convenience function for quick TTS
|
|
def text_to_speech(text: str, speaker_names: Optional[Union[str, List[str]]] = None, **kwargs: dict[str, Any]) -> np.ndarray:
|
|
"""
|
|
Quick text-to-speech conversion using global engine.
|
|
|
|
Args:
|
|
text: Input text with speaker labels
|
|
speaker_names: Speaker name(s) to use
|
|
**kwargs: Additional arguments for TTS engine or text_to_speech method
|
|
|
|
Returns:
|
|
numpy.ndarray: Audio data
|
|
"""
|
|
# Separate engine kwargs from TTS kwargs
|
|
engine_kwargs = {k: v for k, v in kwargs.items()
|
|
if k in ['model_path', 'device', 'voices_dir', 'cfg_scale', 'ddpm_steps']}
|
|
tts_kwargs = {k: v for k, v in kwargs.items() if k not in engine_kwargs}
|
|
|
|
engine = get_tts_engine(**engine_kwargs)
|
|
return engine.text_to_speech(text, speaker_names, **tts_kwargs)
|
|
|
|
|
|
|
|
# Example usage:
|
|
# Method 1: Create instance
|
|
# tts_engine = VibeVoiceTTS(model_path="microsoft/VibeVoice-1.5b")
|
|
# audio_data = tts_engine.text_to_speech(
|
|
# "Speaker 1: Hello world!\nSpeaker 2: How are you?",
|
|
# speaker_names=["Andrew", "Ava"]
|
|
# )
|
|
|
|
# # Method 2: Use global instance
|
|
# audio_data = text_to_speech(
|
|
# "Speaker 1: Hello world!",
|
|
# speaker_names="Andrew",
|
|
# verbose=True
|
|
# )
|
|
|
|
# # Method 3: Global engine with custom config
|
|
# engine = get_tts_engine(device="cuda", cfg_scale=1.5)
|
|
# audio_data = engine.text_to_speech("Speaker 1: Hello!", ["Andrew"]) |