Reload working
This commit is contained in:
parent
a313209768
commit
e64edf92ca
@ -2,9 +2,10 @@
|
|||||||
Shared Pydantic models for API communication between voicebot and server components.
|
Shared Pydantic models for API communication between voicebot and server components.
|
||||||
|
|
||||||
This module contains all the shared data models used for:
|
This module contains all the shared data models used for:
|
||||||
- HTTP API requests and responses
|
- HTTP API requests and responses
|
||||||
- WebSocket message payloads
|
- WebSocket message payloads
|
||||||
- Data persistence structures
|
- Data persistence structures
|
||||||
|
Test comment for shared reload detection - updated again
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
@ -21,7 +21,8 @@ export PATH="$VIRTUAL_ENV/bin:$PATH"
|
|||||||
# Launch voicebot in production or development mode
|
# Launch voicebot in production or development mode
|
||||||
if [ "$PRODUCTION" != "true" ]; then
|
if [ "$PRODUCTION" != "true" ]; then
|
||||||
echo "Starting voicebot in development mode with auto-reload..."
|
echo "Starting voicebot in development mode with auto-reload..."
|
||||||
python3 scripts/reload_runner.py --watch /voicebot --watch /shared -- uv run main.py \
|
# Fix: Use single --watch argument with multiple paths instead of multiple --watch arguments
|
||||||
|
python3 -u scripts/reload_runner.py --watch . /shared --verbose --interval 0.5 -- uv run main.py \
|
||||||
--insecure \
|
--insecure \
|
||||||
--server-url https://ketrenos.com/ai-voicebot \
|
--server-url https://ketrenos.com/ai-voicebot \
|
||||||
--lobby default \
|
--lobby default \
|
||||||
|
@ -48,9 +48,9 @@ def _setup_logging(level: str=logging_level) -> logging.Logger:
|
|||||||
warnings.filterwarnings("ignore", message="n_jobs value 1 overridden")
|
warnings.filterwarnings("ignore", message="n_jobs value 1 overridden")
|
||||||
warnings.filterwarnings("ignore", message=".*websocket.*is deprecated")
|
warnings.filterwarnings("ignore", message=".*websocket.*is deprecated")
|
||||||
|
|
||||||
logging.getLogger("aiortc").setLevel(logging.INFO)
|
logging.getLogger("aiortc").setLevel(logging.WARNING)
|
||||||
logging.getLogger("aioice").setLevel(logging.INFO)
|
logging.getLogger("aioice").setLevel(logging.WARNING)
|
||||||
logging.getLogger("asyncio").setLevel(logging.INFO)
|
logging.getLogger("asyncio").setLevel(logging.WARNING)
|
||||||
|
|
||||||
numeric_level = getattr(logging, level.upper(), None)
|
numeric_level = getattr(logging, level.upper(), None)
|
||||||
if not isinstance(numeric_level, int):
|
if not isinstance(numeric_level, int):
|
||||||
|
144
voicebot/main.py
144
voicebot/main.py
@ -3,6 +3,7 @@ WebRTC Media Agent for Python
|
|||||||
|
|
||||||
This module provides WebRTC signaling server communication and peer connection management.
|
This module provides WebRTC signaling server communication and peer connection management.
|
||||||
Synthetic audio/video track creation is handled by the synthetic_media module.
|
Synthetic audio/video track creation is handled by the synthetic_media module.
|
||||||
|
Test change to trigger reload - TESTING RELOAD NOW
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@ -20,6 +21,7 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# test
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
@ -60,7 +62,6 @@ from synthetic_media import create_synthetic_tracks
|
|||||||
|
|
||||||
# import debug_aioice
|
# import debug_aioice
|
||||||
|
|
||||||
|
|
||||||
# Generic message payload type
|
# Generic message payload type
|
||||||
MessageData = dict[str, object]
|
MessageData = dict[str, object]
|
||||||
|
|
||||||
@ -237,9 +238,9 @@ class WebRTCSignalingClient:
|
|||||||
|
|
||||||
ws = cast(WebSocketProtocol, self.websocket)
|
ws = cast(WebSocketProtocol, self.websocket)
|
||||||
try:
|
try:
|
||||||
logger.info(f"_send_message: Sending {message_type} with data: {data}")
|
logger.debug(f"_send_message: Sending {message_type} with data: {data}")
|
||||||
await ws.send(json.dumps(message))
|
await ws.send(json.dumps(message))
|
||||||
logger.info(f"_send_message: Sent message: {message_type}")
|
logger.debug(f"_send_message: Sent message: {message_type}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"_send_message: Failed to send {message_type}: {e}", exc_info=True
|
f"_send_message: Failed to send {message_type}: {e}", exc_info=True
|
||||||
@ -250,7 +251,7 @@ class WebRTCSignalingClient:
|
|||||||
try:
|
try:
|
||||||
ws = cast(WebSocketProtocol, self.websocket)
|
ws = cast(WebSocketProtocol, self.websocket)
|
||||||
async for message in ws:
|
async for message in ws:
|
||||||
logger.info(f"_handle_messages: Received raw message: {message}")
|
logger.debug(f"_handle_messages: Received raw message: {message}")
|
||||||
try:
|
try:
|
||||||
data = cast(MessageData, json.loads(message))
|
data = cast(MessageData, json.loads(message))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -275,7 +276,7 @@ class WebRTCSignalingClient:
|
|||||||
logger.error(f"Invalid message structure: {e}", exc_info=True)
|
logger.error(f"Invalid message structure: {e}", exc_info=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"_process_message: Received message type: {msg_type} with data: {data}"
|
f"_process_message: Received message type: {msg_type} with data: {data}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -350,14 +351,14 @@ class WebRTCSignalingClient:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})"
|
f"Adding peer: {peer_name} (should_create_offer: {should_create_offer})"
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"_handle_add_peer: peer_id={peer_id}, peer_name={peer_name}, should_create_offer={should_create_offer}"
|
f"_handle_add_peer: peer_id={peer_id}, peer_name={peer_name}, should_create_offer={should_create_offer}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if peer already exists
|
# Check if peer already exists
|
||||||
if peer_id in self.peer_connections:
|
if peer_id in self.peer_connections:
|
||||||
pc = self.peer_connections[peer_id]
|
pc = self.peer_connections[peer_id]
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"_handle_add_peer: Existing connection state: {pc.connectionState}"
|
f"_handle_add_peer: Existing connection state: {pc.connectionState}"
|
||||||
)
|
)
|
||||||
if pc.connectionState in ["new", "connected", "connecting"]:
|
if pc.connectionState in ["new", "connected", "connecting"]:
|
||||||
@ -365,7 +366,7 @@ class WebRTCSignalingClient:
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# Clean up stale connection
|
# Clean up stale connection
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"_handle_add_peer: Closing stale connection for {peer_name}"
|
f"_handle_add_peer: Closing stale connection for {peer_name}"
|
||||||
)
|
)
|
||||||
await pc.close()
|
await pc.close()
|
||||||
@ -390,7 +391,7 @@ class WebRTCSignalingClient:
|
|||||||
RTCIceServer(urls="stun:stun.l.google.com:19302"),
|
RTCIceServer(urls="stun:stun.l.google.com:19302"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"_handle_add_peer: Creating RTCPeerConnection for {peer_name} with config: {config}"
|
f"_handle_add_peer: Creating RTCPeerConnection for {peer_name} with config: {config}"
|
||||||
)
|
)
|
||||||
pc = RTCPeerConnection(configuration=config)
|
pc = RTCPeerConnection(configuration=config)
|
||||||
@ -400,7 +401,9 @@ class WebRTCSignalingClient:
|
|||||||
logger.info(f"ICE gathering state: {pc.iceGatheringState}")
|
logger.info(f"ICE gathering state: {pc.iceGatheringState}")
|
||||||
# Debug: Check if we have any local candidates when gathering is complete
|
# Debug: Check if we have any local candidates when gathering is complete
|
||||||
if pc.iceGatheringState == "complete":
|
if pc.iceGatheringState == "complete":
|
||||||
logger.info(f"ICE gathering complete for {peer_name} - checking if candidates were generated...")
|
logger.info(
|
||||||
|
f"ICE gathering complete for {peer_name} - checking if candidates were generated..."
|
||||||
|
)
|
||||||
|
|
||||||
pc.on("icegatheringstatechange")(on_ice_gathering_state_change)
|
pc.on("icegatheringstatechange")(on_ice_gathering_state_change)
|
||||||
|
|
||||||
@ -425,9 +428,13 @@ class WebRTCSignalingClient:
|
|||||||
|
|
||||||
def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None:
|
def on_ice_candidate(candidate: Optional[RTCIceCandidate]) -> None:
|
||||||
logger.info(f"on_ice_candidate: {candidate}")
|
logger.info(f"on_ice_candidate: {candidate}")
|
||||||
logger.info(f"on_ice_candidate CALLED for {peer_name}: candidate={candidate}")
|
logger.info(
|
||||||
|
f"on_ice_candidate CALLED for {peer_name}: candidate={candidate}"
|
||||||
|
)
|
||||||
if not candidate:
|
if not candidate:
|
||||||
logger.info(f"on_ice_candidate: End of candidates signal for {peer_name}")
|
logger.info(
|
||||||
|
f"on_ice_candidate: End of candidates signal for {peer_name}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Raw SDP fragment for the candidate
|
# Raw SDP fragment for the candidate
|
||||||
@ -467,7 +474,7 @@ class WebRTCSignalingClient:
|
|||||||
|
|
||||||
# Add local tracks
|
# Add local tracks
|
||||||
for track in self.local_tracks.values():
|
for track in self.local_tracks.values():
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"_handle_add_peer: Adding local track {track.kind} to {peer_name}"
|
f"_handle_add_peer: Adding local track {track.kind} to {peer_name}"
|
||||||
)
|
)
|
||||||
pc.addTrack(track)
|
pc.addTrack(track)
|
||||||
@ -479,34 +486,42 @@ class WebRTCSignalingClient:
|
|||||||
self.is_negotiating[peer_id] = True
|
self.is_negotiating[peer_id] = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"_handle_add_peer: Creating offer for {peer_name}")
|
logger.debug(f"_handle_add_peer: Creating offer for {peer_name}")
|
||||||
offer = await pc.createOffer()
|
offer = await pc.createOffer()
|
||||||
logger.info(f"_handle_add_peer: Offer created for {peer_name}: {offer}")
|
logger.debug(
|
||||||
|
f"_handle_add_peer: Offer created for {peer_name}: {offer}"
|
||||||
|
)
|
||||||
await pc.setLocalDescription(offer)
|
await pc.setLocalDescription(offer)
|
||||||
logger.info(f"_handle_add_peer: Local description set for {peer_name}")
|
logger.debug(f"_handle_add_peer: Local description set for {peer_name}")
|
||||||
|
|
||||||
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
|
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
|
||||||
# Use Method 2: Complete SDP approach to extract ICE candidates
|
# Use Method 2: Complete SDP approach to extract ICE candidates
|
||||||
logger.info(f"_handle_add_peer: Waiting for ICE gathering to complete for {peer_name}")
|
logger.debug(
|
||||||
|
f"_handle_add_peer: Waiting for ICE gathering to complete for {peer_name}"
|
||||||
|
)
|
||||||
while pc.iceGatheringState != "complete":
|
while pc.iceGatheringState != "complete":
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
logger.info(f"_handle_add_peer: ICE gathering complete, extracting candidates from SDP for {peer_name}")
|
logger.debug(
|
||||||
|
f"_handle_add_peer: ICE gathering complete, extracting candidates from SDP for {peer_name}"
|
||||||
|
)
|
||||||
|
|
||||||
# Parse ICE candidates from the local SDP
|
# Parse ICE candidates from the local SDP
|
||||||
sdp_lines = pc.localDescription.sdp.split('\n')
|
sdp_lines = pc.localDescription.sdp.split("\n")
|
||||||
candidate_lines = [line for line in sdp_lines if line.startswith('a=candidate:')]
|
candidate_lines = [
|
||||||
|
line for line in sdp_lines if line.startswith("a=candidate:")
|
||||||
|
]
|
||||||
|
|
||||||
# Track which media section we're in to determine sdpMid and sdpMLineIndex
|
# Track which media section we're in to determine sdpMid and sdpMLineIndex
|
||||||
current_media_index = -1
|
current_media_index = -1
|
||||||
current_mid = None
|
current_mid = None
|
||||||
|
|
||||||
for line in sdp_lines:
|
for line in sdp_lines:
|
||||||
if line.startswith('m='): # Media section
|
if line.startswith("m="): # Media section
|
||||||
current_media_index += 1
|
current_media_index += 1
|
||||||
elif line.startswith('a=mid:'): # Media ID
|
elif line.startswith("a=mid:"): # Media ID
|
||||||
current_mid = line.split(':', 1)[1].strip()
|
current_mid = line.split(":", 1)[1].strip()
|
||||||
elif line.startswith('a=candidate:'):
|
elif line.startswith("a=candidate:"):
|
||||||
candidate_sdp = line[2:] # Remove 'a=' prefix
|
candidate_sdp = line[2:] # Remove 'a=' prefix
|
||||||
|
|
||||||
candidate_model = ICECandidateDictModel(
|
candidate_model = ICECandidateDictModel(
|
||||||
@ -519,12 +534,14 @@ class WebRTCSignalingClient:
|
|||||||
peer_name=peer_name,
|
peer_name=peer_name,
|
||||||
candidate=candidate_model,
|
candidate=candidate_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"_handle_add_peer: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}...")
|
logger.debug(
|
||||||
|
f"_handle_add_peer: Sending extracted ICE candidate for {peer_name}: {candidate_sdp[:60]}..."
|
||||||
|
)
|
||||||
await self._send_message(
|
await self._send_message(
|
||||||
"relayICECandidate", payload_candidate.model_dump()
|
"relayICECandidate", payload_candidate.model_dump()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send end-of-candidates signal (empty candidate)
|
# Send end-of-candidates signal (empty candidate)
|
||||||
end_candidate_model = ICECandidateDictModel(
|
end_candidate_model = ICECandidateDictModel(
|
||||||
candidate="",
|
candidate="",
|
||||||
@ -534,10 +551,14 @@ class WebRTCSignalingClient:
|
|||||||
payload_end = IceCandidateModel(
|
payload_end = IceCandidateModel(
|
||||||
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
|
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
|
||||||
)
|
)
|
||||||
logger.info(f"_handle_add_peer: Sending end-of-candidates signal for {peer_name}")
|
logger.debug(
|
||||||
|
f"_handle_add_peer: Sending end-of-candidates signal for {peer_name}"
|
||||||
|
)
|
||||||
await self._send_message("relayICECandidate", payload_end.model_dump())
|
await self._send_message("relayICECandidate", payload_end.model_dump())
|
||||||
|
|
||||||
logger.info(f"_handle_add_peer: Sent {len(candidate_lines)} ICE candidates to {peer_name}")
|
logger.debug(
|
||||||
|
f"_handle_add_peer: Sent {len(candidate_lines)} ICE candidates to {peer_name}"
|
||||||
|
)
|
||||||
|
|
||||||
session_desc_typed = SessionDescriptionTypedModel(
|
session_desc_typed = SessionDescriptionTypedModel(
|
||||||
type=offer.type, sdp=offer.sdp
|
type=offer.type, sdp=offer.sdp
|
||||||
@ -660,29 +681,35 @@ class WebRTCSignalingClient:
|
|||||||
try:
|
try:
|
||||||
answer = await pc.createAnswer()
|
answer = await pc.createAnswer()
|
||||||
await pc.setLocalDescription(answer)
|
await pc.setLocalDescription(answer)
|
||||||
|
|
||||||
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
|
# WORKAROUND for aiortc icecandidate event not firing (GitHub issue #1344)
|
||||||
# Use Method 2: Complete SDP approach to extract ICE candidates
|
# Use Method 2: Complete SDP approach to extract ICE candidates
|
||||||
logger.info(f"_handle_session_description: Waiting for ICE gathering to complete for {peer_name} (answer)")
|
logger.debug(
|
||||||
|
f"_handle_session_description: Waiting for ICE gathering to complete for {peer_name} (answer)"
|
||||||
|
)
|
||||||
while pc.iceGatheringState != "complete":
|
while pc.iceGatheringState != "complete":
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
logger.info(f"_handle_session_description: ICE gathering complete, extracting candidates from SDP for {peer_name} (answer)")
|
logger.debug(
|
||||||
|
f"_handle_session_description: ICE gathering complete, extracting candidates from SDP for {peer_name} (answer)"
|
||||||
|
)
|
||||||
|
|
||||||
# Parse ICE candidates from the local SDP
|
# Parse ICE candidates from the local SDP
|
||||||
sdp_lines = pc.localDescription.sdp.split('\n')
|
sdp_lines = pc.localDescription.sdp.split("\n")
|
||||||
candidate_lines = [line for line in sdp_lines if line.startswith('a=candidate:')]
|
candidate_lines = [
|
||||||
|
line for line in sdp_lines if line.startswith("a=candidate:")
|
||||||
|
]
|
||||||
|
|
||||||
# Track which media section we're in to determine sdpMid and sdpMLineIndex
|
# Track which media section we're in to determine sdpMid and sdpMLineIndex
|
||||||
current_media_index = -1
|
current_media_index = -1
|
||||||
current_mid = None
|
current_mid = None
|
||||||
|
|
||||||
for line in sdp_lines:
|
for line in sdp_lines:
|
||||||
if line.startswith('m='): # Media section
|
if line.startswith("m="): # Media section
|
||||||
current_media_index += 1
|
current_media_index += 1
|
||||||
elif line.startswith('a=mid:'): # Media ID
|
elif line.startswith("a=mid:"): # Media ID
|
||||||
current_mid = line.split(':', 1)[1].strip()
|
current_mid = line.split(":", 1)[1].strip()
|
||||||
elif line.startswith('a=candidate:'):
|
elif line.startswith("a=candidate:"):
|
||||||
candidate_sdp = line[2:] # Remove 'a=' prefix
|
candidate_sdp = line[2:] # Remove 'a=' prefix
|
||||||
|
|
||||||
candidate_model = ICECandidateDictModel(
|
candidate_model = ICECandidateDictModel(
|
||||||
@ -695,12 +722,14 @@ class WebRTCSignalingClient:
|
|||||||
peer_name=peer_name,
|
peer_name=peer_name,
|
||||||
candidate=candidate_model,
|
candidate=candidate_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"_handle_session_description: Sending extracted ICE candidate for {peer_name} (answer): {candidate_sdp[:60]}...")
|
logger.debug(
|
||||||
|
f"_handle_session_description: Sending extracted ICE candidate for {peer_name} (answer): {candidate_sdp[:60]}..."
|
||||||
|
)
|
||||||
await self._send_message(
|
await self._send_message(
|
||||||
"relayICECandidate", payload_candidate.model_dump()
|
"relayICECandidate", payload_candidate.model_dump()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send end-of-candidates signal (empty candidate)
|
# Send end-of-candidates signal (empty candidate)
|
||||||
end_candidate_model = ICECandidateDictModel(
|
end_candidate_model = ICECandidateDictModel(
|
||||||
candidate="",
|
candidate="",
|
||||||
@ -710,10 +739,14 @@ class WebRTCSignalingClient:
|
|||||||
payload_end = IceCandidateModel(
|
payload_end = IceCandidateModel(
|
||||||
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
|
peer_id=peer_id, peer_name=peer_name, candidate=end_candidate_model
|
||||||
)
|
)
|
||||||
logger.info(f"_handle_session_description: Sending end-of-candidates signal for {peer_name} (answer)")
|
logger.debug(
|
||||||
|
f"_handle_session_description: Sending end-of-candidates signal for {peer_name} (answer)"
|
||||||
|
)
|
||||||
await self._send_message("relayICECandidate", payload_end.model_dump())
|
await self._send_message("relayICECandidate", payload_end.model_dump())
|
||||||
|
|
||||||
logger.info(f"_handle_session_description: Sent {len(candidate_lines)} ICE candidates to {peer_name} (answer)")
|
logger.debug(
|
||||||
|
f"_handle_session_description: Sent {len(candidate_lines)} ICE candidates to {peer_name} (answer)"
|
||||||
|
)
|
||||||
|
|
||||||
session_desc_typed = SessionDescriptionTypedModel(
|
session_desc_typed = SessionDescriptionTypedModel(
|
||||||
type=answer.type, sdp=answer.sdp
|
type=answer.type, sdp=answer.sdp
|
||||||
@ -1056,3 +1089,6 @@ if __name__ == "__main__":
|
|||||||
# pip install aiortc websockets opencv-python numpy
|
# pip install aiortc websockets opencv-python numpy
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
# test modification
|
||||||
|
# Test comment Mon Sep 1 03:48:19 PM PDT 2025
|
||||||
|
# Test change at Mon Sep 1 03:52:13 PM PDT 2025
|
||||||
|
@ -11,6 +11,7 @@ and inside containers without installing extra packages.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
@ -21,28 +22,90 @@ from types import FrameType
|
|||||||
|
|
||||||
|
|
||||||
def scan_py_mtimes(paths: List[str]) -> Dict[str, float]:
|
def scan_py_mtimes(paths: List[str]) -> Dict[str, float]:
|
||||||
|
# Directories to skip during scanning
|
||||||
|
SKIP_DIRS = {
|
||||||
|
".venv",
|
||||||
|
"__pycache__",
|
||||||
|
".git",
|
||||||
|
"node_modules",
|
||||||
|
".mypy_cache",
|
||||||
|
".pytest_cache",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
}
|
||||||
|
|
||||||
mtimes: Dict[str, float] = {}
|
mtimes: Dict[str, float] = {}
|
||||||
for p in paths:
|
for p in paths:
|
||||||
if os.path.isfile(p) and p.endswith('.py'):
|
if os.path.isfile(p) and p.endswith('.py'):
|
||||||
try:
|
try:
|
||||||
mtimes[p] = os.path.getmtime(p)
|
# Use both mtime and ctime to catch more changes in Docker environments
|
||||||
|
stat = os.stat(p)
|
||||||
|
mtimes[p] = max(stat.st_mtime, stat.st_ctime)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for root, _, files in os.walk(p):
|
for root, dirs, files in os.walk(p):
|
||||||
|
# Skip common directories that shouldn't trigger reloads
|
||||||
|
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
|
||||||
|
|
||||||
for f in files:
|
for f in files:
|
||||||
if not f.endswith('.py'):
|
if not f.endswith('.py'):
|
||||||
continue
|
continue
|
||||||
fp = os.path.join(root, f)
|
fp = os.path.join(root, f)
|
||||||
try:
|
try:
|
||||||
mtimes[fp] = os.path.getmtime(fp)
|
# Use both mtime and ctime to catch more changes in Docker environments
|
||||||
|
stat = os.stat(fp)
|
||||||
|
mtimes[fp] = max(stat.st_mtime, stat.st_ctime)
|
||||||
except OSError:
|
except OSError:
|
||||||
# file might disappear between walk and stat
|
# file might disappear between walk and stat
|
||||||
pass
|
pass
|
||||||
return mtimes
|
return mtimes
|
||||||
|
|
||||||
|
|
||||||
|
def scan_py_hashes(paths: List[str]) -> Dict[str, str]:
|
||||||
|
"""Fallback method: scan file content hashes for change detection."""
|
||||||
|
# Directories to skip during scanning
|
||||||
|
SKIP_DIRS = {
|
||||||
|
".venv",
|
||||||
|
"__pycache__",
|
||||||
|
".git",
|
||||||
|
"node_modules",
|
||||||
|
".mypy_cache",
|
||||||
|
".pytest_cache",
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
}
|
||||||
|
|
||||||
|
hashes: Dict[str, str] = {}
|
||||||
|
for p in paths:
|
||||||
|
if os.path.isfile(p) and p.endswith(".py"):
|
||||||
|
try:
|
||||||
|
with open(p, "rb") as f:
|
||||||
|
content = f.read()
|
||||||
|
hashes[p] = hashlib.md5(content).hexdigest()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
continue
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(p):
|
||||||
|
# Skip common directories that shouldn't trigger reloads
|
||||||
|
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
if not f.endswith(".py"):
|
||||||
|
continue
|
||||||
|
fp = os.path.join(root, f)
|
||||||
|
try:
|
||||||
|
with open(fp, "rb") as file:
|
||||||
|
content = file.read()
|
||||||
|
hashes[fp] = hashlib.md5(content).hexdigest()
|
||||||
|
except OSError:
|
||||||
|
# file might disappear between walk and read
|
||||||
|
pass
|
||||||
|
return hashes
|
||||||
|
|
||||||
|
|
||||||
def start_process(cmd: List[str]) -> subprocess.Popen[bytes]:
|
def start_process(cmd: List[str]) -> subprocess.Popen[bytes]:
|
||||||
print("Starting:", " ".join(cmd))
|
print("Starting:", " ".join(cmd))
|
||||||
return subprocess.Popen(cmd)
|
return subprocess.Popen(cmd)
|
||||||
@ -66,10 +129,20 @@ def terminate_process(p: subprocess.Popen[bytes], timeout: float = 5.0) -> None:
|
|||||||
def main() -> int:
|
def main() -> int:
|
||||||
parser = argparse.ArgumentParser(description="Restart a command when .py files change")
|
parser = argparse.ArgumentParser(description="Restart a command when .py files change")
|
||||||
parser.add_argument("--watch", "-w", nargs="+", default=["."], help="Directories or files to watch")
|
parser.add_argument("--watch", "-w", nargs="+", default=["."], help="Directories or files to watch")
|
||||||
parser.add_argument("--interval", "-i", type=float, default=1.0, help="Polling interval in seconds")
|
parser.add_argument(
|
||||||
|
"--interval", "-i", type=float, default=0.5, help="Polling interval in seconds"
|
||||||
|
)
|
||||||
parser.add_argument("--delay-restart", type=float, default=0.1, help="Delay after change before restarting")
|
parser.add_argument("--delay-restart", type=float, default=0.1, help="Delay after change before restarting")
|
||||||
parser.add_argument("--no-restart-on-exit", action="store_true", help="Don't restart if the process exits on its own")
|
parser.add_argument("--no-restart-on-exit", action="store_true", help="Don't restart if the process exits on its own")
|
||||||
parser.add_argument("--pass-sigterm", action="store_true", help="Forward SIGTERM to child and exit when received")
|
parser.add_argument("--pass-sigterm", action="store_true", help="Forward SIGTERM to child and exit when received")
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose", "-v", action="store_true", help="Enable verbose logging"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-hash-fallback",
|
||||||
|
action="store_true",
|
||||||
|
help="Use content hashing as fallback for Docker environments",
|
||||||
|
)
|
||||||
# Accept the command to run as a positional "remainder" so callers can
|
# Accept the command to run as a positional "remainder" so callers can
|
||||||
# separate options with `--` and have everything after it treated as the
|
# separate options with `--` and have everything after it treated as the
|
||||||
# command. Defining an option named "--" doesn't work reliably with
|
# command. Defining an option named "--" doesn't work reliably with
|
||||||
@ -93,6 +166,20 @@ def main() -> int:
|
|||||||
watch_paths = args.watch
|
watch_paths = args.watch
|
||||||
|
|
||||||
last_mtimes = scan_py_mtimes(watch_paths)
|
last_mtimes = scan_py_mtimes(watch_paths)
|
||||||
|
last_hashes = scan_py_hashes(watch_paths) if args.use_hash_fallback else {}
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
print(f"Watching {len(last_mtimes)} Python files in paths: {watch_paths}")
|
||||||
|
print(f"Working directory: {os.getcwd()}")
|
||||||
|
print(f"Resolved watch paths: {[os.path.abspath(p) for p in watch_paths]}")
|
||||||
|
print(f"Polling interval: {args.interval}s")
|
||||||
|
if args.use_hash_fallback:
|
||||||
|
print("Using content hash fallback for change detection")
|
||||||
|
print("Sample files being watched:")
|
||||||
|
for fp in sorted(last_mtimes.keys())[:5]:
|
||||||
|
print(f" {fp}")
|
||||||
|
if len(last_mtimes) > 5:
|
||||||
|
print(f" ... and {len(last_mtimes) - 5} more")
|
||||||
|
|
||||||
child = start_process(cmd)
|
child = start_process(cmd)
|
||||||
|
|
||||||
@ -125,31 +212,70 @@ def main() -> int:
|
|||||||
# else restart immediately
|
# else restart immediately
|
||||||
child = start_process(cmd)
|
child = start_process(cmd)
|
||||||
last_mtimes = scan_py_mtimes(watch_paths)
|
last_mtimes = scan_py_mtimes(watch_paths)
|
||||||
|
last_hashes = (
|
||||||
|
scan_py_hashes(watch_paths) if args.use_hash_fallback else {}
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check for source changes
|
# Check for source changes
|
||||||
current = scan_py_mtimes(watch_paths)
|
current = scan_py_mtimes(watch_paths)
|
||||||
changed = False
|
changed = False
|
||||||
|
change_reason = ""
|
||||||
|
|
||||||
# Check for new or changed files
|
# Check for new or changed files
|
||||||
for fp, m in current.items():
|
for fp, m in current.items():
|
||||||
if fp not in last_mtimes or last_mtimes.get(fp) != m:
|
if fp not in last_mtimes or last_mtimes.get(fp) != m:
|
||||||
print("Detected change in:", fp)
|
print("Detected change in:", fp)
|
||||||
|
if args.verbose:
|
||||||
|
old_mtime = last_mtimes.get(fp, 0)
|
||||||
|
print(f" Old mtime: {old_mtime}, New mtime: {m}")
|
||||||
changed = True
|
changed = True
|
||||||
|
change_reason = f"mtime change in {fp}"
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Hash-based fallback check if mtime didn't detect changes
|
||||||
|
if not changed and args.use_hash_fallback:
|
||||||
|
current_hashes = scan_py_hashes(watch_paths)
|
||||||
|
for fp, h in current_hashes.items():
|
||||||
|
if fp not in last_hashes or last_hashes.get(fp) != h:
|
||||||
|
print("Detected content change in:", fp)
|
||||||
|
if args.verbose:
|
||||||
|
print(
|
||||||
|
f" Hash changed: {last_hashes.get(fp, 'None')} -> {h}"
|
||||||
|
)
|
||||||
|
changed = True
|
||||||
|
change_reason = f"content change in {fp}"
|
||||||
|
break
|
||||||
|
# Update hash cache
|
||||||
|
last_hashes = current_hashes
|
||||||
|
|
||||||
# Check for deleted files
|
# Check for deleted files
|
||||||
if not changed:
|
if not changed:
|
||||||
for fp in list(last_mtimes.keys()):
|
for fp in list(last_mtimes.keys()):
|
||||||
if fp not in current:
|
if fp not in current:
|
||||||
print("Detected deleted file:", fp)
|
print("Detected deleted file:", fp)
|
||||||
changed = True
|
changed = True
|
||||||
|
change_reason = f"deleted file {fp}"
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Additional debug output
|
||||||
|
if args.verbose and not changed:
|
||||||
|
num_files = len(current)
|
||||||
|
if num_files != len(last_mtimes):
|
||||||
|
print(f"File count changed: {len(last_mtimes)} -> {num_files}")
|
||||||
|
changed = True
|
||||||
|
change_reason = "file count change"
|
||||||
|
|
||||||
if changed:
|
if changed:
|
||||||
|
if args.verbose:
|
||||||
|
print(f"Restarting due to: {change_reason}")
|
||||||
# Small debounce
|
# Small debounce
|
||||||
time.sleep(args.delay_restart)
|
time.sleep(args.delay_restart)
|
||||||
terminate_process(child)
|
terminate_process(child)
|
||||||
child = start_process(cmd)
|
child = start_process(cmd)
|
||||||
last_mtimes = scan_py_mtimes(watch_paths)
|
last_mtimes = scan_py_mtimes(watch_paths)
|
||||||
|
if args.use_hash_fallback:
|
||||||
|
last_hashes = scan_py_hashes(watch_paths)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("Interrupted, shutting down.")
|
print("Interrupted, shutting down.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user