from typing import Any import librosa import numpy as np from logger import logger from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq model_ids = { "Distil-Whisper": [ "distil-whisper/distil-large-v2", "distil-whisper/distil-medium.en", "distil-whisper/distil-small.en" ], "Whisper": [ "openai/whisper-large-v3", "openai/whisper-large-v2", "openai/whisper-large", "openai/whisper-medium", "openai/whisper-small", "openai/whisper-base", "openai/whisper-tiny", "openai/whisper-medium.en", "openai/whisper-small.en", "openai/whisper-base.en", "openai/whisper-tiny.en", ] } model_type = model_ids["Distil-Whisper"] logger.info(model_type) model_id = model_type[0] processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore pt_model.eval() # type: ignore def extract_input_features(audio_array: np.ndarray, sampling_rate: int): """Extract input features from audio array and sampling rate.""" input_features = processor( audio_array, sampling_rate=sampling_rate, return_tensors="pt", ).input_features return input_features def load_audio_file(file_path: str) -> tuple[np.ndarray, int]: """Load audio file from disk and return audio array and sampling rate.""" try: # Load audio file using librosa audio_array, sampling_rate = librosa.load(file_path, sr=None) logger.info(f"Loaded audio file: {file_path}, duration: {len(audio_array)/sampling_rate:.2f}s, sample rate: {sampling_rate}Hz") return audio_array, int(sampling_rate) # Ensure sampling_rate is int except Exception as e: logger.error(f"Error loading audio file {file_path}: {e}") raise # Example usage - replace with your audio file path audio_file_path = "/voicebot/F_0818_15y11m_1.wav" # Load audio from file instead of dataset try: audio_array, sampling_rate = load_audio_file(audio_file_path) input_features = extract_input_features(audio_array, sampling_rate) predicted_ids = pt_model.generate(input_features) # type: ignore transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore print(f"Audio file: {audio_file_path}") print(f"Transcription: {transcription[0]}") except FileNotFoundError: logger.error(f"Audio file not found: {audio_file_path}") print("Please update the audio_file_path variable with a valid path to your wav file") except Exception as e: logger.error(f"Error processing audio: {e}") print(f"Error: {e}")