Fix linting errors
This commit is contained in:
parent
6857dd66fa
commit
cc9a7caa78
@ -10,7 +10,9 @@ import time
|
|||||||
import threading
|
import threading
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from queue import Queue, Empty
|
from queue import Queue, Empty
|
||||||
from typing import Dict, Optional, Callable, Awaitable, Deque, Any
|
from typing import Dict, Optional, Callable, Awaitable, Deque, Any, cast
|
||||||
|
import numpy.typing as npt
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
# Core dependencies
|
# Core dependencies
|
||||||
import librosa
|
import librosa
|
||||||
@ -28,6 +30,23 @@ from shared.models import ChatMessageModel
|
|||||||
|
|
||||||
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
||||||
|
|
||||||
|
# Type definitions
|
||||||
|
AudioArray = npt.NDArray[np.float32]
|
||||||
|
|
||||||
|
class AudioQueueItem(BaseModel):
|
||||||
|
"""Audio data with timestamp for processing queue."""
|
||||||
|
audio: AudioArray
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
class TranscriptionHistoryItem(BaseModel):
|
||||||
|
"""Transcription history item with metadata."""
|
||||||
|
message: str
|
||||||
|
timestamp: float
|
||||||
|
is_final: bool
|
||||||
|
|
||||||
AGENT_NAME = "whisper"
|
AGENT_NAME = "whisper"
|
||||||
AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text"
|
AGENT_DESCRIPTION = "Real-time speech transcription (Whisper) - converts speech to text"
|
||||||
sample_rate = 16000 # Whisper expects 16kHz
|
sample_rate = 16000 # Whisper expects 16kHz
|
||||||
@ -80,23 +99,23 @@ class AudioProcessor:
|
|||||||
self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms)
|
self.samples_per_frame = 480 # Common WebRTC frame size at 16kHz (30ms)
|
||||||
|
|
||||||
# Audio buffering
|
# Audio buffering
|
||||||
self.audio_buffer: Deque[np.ndarray] = deque(maxlen=1000) # ~30 seconds at 30ms frames
|
self.audio_buffer: Deque[AudioArray] = deque(maxlen=1000) # ~30 seconds at 30ms frames
|
||||||
self.phrase_timeout = 3.0 # seconds of silence before considering phrase complete
|
self.phrase_timeout = 3.0 # seconds of silence before considering phrase complete
|
||||||
self.last_activity_time = time.time()
|
self.last_activity_time = time.time()
|
||||||
|
|
||||||
# Transcription state
|
# Transcription state
|
||||||
self.current_phrase_audio = np.array([], dtype=np.float32)
|
self.current_phrase_audio: AudioArray = np.array([], dtype=np.float32)
|
||||||
self.transcription_history = []
|
self.transcription_history: list[TranscriptionHistoryItem] = []
|
||||||
|
|
||||||
# Background processing
|
# Background processing
|
||||||
self.processing_queue = Queue()
|
self.processing_queue: Queue[AudioQueueItem] = Queue()
|
||||||
self.is_running = True
|
self.is_running = True
|
||||||
self.processor_thread = threading.Thread(target=self._processing_loop, daemon=True)
|
self.processor_thread = threading.Thread(target=self._processing_loop, daemon=True)
|
||||||
self.processor_thread.start()
|
self.processor_thread.start()
|
||||||
|
|
||||||
logger.info("AudioProcessor initialized for real-time transcription")
|
logger.info("AudioProcessor initialized for real-time transcription")
|
||||||
|
|
||||||
def add_audio_data(self, audio_data: np.ndarray):
|
def add_audio_data(self, audio_data: AudioArray):
|
||||||
"""Add new audio data to the processing buffer."""
|
"""Add new audio data to the processing buffer."""
|
||||||
if not self.is_running:
|
if not self.is_running:
|
||||||
return
|
return
|
||||||
@ -121,10 +140,11 @@ class AudioProcessor:
|
|||||||
|
|
||||||
# Add to processing queue
|
# Add to processing queue
|
||||||
try:
|
try:
|
||||||
self.processing_queue.put_nowait({
|
queue_item = AudioQueueItem(
|
||||||
'audio': combined_audio,
|
audio=combined_audio,
|
||||||
'timestamp': time.time()
|
timestamp=time.time()
|
||||||
})
|
)
|
||||||
|
self.processing_queue.put_nowait(queue_item)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Queue full, skip this chunk
|
# Queue full, skip this chunk
|
||||||
logger.debug("Audio processing queue full, dropping audio chunk")
|
logger.debug("Audio processing queue full, dropping audio chunk")
|
||||||
@ -141,8 +161,8 @@ class AudioProcessor:
|
|||||||
except Empty:
|
except Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
audio_array = audio_data['audio']
|
audio_array = audio_data.audio
|
||||||
chunk_timestamp = audio_data['timestamp']
|
chunk_timestamp = audio_data.timestamp
|
||||||
|
|
||||||
# Check if this is a new phrase (gap in audio)
|
# Check if this is a new phrase (gap in audio)
|
||||||
time_since_last = chunk_timestamp - self.last_activity_time
|
time_since_last = chunk_timestamp - self.last_activity_time
|
||||||
@ -179,7 +199,7 @@ class AudioProcessor:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio processing loop: {e}", exc_info=True)
|
logger.error(f"Error in audio processing loop: {e}", exc_info=True)
|
||||||
|
|
||||||
async def _transcribe_and_send(self, audio_array: np.ndarray, is_final: bool):
|
async def _transcribe_and_send(self, audio_array: AudioArray, is_final: bool):
|
||||||
"""Transcribe audio and send result as chat message."""
|
"""Transcribe audio and send result as chat message."""
|
||||||
global sample_rate
|
global sample_rate
|
||||||
|
|
||||||
@ -203,15 +223,16 @@ class AudioProcessor:
|
|||||||
message = f"{prefix}{text}"
|
message = f"{prefix}{text}"
|
||||||
|
|
||||||
# Avoid sending duplicate messages
|
# Avoid sending duplicate messages
|
||||||
if is_final or message not in [h.get('message', '') for h in self.transcription_history[-3:]]:
|
if is_final or message not in [h.message for h in self.transcription_history[-3:]]:
|
||||||
await self.send_chat_func(message)
|
await self.send_chat_func(message)
|
||||||
|
|
||||||
# Keep history for deduplication
|
# Keep history for deduplication
|
||||||
self.transcription_history.append({
|
history_item = TranscriptionHistoryItem(
|
||||||
'message': message,
|
message=message,
|
||||||
'timestamp': time.time(),
|
timestamp=time.time(),
|
||||||
'is_final': is_final
|
is_final=is_final
|
||||||
})
|
)
|
||||||
|
self.transcription_history.append(history_item)
|
||||||
|
|
||||||
# Limit history size
|
# Limit history size
|
||||||
if len(self.transcription_history) > 10:
|
if len(self.transcription_history) > 10:
|
||||||
@ -265,9 +286,12 @@ async def handle_track_received(peer: Peer, track: MediaStreamTrack):
|
|||||||
target_sr=sample_rate
|
target_sr=sample_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ensure audio_data is AudioArray (float32)
|
||||||
|
audio_data_float32 = cast(AudioArray, audio_data.astype(np.float32))
|
||||||
|
|
||||||
# Send to audio processor
|
# Send to audio processor
|
||||||
if _audio_processor:
|
if _audio_processor:
|
||||||
_audio_processor.add_audio_data(audio_data)
|
_audio_processor.add_audio_data(audio_data_float32)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing audio track from {peer.peer_name}: {e}", exc_info=True)
|
logger.error(f"Error processing audio track from {peer.peer_name}: {e}", exc_info=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user