Features added: - WebSocket chat message handling in WebRTC signaling client - Bot chat handler discovery and automatic setup - Chat message sending/receiving capabilities - Example chatbot with conversation features - Enhanced whisper bot with chat commands - Comprehensive error handling and logging - Full integration with existing WebRTC infrastructure Bots can now: - Receive chat messages from lobby participants - Send responses back through WebSocket - Process commands and keywords - Integrate seamlessly with voice/video functionality Files modified: - voicebot/webrtc_signaling.py: Added chat message handling - voicebot/bot_orchestrator.py: Enhanced bot discovery for chat - voicebot/bots/whisper.py: Added chat command processing - voicebot/bots/chatbot.py: New conversational bot - voicebot/bots/__init__.py: Added chatbot module - CHAT_INTEGRATION.md: Comprehensive documentation - README.md: Updated with chat functionality info
146 lines
5.5 KiB
Python
146 lines
5.5 KiB
Python
"""Bots package whisper agent (bots/whisper)
|
|
|
|
Lightweight agent descriptor; heavy model loading must be done by a controller
|
|
when the agent is actually used.
|
|
"""
|
|
|
|
from typing import Dict, Any, Optional, Callable, Awaitable
|
|
import librosa
|
|
from logger import logger
|
|
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
|
from aiortc import MediaStreamTrack
|
|
|
|
# Import shared models for chat functionality
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
from shared.models import ChatMessageModel
|
|
|
|
|
|
AGENT_NAME = "whisper"
|
|
AGENT_DESCRIPTION = "Speech recognition agent (Whisper) - processes incoming audio"
|
|
|
|
|
|
def agent_info() -> Dict[str, str]:
|
|
return {"name": AGENT_NAME, "description": AGENT_DESCRIPTION}
|
|
|
|
|
|
def create_agent_tracks(session_name: str) -> dict[str, MediaStreamTrack]:
|
|
"""Whisper is not a media source - return no local tracks."""
|
|
return {}
|
|
|
|
async def handle_chat_message(chat_message: ChatMessageModel, send_message_func: Callable[[str], Awaitable[None]]) -> Optional[str]:
|
|
"""Handle incoming chat messages and optionally return a response.
|
|
|
|
Args:
|
|
chat_message: The received chat message
|
|
send_message_func: Function to send messages back to the lobby
|
|
|
|
Returns:
|
|
Optional response message to send back to the lobby
|
|
"""
|
|
logger.info(f"Whisper bot received chat message from {chat_message.sender_name}: {chat_message.message}")
|
|
|
|
# Simple echo bot behavior for demonstration
|
|
if chat_message.message.lower().startswith("whisper:"):
|
|
command = chat_message.message[8:].strip() # Remove "whisper:" prefix
|
|
if command.lower() == "hello":
|
|
return f"Hello {chat_message.sender_name}! I'm the Whisper speech recognition bot."
|
|
elif command.lower() == "help":
|
|
return "I can process speech and respond to simple commands. Try 'whisper: hello' or 'whisper: status'"
|
|
elif command.lower() == "status":
|
|
return "Whisper bot is running and ready to process audio and chat messages."
|
|
else:
|
|
return f"I heard you say: {command}. Try 'whisper: help' for available commands."
|
|
|
|
# Don't respond to other messages
|
|
return None
|
|
|
|
def do_work():
|
|
model_ids = {
|
|
"Distil-Whisper": [
|
|
"distil-whisper/distil-large-v2",
|
|
"distil-whisper/distil-medium.en",
|
|
"distil-whisper/distil-small.en"
|
|
],
|
|
"Whisper": [
|
|
"openai/whisper-large-v3",
|
|
"openai/whisper-large-v2",
|
|
"openai/whisper-large",
|
|
"openai/whisper-medium",
|
|
"openai/whisper-small",
|
|
"openai/whisper-base",
|
|
"openai/whisper-tiny",
|
|
"openai/whisper-medium.en",
|
|
"openai/whisper-small.en",
|
|
"openai/whisper-base.en",
|
|
"openai/whisper-tiny.en",
|
|
]
|
|
}
|
|
|
|
model_type = model_ids["Distil-Whisper"]
|
|
|
|
logger.info(model_type)
|
|
model_id = model_type[0]
|
|
|
|
|
|
processor: Any = AutoProcessor.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
|
|
|
|
pt_model: Any = AutoModelForSpeechSeq2Seq.from_pretrained(pretrained_model_name_or_path=model_id) # type: ignore
|
|
pt_model.eval() # type: ignore
|
|
|
|
|
|
def extract_input_features(audio_array: Any, sampling_rate: int) -> Any:
|
|
"""Extract input features from audio array and sampling rate."""
|
|
processor_output = processor( # type: ignore
|
|
audio_array,
|
|
sampling_rate=sampling_rate,
|
|
return_tensors="pt",
|
|
)
|
|
input_features: Any = processor_output.input_features # type: ignore
|
|
return input_features # type: ignore
|
|
|
|
|
|
def load_audio_file(file_path: str) -> tuple[Any, int]:
|
|
"""Load audio file from disk and return audio array and sampling rate."""
|
|
# Whisper models expect 16kHz sample rate
|
|
target_sample_rate = 16000
|
|
|
|
try:
|
|
# Load audio file using librosa and resample to target rate
|
|
audio_array, original_sampling_rate = librosa.load(file_path, sr=None) # type: ignore
|
|
logger.info(f"Loaded audio file: {file_path}, duration: {len(audio_array)/original_sampling_rate:.2f}s, original sample rate: {original_sampling_rate}Hz") # type: ignore
|
|
|
|
# Resample if necessary
|
|
if original_sampling_rate != target_sample_rate:
|
|
audio_array = librosa.resample(audio_array, orig_sr=original_sampling_rate, target_sr=target_sample_rate) # type: ignore
|
|
logger.info(f"Resampled audio from {original_sampling_rate}Hz to {target_sample_rate}Hz")
|
|
|
|
return audio_array, target_sample_rate # type: ignore
|
|
except Exception as e:
|
|
logger.error(f"Error loading audio file {file_path}: {e}")
|
|
raise
|
|
|
|
|
|
# Example usage - replace with your audio file path
|
|
audio_file_path = "/voicebot/F_0818_15y11m_1.wav"
|
|
|
|
# Load audio from file instead of dataset
|
|
try:
|
|
audio_array, sampling_rate = load_audio_file(audio_file_path)
|
|
input_features = extract_input_features(audio_array, sampling_rate)
|
|
|
|
predicted_ids = pt_model.generate(input_features) # type: ignore
|
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) # type: ignore
|
|
|
|
print(f"Audio file: {audio_file_path}")
|
|
print(f"Transcription: {transcription[0]}")
|
|
|
|
except FileNotFoundError:
|
|
logger.error(f"Audio file not found: {audio_file_path}")
|
|
print("Please update the audio_file_path variable with a valid path to your wav file")
|
|
except Exception as e:
|
|
logger.error(f"Error processing audio: {e}")
|
|
print(f"Error: {e}")
|
|
|
|
|