81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
from typing import Any
|
|
import librosa
|
|
import numpy as np
|
|
from logger import logger
|
|
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
|
|
|
model_ids = {
|
|
"Distil-Whisper": [
|
|
"distil-whisper/distil-large-v2",
|
|
"distil-whisper/distil-medium.en",
|
|
"distil-whisper/distil-small.en"
|
|
],
|
|
"Whisper": [
|
|
"openai/whisper-large-v3",
|
|
"openai/whisper-large-v2",
|
|
"openai/whisper-large",
|
|
"openai/whisper-medium",
|
|
"openai/whisper-small",
|
|
"openai/whisper-base",
|
|
"openai/whisper-tiny",
|
|
"openai/whisper-medium.en",
|
|
"openai/whisper-small.en",
|
|
"openai/whisper-base.en",
|
|
"openai/whisper-tiny.en",
|
|
]
|
|
}
|
|
|
|
model_type = model_ids["Distil-Whisper"]
|
|
|
|
logger.info(model_type)
|
|
model_id = model_type[0]
|
|
|
|
|
|
processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
|
|
|
|
pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
|
|
pt_model.eval() # type: ignore
|
|
|
|
|
|
def extract_input_features(audio_array: np.ndarray, sampling_rate: int):
|
|
"""Extract input features from audio array and sampling rate."""
|
|
input_features = processor(
|
|
audio_array,
|
|
sampling_rate=sampling_rate,
|
|
return_tensors="pt",
|
|
).input_features
|
|
return input_features
|
|
|
|
|
|
def load_audio_file(file_path: str) -> tuple[np.ndarray, int]:
|
|
"""Load audio file from disk and return audio array and sampling rate."""
|
|
try:
|
|
# Load audio file using librosa
|
|
audio_array, sampling_rate = librosa.load(file_path, sr=None)
|
|
logger.info(f"Loaded audio file: {file_path}, duration: {len(audio_array)/sampling_rate:.2f}s, sample rate: {sampling_rate}Hz")
|
|
return audio_array, int(sampling_rate) # Ensure sampling_rate is int
|
|
except Exception as e:
|
|
logger.error(f"Error loading audio file {file_path}: {e}")
|
|
raise
|
|
|
|
|
|
# Example usage - replace with your audio file path
|
|
audio_file_path = "/voicebot/F_0818_15y11m_1.wav"
|
|
|
|
# Load audio from file instead of dataset
|
|
try:
|
|
audio_array, sampling_rate = load_audio_file(audio_file_path)
|
|
input_features = extract_input_features(audio_array, sampling_rate)
|
|
|
|
predicted_ids = pt_model.generate(input_features) # type: ignore
|
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore
|
|
|
|
print(f"Audio file: {audio_file_path}")
|
|
print(f"Transcription: {transcription[0]}")
|
|
|
|
except FileNotFoundError:
|
|
logger.error(f"Audio file not found: {audio_file_path}")
|
|
print("Please update the audio_file_path variable with a valid path to your wav file")
|
|
except Exception as e:
|
|
logger.error(f"Error processing audio: {e}")
|
|
print(f"Error: {e}") |