from typing import Any import librosa from logger import logger from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq model_ids = { "Distil-Whisper": [ "distil-whisper/distil-large-v2", "distil-whisper/distil-medium.en", "distil-whisper/distil-small.en" ], "Whisper": [ "openai/whisper-large-v3", "openai/whisper-large-v2", "openai/whisper-large", "openai/whisper-medium", "openai/whisper-small", "openai/whisper-base", "openai/whisper-tiny", "openai/whisper-medium.en", "openai/whisper-small.en", "openai/whisper-base.en", "openai/whisper-tiny.en", ] } model_type = model_ids["Distil-Whisper"] logger.info(model_type) model_id = model_type[0] processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore pt_model.eval() # type: ignore def extract_input_features(audio_array: Any, sampling_rate: int) -> Any: """Extract input features from audio array and sampling rate.""" input_features = processor( audio_array, sampling_rate=sampling_rate, return_tensors="pt", ).input_features return input_features def load_audio_file(file_path: str) -> tuple[Any, int]: """Load audio file from disk and return audio array and sampling rate.""" # Whisper models expect 16kHz sample rate target_sample_rate = 16000 try: # Load audio file using librosa and resample to target rate audio_array, original_sampling_rate = librosa.load(file_path, sr=None) # type: ignore logger.info(f"Loaded audio file: {file_path}, duration: {len(audio_array)/original_sampling_rate:.2f}s, original sample rate: {original_sampling_rate}Hz") # type: ignore # Resample if necessary if original_sampling_rate != target_sample_rate: audio_array = librosa.resample(audio_array, orig_sr=original_sampling_rate, target_sr=target_sample_rate) # type: ignore logger.info(f"Resampled audio from {original_sampling_rate}Hz to {target_sample_rate}Hz") return audio_array, target_sample_rate # type: ignore except Exception as e: logger.error(f"Error loading audio file {file_path}: {e}") raise # Example usage - replace with your audio file path audio_file_path = "/voicebot/F_0818_15y11m_1.wav" # Load audio from file instead of dataset try: audio_array, sampling_rate = load_audio_file(audio_file_path) input_features = extract_input_features(audio_array, sampling_rate) predicted_ids = pt_model.generate(input_features) # type: ignore transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore print(f"Audio file: {audio_file_path}") print(f"Transcription: {transcription[0]}") except FileNotFoundError: logger.error(f"Audio file not found: {audio_file_path}") print("Please update the audio_file_path variable with a valid path to your wav file") except Exception as e: logger.error(f"Error processing audio: {e}") print(f"Error: {e}")