From cc9a7caa78ee6f36a6cbf358a4a2ca2f5df89c53 Mon Sep 17 00:00:00 2001 From: James Ketrenos Date: Thu, 4 Sep 2025 14:34:22 -0700 Subject: [PATCH] Fix linting errors --- voicebot/bots/whisper.py | 64 +++++++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/voicebot/bots/whisper.py b/voicebot/bots/whisper.py index 8b8a764..c42ae60 100644 --- a/voicebot/bots/whisper.py +++ b/voicebot/bots/whisper.py @@ -10,7 +10,9 @@ import time import threading from collections import deque from queue import Queue, Empty -from typing import Dict, Optional, Callable, Awaitable, Deque, Any +from typing import Dict, Optional, Callable, Awaitable, Deque, Any, cast +import numpy.typing as npt +from pydantic import BaseModel # Core dependencies import librosa @@ -28,6 +30,23 @@ from shared.models import ChatMessageModel from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq +# Type definitions +AudioArray = npt.NDArray[np.float32] + +class AudioQueueItem(BaseModel): + """Audio data with timestamp for processing queue.""" + audio: AudioArray + timestamp: float + + class Config: + arbitrary_types_allowed = True + +class TranscriptionHistoryItem(BaseModel): + """Transcription history item with metadata.""" + message: str + timestamp: float + is_final: bool + AGENT_NAME = "whisper" AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text" sample_rate = 16000 # Whisper expects 16kHz @@ -80,23 +99,23 @@ class AudioProcessor: self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms) # Audio buffering - self.audio_buffer: Deque[np.ndarray] = deque(maxlen=1000) # ~30 seconds at 30ms frames + self.audio_buffer: Deque[AudioArray] = deque(maxlen=1000) # ~30 seconds at 30ms frames self.phrase_timeout = 3.0 # seconds of silence before considering phrase complete self.last_activity_time = time.time() # Transcription state - self.current_phrase_audio = np.array([], dtype=np.float32) - self.transcription_history = [] + self.current_phrase_audio: AudioArray = np.array([], dtype=np.float32) + self.transcription_history: list[TranscriptionHistoryItem] = [] # Background processing - self.processing_queue = Queue() + self.processing_queue: Queue[AudioQueueItem] = Queue() self.is_running = True self.processor_thread = threading.Thread(target=self._processing_loop, daemon=True) self.processor_thread.start() logger.info("AudioProcessor initialized for real-time transcription") - def add_audio_data(self, audio_data: np.ndarray): + def add_audio_data(self, audio_data: AudioArray): """Add new audio data to the processing buffer.""" if not self.is_running: return @@ -121,10 +140,11 @@ class AudioProcessor: # Add to processing queue try: - self.processing_queue.put_nowait({ - 'audio': combined_audio, - 'timestamp': time.time() - }) + queue_item = AudioQueueItem( + audio=combined_audio, + timestamp=time.time() + ) + self.processing_queue.put_nowait(queue_item) except Exception: # Queue full, skip this chunk logger.debug("Audio processing queue full, dropping audio chunk") @@ -141,8 +161,8 @@ class AudioProcessor: except Empty: continue - audio_array = audio_data['audio'] - chunk_timestamp = audio_data['timestamp'] + audio_array = audio_data.audio + chunk_timestamp = audio_data.timestamp # Check if this is a new phrase (gap in audio) time_since_last = chunk_timestamp - self.last_activity_time @@ -179,7 +199,7 @@ class AudioProcessor: except Exception as e: logger.error(f"Error in audio processing loop: {e}", exc_info=True) - async def _transcribe_and_send(self, audio_array: np.ndarray, is_final: bool): + async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool): """Transcribe audio and send result as chat message.""" global sample_rate @@ -203,15 +223,16 @@ class AudioProcessor: message = f"{prefix}{text}" # Avoid sending duplicate messages - if is_final or message not in [h.get('message', '') for h in self.transcription_history[-3:]]: + if is_final or message not in [h.message for h in self.transcription_history[-3:]]: await self.send_chat_func(message) # Keep history for deduplication - self.transcription_history.append({ - 'message': message, - 'timestamp': time.time(), - 'is_final': is_final - }) + history_item = TranscriptionHistoryItem( + message=message, + timestamp=time.time(), + is_final=is_final + ) + self.transcription_history.append(history_item) # Limit history size if len(self.transcription_history) > 10: @@ -265,9 +286,12 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack): target_sr=sample_rate ) + # Ensure audio_data is AudioArray (float32) + audio_data_float32 = cast(AudioArray, audio_data.astype(np.float32)) + # Send to audio processor if _audio_processor: - _audio_processor.add_audio_data(audio_data) + _audio_processor.add_audio_data(audio_data_float32) except Exception as e: logger.error(f"Error processing audio track from {peer.peer_name}: {e}", exc_info=True)