Fix linting errors

This commit is contained in:
James Ketr 2025-09-04 14:34:22 -07:00
parent 6857dd66fa
commit cc9a7caa78

View File

@ -10,7 +10,9 @@ import time
import threading
from collections import deque
from queue import Queue, Empty
from typing import Dict, Optional, Callable, Awaitable, Deque, Any
from typing import Dict, Optional, Callable, Awaitable, Deque, Any, cast
import numpy.typing as npt
from pydantic import BaseModel
# Core dependencies
import librosa
@ -28,6 +30,23 @@ from shared.models import ChatMessageModel
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
# Type definitions
AudioArray = npt.NDArray[np.float32]
class AudioQueueItem(BaseModel):
"""Audio data with timestamp for processing queue."""
audio: AudioArray
timestamp: float
class Config:
arbitrary_types_allowed = True
class TranscriptionHistoryItem(BaseModel):
"""Transcription history item with metadata."""
message: str
timestamp: float
is_final: bool
AGENT_NAME = "whisper"
AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text"
sample_rate = 16000 # Whisper expects 16kHz
@ -80,23 +99,23 @@ class AudioProcessor:
self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms)
# Audio buffering
self.audio_buffer: Deque[np.ndarray] = deque(maxlen=1000) # ~30 seconds at 30ms frames
self.audio_buffer: Deque[AudioArray] = deque(maxlen=1000) # ~30 seconds at 30ms frames
self.phrase_timeout = 3.0 # seconds of silence before considering phrase complete
self.last_activity_time = time.time()
# Transcription state
self.current_phrase_audio = np.array([], dtype=np.float32)
self.transcription_history = []
self.current_phrase_audio: AudioArray = np.array([], dtype=np.float32)
self.transcription_history: list[TranscriptionHistoryItem] = []
# Background processing
self.processing_queue = Queue()
self.processing_queue: Queue[AudioQueueItem] = Queue()
self.is_running = True
self.processor_thread = threading.Thread(target=self._processing_loop, daemon=True)
self.processor_thread.start()
logger.info("AudioProcessor initialized for real-time transcription")
def add_audio_data(self, audio_data: np.ndarray):
def add_audio_data(self, audio_data: AudioArray):
"""Add new audio data to the processing buffer."""
if not self.is_running:
return
@ -121,10 +140,11 @@ class AudioProcessor:
# Add to processing queue
try:
self.processing_queue.put_nowait({
'audio': combined_audio,
'timestamp': time.time()
})
queue_item = AudioQueueItem(
audio=combined_audio,
timestamp=time.time()
)
self.processing_queue.put_nowait(queue_item)
except Exception:
# Queue full, skip this chunk
logger.debug("Audio processing queue full, dropping audio chunk")
@ -141,8 +161,8 @@ class AudioProcessor:
except Empty:
continue
audio_array = audio_data['audio']
chunk_timestamp = audio_data['timestamp']
audio_array = audio_data.audio
chunk_timestamp = audio_data.timestamp
# Check if this is a new phrase (gap in audio)
time_since_last = chunk_timestamp - self.last_activity_time
@ -179,7 +199,7 @@ class AudioProcessor:
except Exception as e:
logger.error(f"Error in audio processing loop: {e}", exc_info=True)
async def _transcribe_and_send(self, audio_array: np.ndarray, is_final: bool):
async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool):
"""Transcribe audio and send result as chat message."""
global sample_rate
@ -203,15 +223,16 @@ class AudioProcessor:
message = f"{prefix}{text}"
# Avoid sending duplicate messages
if is_final or message not in [h.get('message', '') for h in self.transcription_history[-3:]]:
if is_final or message not in [h.message for h in self.transcription_history[-3:]]:
await self.send_chat_func(message)
# Keep history for deduplication
self.transcription_history.append({
'message': message,
'timestamp': time.time(),
'is_final': is_final
})
history_item = TranscriptionHistoryItem(
message=message,
timestamp=time.time(),
is_final=is_final
)
self.transcription_history.append(history_item)
# Limit history size
if len(self.transcription_history) > 10:
@ -265,9 +286,12 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack):
target_sr=sample_rate
)
# Ensure audio_data is AudioArray (float32)
audio_data_float32 = cast(AudioArray, audio_data.astype(np.float32))
# Send to audio processor
if _audio_processor:
_audio_processor.add_audio_data(audio_data)
_audio_processor.add_audio_data(audio_data_float32)
except Exception as e:
logger.error(f"Error processing audio track from {peer.peer_name}: {e}", exc_info=True)