Move vibevoice to GPU
This commit is contained in:
parent
781748a182
commit
30886d9fa8
File diff suppressed because it is too large
Load Diff
@ -5,9 +5,47 @@ from typing import Any, List, Tuple, Union, Optional
|
|||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import sys
|
||||||
|
|
||||||
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
# Defer importing the external `vibevoice` package until we actually need it.
|
||||||
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
# In some environments the `vibevoice` package isn't installed into site-packages
|
||||||
|
# but the repo contains a local copy under `voicebot/VibeVoice`. Attempt a lazy
|
||||||
|
# import and, if that fails, add the local path(s) to sys.path and retry.
|
||||||
|
|
||||||
|
def _import_vibevoice_symbols():
|
||||||
|
try:
|
||||||
|
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
||||||
|
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
||||||
|
return VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor
|
||||||
|
except Exception:
|
||||||
|
# If a required package (like `diffusers`) is missing inside the
|
||||||
|
# container or venv, importing deeper VibeVoice modules will raise
|
||||||
|
# ModuleNotFoundError. Detect that and raise a clearer error that
|
||||||
|
# includes install instructions.
|
||||||
|
import traceback as _tb
|
||||||
|
exc_type, exc_val, exc_tb = _tb.sys.exc_info()
|
||||||
|
if isinstance(exc_val, ModuleNotFoundError):
|
||||||
|
missing = str(exc_val).split("'")[1] if "'" in str(exc_val) else str(exc_val)
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
f"Missing dependency when importing VibeVoice: {missing}.\n"
|
||||||
|
"Install required packages inside the voicebot container.\n"
|
||||||
|
"Example (inside container):\n"
|
||||||
|
" PYTHONPATH=/shared:/voicebot uv run python3 -m pip install diffusers accelerate safetensors\n"
|
||||||
|
"Or add the packages to the voicebot service environment / pyproject and rebuild."
|
||||||
|
) from exc_val
|
||||||
|
# Try adding likely repository-local paths where VibeVoice lives
|
||||||
|
base = os.path.dirname(__file__) # voicebot/bots
|
||||||
|
candidates = [
|
||||||
|
os.path.abspath(os.path.join(base, "..", "VibeVoice")),
|
||||||
|
os.path.abspath(os.path.join("/", "voicebot", "VibeVoice")),
|
||||||
|
]
|
||||||
|
for p in candidates:
|
||||||
|
if os.path.isdir(p) and p not in sys.path:
|
||||||
|
sys.path.insert(0, p)
|
||||||
|
# Retry import
|
||||||
|
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
|
||||||
|
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
|
||||||
|
return VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor
|
||||||
|
|
||||||
from shared.logger import logger
|
from shared.logger import logger
|
||||||
|
|
||||||
@ -155,6 +193,8 @@ class VibeVoiceTTS:
|
|||||||
def _load_model(self):
|
def _load_model(self):
|
||||||
"""Load the model and processor with device-specific configuration."""
|
"""Load the model and processor with device-specific configuration."""
|
||||||
logger.info(f"Loading processor & model from {self.model_path}")
|
logger.info(f"Loading processor & model from {self.model_path}")
|
||||||
|
# Ensure external vibevoice symbols are available (lazy import)
|
||||||
|
VibeVoiceForConditionalGenerationInference, VibeVoiceProcessor = _import_vibevoice_symbols()
|
||||||
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
|
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
|
||||||
|
|
||||||
# Decide dtype & attention implementation
|
# Decide dtype & attention implementation
|
||||||
@ -401,15 +441,16 @@ class VibeVoiceTTS:
|
|||||||
logger.info(f"Generated tokens: {generated_tokens}")
|
logger.info(f"Generated tokens: {generated_tokens}")
|
||||||
logger.info(f"Total tokens: {output_tokens}")
|
logger.info(f"Total tokens: {output_tokens}")
|
||||||
|
|
||||||
|
# Return audio data as numpy array
|
||||||
# Return audio data as numpy array
|
# Return audio data as numpy array
|
||||||
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
|
if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
|
||||||
audio_tensor = outputs.speech_outputs[0]
|
audio_tensor = outputs.speech_outputs[0]
|
||||||
|
|
||||||
# Convert to numpy array on CPU
|
# Convert to numpy array on CPU, ensuring compatible dtype
|
||||||
if hasattr(audio_tensor, 'cpu'):
|
if hasattr(audio_tensor, 'cpu'):
|
||||||
audio_data = audio_tensor.cpu().numpy()
|
audio_data = audio_tensor.cpu().float().numpy() # Convert to float32 first
|
||||||
else:
|
else:
|
||||||
audio_data = np.array(audio_tensor)
|
audio_data = np.array(audio_tensor, dtype=np.float32)
|
||||||
|
|
||||||
# Ensure it's a 1D array
|
# Ensure it's a 1D array
|
||||||
if audio_data.ndim > 1:
|
if audio_data.ndim > 1:
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/xpu
|
||||||
about-time==4.2.1
|
about-time==4.2.1
|
||||||
aiofiles==24.1.0
|
aiofiles==24.1.0
|
||||||
aiohappyeyeballs==2.6.1
|
aiohappyeyeballs==2.6.1
|
||||||
@ -74,20 +75,6 @@ ninja==1.13.0
|
|||||||
nncf==2.18.0
|
nncf==2.18.0
|
||||||
numba==0.61.2
|
numba==0.61.2
|
||||||
numpy==2.2.6
|
numpy==2.2.6
|
||||||
nvidia-cublas-cu12==12.8.4.1
|
|
||||||
nvidia-cuda-cupti-cu12==12.8.90
|
|
||||||
nvidia-cuda-nvrtc-cu12==12.8.93
|
|
||||||
nvidia-cuda-runtime-cu12==12.8.90
|
|
||||||
nvidia-cudnn-cu12==9.10.2.21
|
|
||||||
nvidia-cufft-cu12==11.3.3.83
|
|
||||||
nvidia-cufile-cu12==1.13.1.3
|
|
||||||
nvidia-curand-cu12==10.3.9.90
|
|
||||||
nvidia-cusolver-cu12==11.7.3.90
|
|
||||||
nvidia-cusparse-cu12==12.5.8.93
|
|
||||||
nvidia-cusparselt-cu12==0.7.1
|
|
||||||
nvidia-nccl-cu12==2.27.3
|
|
||||||
nvidia-nvjitlink-cu12==12.8.93
|
|
||||||
nvidia-nvtx-cu12==12.8.90
|
|
||||||
onnx==1.19.0
|
onnx==1.19.0
|
||||||
openai==1.107.2
|
openai==1.107.2
|
||||||
openai-whisper @ git+https://github.com/openai/whisper.git@c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
|
openai-whisper @ git+https://github.com/openai/whisper.git@c0d2f624c09dc18e709e37c2ad90c039a4eb72a2
|
||||||
@ -157,11 +144,13 @@ threadpoolctl==3.6.0
|
|||||||
tiktoken==0.11.0
|
tiktoken==0.11.0
|
||||||
tokenizers==0.21.4
|
tokenizers==0.21.4
|
||||||
tomlkit==0.13.3
|
tomlkit==0.13.3
|
||||||
torch==2.8.0
|
|
||||||
torchvision==0.23.0
|
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
|
torch==2.8.0+xpu
|
||||||
|
torchvision==0.23.0+xpu
|
||||||
|
torchaudio==2.8.0+xpu
|
||||||
transformers==4.53.3
|
transformers==4.53.3
|
||||||
triton==3.4.0
|
diffusers
|
||||||
|
accelerate
|
||||||
typer==0.17.4
|
typer==0.17.4
|
||||||
typing-extensions==4.15.0
|
typing-extensions==4.15.0
|
||||||
typing-inspection==0.4.1
|
typing-inspection==0.4.1
|
||||||
@ -172,4 +161,4 @@ watchdog==6.0.0
|
|||||||
websockets==15.0.1
|
websockets==15.0.1
|
||||||
wrapt==1.17.3
|
wrapt==1.17.3
|
||||||
xxhash==3.5.0
|
xxhash==3.5.0
|
||||||
yarl==1.20.1
|
yarl==1.20.1
|
Loading…
x
Reference in New Issue
Block a user