113 lines
4.0 KiB
Python

"""Bots package whisper agent (bots/whisper)
Lightweight agent descriptor; heavy model loading must be done by a controller
when the agent is actually used.
"""
from typing import Dict, Any
import librosa
from logger import logger
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
from aiortc import MediaStreamTrack
AGENT_NAME = "whisper"
AGENT_DESCRIPTION = "Speech recognition agent (Whisper) - processes incoming audio"
def agent_info() -> Dict[str, str]:
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION}
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
"""Whisper is not a media source - return no local tracks."""
return {}
def do_work():
model_ids = {
"Distil-Whisper": [
"distil-whisper/distil-large-v2",
"distil-whisper/distil-medium.en",
"distil-whisper/distil-small.en"
],
"Whisper": [
"openai/whisper-large-v3",
"openai/whisper-large-v2",
"openai/whisper-large",
"openai/whisper-medium",
"openai/whisper-small",
"openai/whisper-base",
"openai/whisper-tiny",
"openai/whisper-medium.en",
"openai/whisper-small.en",
"openai/whisper-base.en",
"openai/whisper-tiny.en",
]
}
model_type = model_ids["Distil-Whisper"]
logger.info(model_type)
model_id = model_type[0]
processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
pt_model.eval() # type: ignore
def extract_input_features(audio_array: Any, sampling_rate: int) -> Any:
"""Extract input features from audio array and sampling rate."""
processor_output = processor( # type: ignore
audio_array,
sampling_rate=sampling_rate,
return_tensors="pt",
)
input_features: Any = processor_output.input_features # type: ignore
return input_features # type: ignore
def load_audio_file(file_path: str) -> tuple[Any, int]:
"""Load audio file from disk and return audio array and sampling rate."""
# Whisper models expect 16kHz sample rate
target_sample_rate = 16000
try:
# Load audio file using librosa and resample to target rate
audio_array, original_sampling_rate = librosa.load(file_path, sr=None) # type: ignore
logger.info(f"Loaded audio file: {file_path}, duration: {len(audio_array)/original_sampling_rate:.2f}s, original sample rate: {original_sampling_rate}Hz") # type: ignore
# Resample if necessary
if original_sampling_rate != target_sample_rate:
audio_array = librosa.resample(audio_array, orig_sr=original_sampling_rate, target_sr=target_sample_rate) # type: ignore
logger.info(f"Resampled audio from {original_sampling_rate}Hz to {target_sample_rate}Hz")
return audio_array, target_sample_rate # type: ignore
except Exception as e:
logger.error(f"Error loading audio file {file_path}: {e}")
raise
# Example usage - replace with your audio file path
audio_file_path = "/voicebot/F_0818_15y11m_1.wav"
# Load audio from file instead of dataset
try:
audio_array, sampling_rate = load_audio_file(audio_file_path)
input_features = extract_input_features(audio_array, sampling_rate)
predicted_ids = pt_model.generate(input_features) # type: ignore
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore
print(f"Audio file: {audio_file_path}")
print(f"Transcription: {transcription[0]}")
except FileNotFoundError:
logger.error(f"Audio file not found: {audio_file_path}")
print("Please update the audio_file_path variable with a valid path to your wav file")
except Exception as e:
logger.error(f"Error processing audio: {e}")
print(f"Error: {e}")