Fix linting errors
This commit is contained in:
parent
6857dd66fa
commit
cc9a7caa78
@ -10,7 +10,9 @@ import time
|
||||
import threading
|
||||
from collections import deque
|
||||
from queue import Queue, Empty
|
||||
from typing import Dict, Optional, Callable, Awaitable, Deque, Any
|
||||
from typing import Dict, Optional, Callable, Awaitable, Deque, Any, cast
|
||||
import numpy.typing as npt
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Core dependencies
|
||||
import librosa
|
||||
@ -28,6 +30,23 @@ from shared.models import ChatMessageModel
|
||||
|
||||
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
||||
|
||||
# Type definitions
|
||||
AudioArray = npt.NDArray[np.float32]
|
||||
|
||||
class AudioQueueItem(BaseModel):
|
||||
"""Audio data with timestamp for processing queue."""
|
||||
audio: AudioArray
|
||||
timestamp: float
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
class TranscriptionHistoryItem(BaseModel):
|
||||
"""Transcription history item with metadata."""
|
||||
message: str
|
||||
timestamp: float
|
||||
is_final: bool
|
||||
|
||||
AGENT_NAME = "whisper"
|
||||
AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text"
|
||||
sample_rate = 16000 # Whisper expects 16kHz
|
||||
@ -80,23 +99,23 @@ class AudioProcessor:
|
||||
self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms)
|
||||
|
||||
# Audio buffering
|
||||
self.audio_buffer: Deque[np.ndarray] = deque(maxlen=1000) # ~30 seconds at 30ms frames
|
||||
self.audio_buffer: Deque[AudioArray] = deque(maxlen=1000) # ~30 seconds at 30ms frames
|
||||
self.phrase_timeout = 3.0 # seconds of silence before considering phrase complete
|
||||
self.last_activity_time = time.time()
|
||||
|
||||
# Transcription state
|
||||
self.current_phrase_audio = np.array([], dtype=np.float32)
|
||||
self.transcription_history = []
|
||||
self.current_phrase_audio: AudioArray = np.array([], dtype=np.float32)
|
||||
self.transcription_history: list[TranscriptionHistoryItem] = []
|
||||
|
||||
# Background processing
|
||||
self.processing_queue = Queue()
|
||||
self.processing_queue: Queue[AudioQueueItem] = Queue()
|
||||
self.is_running = True
|
||||
self.processor_thread = threading.Thread(target=self._processing_loop, daemon=True)
|
||||
self.processor_thread.start()
|
||||
|
||||
logger.info("AudioProcessor initialized for real-time transcription")
|
||||
|
||||
def add_audio_data(self, audio_data: np.ndarray):
|
||||
def add_audio_data(self, audio_data: AudioArray):
|
||||
"""Add new audio data to the processing buffer."""
|
||||
if not self.is_running:
|
||||
return
|
||||
@ -121,10 +140,11 @@ class AudioProcessor:
|
||||
|
||||
# Add to processing queue
|
||||
try:
|
||||
self.processing_queue.put_nowait({
|
||||
'audio': combined_audio,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
queue_item = AudioQueueItem(
|
||||
audio=combined_audio,
|
||||
timestamp=time.time()
|
||||
)
|
||||
self.processing_queue.put_nowait(queue_item)
|
||||
except Exception:
|
||||
# Queue full, skip this chunk
|
||||
logger.debug("Audio processing queue full, dropping audio chunk")
|
||||
@ -141,8 +161,8 @@ class AudioProcessor:
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
audio_array = audio_data['audio']
|
||||
chunk_timestamp = audio_data['timestamp']
|
||||
audio_array = audio_data.audio
|
||||
chunk_timestamp = audio_data.timestamp
|
||||
|
||||
# Check if this is a new phrase (gap in audio)
|
||||
time_since_last = chunk_timestamp - self.last_activity_time
|
||||
@ -179,7 +199,7 @@ class AudioProcessor:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio processing loop: {e}", exc_info=True)
|
||||
|
||||
async def _transcribe_and_send(self, audio_array: np.ndarray, is_final: bool):
|
||||
async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool):
|
||||
"""Transcribe audio and send result as chat message."""
|
||||
global sample_rate
|
||||
|
||||
@ -203,15 +223,16 @@ class AudioProcessor:
|
||||
message = f"{prefix}{text}"
|
||||
|
||||
# Avoid sending duplicate messages
|
||||
if is_final or message not in [h.get('message', '') for h in self.transcription_history[-3:]]:
|
||||
if is_final or message not in [h.message for h in self.transcription_history[-3:]]:
|
||||
await self.send_chat_func(message)
|
||||
|
||||
# Keep history for deduplication
|
||||
self.transcription_history.append({
|
||||
'message': message,
|
||||
'timestamp': time.time(),
|
||||
'is_final': is_final
|
||||
})
|
||||
history_item = TranscriptionHistoryItem(
|
||||
message=message,
|
||||
timestamp=time.time(),
|
||||
is_final=is_final
|
||||
)
|
||||
self.transcription_history.append(history_item)
|
||||
|
||||
# Limit history size
|
||||
if len(self.transcription_history) > 10:
|
||||
@ -265,9 +286,12 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack):
|
||||
target_sr=sample_rate
|
||||
)
|
||||
|
||||
# Ensure audio_data is AudioArray (float32)
|
||||
audio_data_float32 = cast(AudioArray, audio_data.astype(np.float32))
|
||||
|
||||
# Send to audio processor
|
||||
if _audio_processor:
|
||||
_audio_processor.add_audio_data(audio_data)
|
||||
_audio_processor.add_audio_data(audio_data_float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing audio track from {peer.peer_name}: {e}", exc_info=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user