ai-voicebot/voicebot/bots/vibevoicetts.py

540 lines
22 KiB
Python

import os
import re
import traceback
from typing import Any, List, Tuple, Union, Optional
import time
import torch
import numpy as np
import sys
# Defer importing the external `vibevoice` package until we actually need it.
# In some environments the `vibevoice` package isn't installed into site-packages
# but the repo contains a local copy under `voicebot/VibeVoice`. Attempt a lazy
# import and, if that fails, add the local path(s) to sys.path and retry.
def _import_vibevoice_symbols():
try:
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
return VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor
except Exception:
# If a required package (like `diffusers`) is missing inside the
# container or venv, importing deeper VibeVoice modules will raise
# ModuleNotFoundError. Detect that and raise a clearer error that
# includes install instructions.
import traceback as _tb
exc_type, exc_val, exc_tb = _tb.sys.exc_info()
if isinstance(exc_val, ModuleNotFoundError):
missing = str(exc_val).split("'")[1] if "'" in str(exc_val) else str(exc_val)
raise ModuleNotFoundError(
f"Missing dependency when importing VibeVoice: {missing}.\n"
"Install required packages inside the voicebot container.\n"
"Example (inside container):\n"
" PYTHONPATH=/shared:/voicebot uv run python3 -m pip install diffusers accelerate safetensors\n"
"Or add the packages to the voicebot service environment / pyproject and rebuild."
) from exc_val
# Try adding likely repository-local paths where VibeVoice lives
base = os.path.dirname(__file__) # voicebot/bots
candidates = [
os.path.abspath(os.path.join(base, "..", "VibeVoice")),
os.path.abspath(os.path.join("/", "voicebot", "VibeVoice")),
]
for p in candidates:
if os.path.isdir(p) and p not in sys.path:
sys.path.insert(0, p)
# Retry import
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
return VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor
from shared.logger import logger
class VoiceMapper:
"""Maps speaker names to voice file paths"""
def __init__(self, voices_dir: Optional[str] = None):
if voices_dir is None:
voices_dir = os.path.join(os.path.dirname(__file__), "voices")
self.voices_dir = voices_dir
self.setup_voice_presets()
# Change name according to our preset wav file
new_dict = {}
for name, path in self.voice_presets.items():
if '_' in name:
name = name.split('_')[0]
if '-' in name:
name = name.split('-')[-1]
new_dict[name] = path
self.voice_presets.update(new_dict)
def setup_voice_presets(self):
"""Setup voice presets by scanning the voices directory."""
# Check if voices directory exists
if not os.path.exists(self.voices_dir):
logger.info(f"Warning: Voices directory not found at {self.voices_dir}")
self.voice_presets = {}
self.available_voices = {}
return
# Scan for all WAV files in the voices directory
self.voice_presets = {}
# Get all .wav files in the voices directory
wav_files = [f for f in os.listdir(self.voices_dir)
if f.lower().endswith('.wav') and os.path.isfile(os.path.join(self.voices_dir, f))]
# Create dictionary with filename (without extension) as key
for wav_file in wav_files:
# Remove .wav extension to get the name
name = os.path.splitext(wav_file)[0]
# Create full path
full_path = os.path.join(self.voices_dir, wav_file)
self.voice_presets[name] = full_path
# Sort the voice presets alphabetically by name for better UI
self.voice_presets = dict(sorted(self.voice_presets.items()))
# Filter out voices that don't exist (this is now redundant but kept for safety)
self.available_voices = {
name: path for name, path in self.voice_presets.items()
if os.path.exists(path)
}
logger.info(f"Found {len(self.available_voices)} voice files in {self.voices_dir}")
if self.available_voices:
logger.info(f"Available voices: {', '.join(self.available_voices.keys())}")
def get_voice_path(self, speaker_name: str) -> str:
"""Get voice file path for a given speaker name"""
# First try exact match
if speaker_name in self.voice_presets:
return self.voice_presets[speaker_name]
# Try partial matching (case insensitive)
speaker_lower = speaker_name.lower()
for preset_name, path in self.voice_presets.items():
if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower():
return path
# Default to first voice if no match found
if not self.voice_presets:
raise ValueError("No voice files available")
default_voice = list(self.voice_presets.values())[0]
logger.info(f"Warning: No voice preset found for '{speaker_name}', using default voice: {default_voice}")
return default_voice
class VibeVoiceTTS:
"""
A reusable Text-to-Speech engine using VibeVoice model.
Example usage:
tts_engine = VibeVoiceTTS(model_path="microsoft/VibeVoice-1.5b")
audio_data = tts_engine.text_to_speech(
text="Speaker 1: Hello world!\nSpeaker 2: How are you?",
speaker_names=["Andrew", "Ava"]
)
"""
def __init__(
self,
model_path: str = "microsoft/VibeVoice-1.5b",
device: Optional[str] = None,
voices_dir: Optional[str] = None,
cfg_scale: float = 1.3,
ddpm_steps: int = 10
):
"""
Initialize the TTS engine with model and configuration.
Args:
model_path: Path to the HuggingFace model directory
device: Device for inference ('cuda', 'mps', 'cpu'). Auto-detected if None
voices_dir: Directory containing voice sample .wav files
cfg_scale: CFG (Classifier-Free Guidance) scale for generation
ddpm_steps: Number of DDPM inference steps
"""
self.model_path = model_path
self.cfg_scale = cfg_scale
self.ddpm_steps = ddpm_steps
# Auto-detect device if not specified
if device is None:
if torch.xpu.is_available():
device = "xpu"
elif torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
# Handle potential typos
if device.lower() == "mpx":
logger.info("Note: device 'mpx' detected, treating it as 'mps'.")
device = "mps"
# Validate mps availability
if device == "mps" and not torch.backends.mps.is_available():
logger.info("Warning: MPS not available. Falling back to CPU.")
device = "cpu"
self.device = device
logger.info(f"Using device: {self.device}")
# Initialize voice mapper
self.voice_mapper = VoiceMapper(voices_dir)
# Load model and processor
self._load_model()
def _load_model(self):
"""Load the model and processor with device-specific configuration."""
logger.info(f"Loading processor & model from {self.model_path}")
# Ensure external vibevoice symbols are available (lazy import)
VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor = _import_vibevoice_symbols()
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
# Decide dtype & attention implementation
if self.device == "mps":
load_dtype = torch.float32 # MPS requires float32
attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS
elif self.device == "cuda":
load_dtype = torch.bfloat16
attn_impl_primary = "flash_attention_2"
elif self.device == "xpu":
load_dtype = torch.bfloat16
attn_impl_primary = "sdpa" # flash_attention_2 not supported on XPU
else: # cpu
load_dtype = torch.float32
attn_impl_primary = "sdpa"
logger.info(f"Using torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
# Load model with device-specific logic
try:
if self.device == "mps":
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
attn_implementation=attn_impl_primary,
device_map=None, # load then move
)
self.model.to("mps")
elif self.device == "cuda":
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl_primary,
)
elif self.device == "xpu":
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
attn_implementation=attn_impl_primary,
device_map={"": self.device}, # Ensure XPU is used
low_cpu_mem_usage=True, # Optimized for Intel GPUs
)
try:
import intel_extension_for_pytorch as ipex
logger.info("Applying IPEX optimizations")
self.model = ipex.optimize(self.model, dtype=torch.bfloat16, inplace=True)
except ImportError:
logger.info("intel_extension_for_pytorch not found, proceeding without IPEX optimizations.")
else: # cpu
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map="cpu",
attn_implementation=attn_impl_primary,
)
except Exception as e:
if attn_impl_primary == 'flash_attention_2':
logger.info(f"[ERROR] : {type(e).__name__}: {e}")
logger.info(traceback.format_exc())
logger.info("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=load_dtype,
device_map=(self.device if self.device in ("cuda", "cpu", "xpu") else None),
attn_implementation='sdpa'
)
if self.device == "mps":
self.model.to("mps")
else:
raise e
self.model.eval()
self.model.set_ddpm_inference_steps(num_steps=self.ddpm_steps)
if hasattr(self.model.model, 'language_model'):
logger.info(f"Language model attention: {self.model.model.language_model.config._attn_implementation}")
logger.info("Model loaded successfully!")
def _parse_script(self, text: str) -> Tuple[List[str], List[str]]:
"""
Parse script text and extract speakers and their text.
Supports format: "Speaker 1: text", "Speaker 2: text", etc.
Returns: (scripts, speaker_numbers)
"""
lines = text.strip().split('\n')
scripts = []
speaker_numbers = []
# Pattern to match "Speaker X:" format where X is a number
speaker_pattern = r'^Speaker\s+(\d+):\s*(.*)$'
current_speaker = None
current_text = ""
for line in lines:
line = line.strip()
if not line:
continue
match = re.match(speaker_pattern, line, re.IGNORECASE)
if match:
# If we have accumulated text from previous speaker, save it
if current_speaker and current_text:
scripts.append(f"Speaker {current_speaker}: {current_text.strip()}")
speaker_numbers.append(current_speaker)
# Start new speaker
current_speaker = match.group(1).strip()
current_text = match.group(2).strip()
else:
# Continue text for current speaker
if current_text:
current_text += " " + line
else:
current_text = line
# Don't forget the last speaker
if current_speaker and current_text:
scripts.append(f"Speaker {current_speaker}: {current_text.strip()}")
speaker_numbers.append(current_speaker)
return scripts, speaker_numbers
def text_to_speech(
self,
text: str,
speaker_names: Union[str, List[str]] = None,
cfg_scale: Optional[float] = None,
verbose: bool = False
) -> np.ndarray:
"""
Convert text to speech and return audio data.
Args:
text: Input text with speaker labels (e.g., "Speaker 1: Hello\nSpeaker 2: Hi there")
speaker_names: Speaker name(s) to map to voice files. Can be single string or list.
cfg_scale: Override default CFG scale for this generation
verbose: Print detailed generation info
Returns:
numpy.ndarray: Audio data as floating point array (sample rate: 24kHz)
"""
if cfg_scale is None:
cfg_scale = self.cfg_scale
# Parse the script to get speaker segments
scripts, speaker_numbers = self._parse_script(text)
if not scripts:
raise ValueError("No valid speaker scripts found in the input text")
if verbose:
logger.info(f"Found {len(scripts)} speaker segments:")
for i, (script, speaker_num) in enumerate(zip(scripts, speaker_numbers)):
logger.info(f" {i+1}. Speaker {speaker_num}")
logger.info(f" Text preview: {script[:100]}...")
# Handle speaker names
if speaker_names is None:
speaker_names = ["Andrew"] # Default speaker
elif isinstance(speaker_names, str):
speaker_names = [speaker_names]
# Map speaker numbers to provided speaker names
speaker_name_mapping = {}
for i, name in enumerate(speaker_names, 1):
speaker_name_mapping[str(i)] = name
if verbose:
logger.info("\nSpeaker mapping:")
for speaker_num in set(speaker_numbers):
mapped_name = speaker_name_mapping.get(speaker_num, f"Speaker {speaker_num}")
logger.info(f" Speaker {speaker_num} -> {mapped_name}")
# Map speakers to voice files
voice_samples = []
actual_speakers = []
# Get unique speaker numbers in order of first appearance
unique_speaker_numbers = []
seen = set()
for speaker_num in speaker_numbers:
if speaker_num not in seen:
unique_speaker_numbers.append(speaker_num)
seen.add(speaker_num)
for speaker_num in unique_speaker_numbers:
speaker_name = speaker_name_mapping.get(speaker_num, f"Speaker {speaker_num}")
voice_path = self.voice_mapper.get_voice_path(speaker_name)
voice_samples.append(voice_path)
actual_speakers.append(speaker_name)
if verbose:
logger.info(f"Speaker {speaker_num} ('{speaker_name}') -> Voice: {os.path.basename(voice_path)}")
# Prepare data for model
full_script = '\n'.join(scripts)
full_script = full_script.replace("'", "'")
# Prepare inputs for the model
inputs = self.processor(
text=[full_script], # Wrap in list for batch processing
voice_samples=[voice_samples], # Wrap in list for batch processing
padding=True,
return_tensors="pt",
return_attention_mask=True,
)
# Move tensors to target device
for k, v in inputs.items():
if torch.is_tensor(v):
inputs[k] = v.to(self.device, non_blocking=True)
if verbose:
logger.info(f"Starting generation with cfg_scale: {cfg_scale}")
# Generate audio
start_time = time.time()
outputs = self.model.generate(
**inputs,
max_new_tokens=None,
cfg_scale=cfg_scale,
tokenizer=self.processor.tokenizer,
generation_config={'do_sample': False},
verbose=verbose,
)
generation_time = time.time() - start_time
if verbose:
logger.info(f"Generation time: {generation_time:.2f} seconds")
# Calculate metrics
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
sample_rate = 24000
audio_samples = outputs.speech_outputs[0].shape[-1] if len(outputs.speech_outputs[0].shape) > 0 else len(outputs.speech_outputs[0])
audio_duration = audio_samples / sample_rate
rtf = generation_time / audio_duration if audio_duration > 0 else float('inf')
logger.info(f"Generated audio duration: {audio_duration:.2f} seconds")
logger.info(f"RTF (Real Time Factor): {rtf:.2f}x")
# Token metrics
input_tokens = inputs['input_ids'].shape[1]
output_tokens = outputs.sequences.shape[1]
generated_tokens = output_tokens - input_tokens
logger.info(f"Prefilling tokens: {input_tokens}")
logger.info(f"Generated tokens: {generated_tokens}")
logger.info(f"Total tokens: {output_tokens}")
# Return audio data as numpy array
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
audio_tensor = outputs.speech_outputs[0].to("cpu", non_blocking=True).float()
audio_data = audio_tensor.numpy()
return audio_data.squeeze()
# audio_tensor = outputs.speech_outputs[0]
# # Convert to numpy array on CPU, ensuring compatible dtype
# if hasattr(audio_tensor, 'cpu'):
# audio_data = audio_tensor.cpu().float().numpy() # Convert to float32 first
# else:
# audio_data = np.array(audio_tensor, dtype=np.float32)
# # Ensure it's a 1D array
# if audio_data.ndim > 1:
# audio_data = audio_data.squeeze()
# return audio_data
else:
raise RuntimeError("No audio output generated")
def get_available_voices(self) -> List[str]:
"""Get list of available voice names."""
return list(self.voice_mapper.available_voices.keys())
def get_sample_rate(self) -> int:
"""Get the sample rate of generated audio."""
return 24000 # VibeVoice uses 24kHz
# Global instance for easy access
_global_tts_engine = None
def get_tts_engine(**kwargs) -> VibeVoiceTTS:
"""
Get or create a global TTS engine instance.
Args:
**kwargs: Arguments to pass to VibeVoiceTTS constructor (only used on first call)
Returns:
VibeVoiceTTS: Global TTS engine instance
"""
global _global_tts_engine
if _global_tts_engine is None:
_global_tts_engine = VibeVoiceTTS(**kwargs)
return _global_tts_engine
# Convenience function for quick TTS
def text_to_speech(text: str, speaker_names: Optional[Union[str, List[str]]] = None, **kwargs: dict[str, Any]) -> np.ndarray:
"""
Quick text-to-speech conversion using global engine.
Args:
text: Input text with speaker labels
speaker_names: Speaker name(s) to use
**kwargs: Additional arguments for TTS engine or text_to_speech method
Returns:
numpy.ndarray: Audio data
"""
# Separate engine kwargs from TTS kwargs
engine_kwargs = {k: v for k, v in kwargs.items()
if k in ['model_path', 'device', 'voices_dir', 'cfg_scale', 'ddpm_steps']}
tts_kwargs = {k: v for k, v in kwargs.items() if k not in engine_kwargs}
engine = get_tts_engine(**engine_kwargs)
return engine.text_to_speech(text, speaker_names, **tts_kwargs)
# Example usage:
# Method 1: Create instance
# tts_engine = VibeVoiceTTS(model_path="microsoft/VibeVoice-1.5b")
# audio_data = tts_engine.text_to_speech(
# "Speaker 1: Hello world!\nSpeaker 2: How are you?",
# speaker_names=["Andrew", "Ava"]
# )
# # Method 2: Use global instance
# audio_data = text_to_speech(
# "Speaker 1: Hello world!",
# speaker_names="Andrew",
# verbose=True
# )
# # Method 3: Global engine with custom config
# engine = get_tts_engine(device="cuda", cfg_scale=1.5)
# audio_data = engine.text_to_speech("Speaker 1: Hello!", ["Andrew"])