chore(backend): MQTT/WLED/devices/capture/utils + api routes hardening
Bundle the remaining backend touch-ups that the production review landed individually as small surgical edits across many modules: - MQTT runtime: fire-and-forget task tracking + drain resilience. - mqtt_source + store + storage/color_strip_source: secret_box encryption for credentials with auto-migration of plaintext fields. - devices/discovery_watcher: task tracking on watcher start/stop. - devices/wled_client + wled_provider: URL scheme inference helper applied at the create/update boundary so bare hostnames stay valid. - core/capture/screen_capture: hardened error paths. - core/processing (mapped/processed/processor_manager/video/wled_target): smaller follow-throughs from the registry refactor that landed earlier on the branch. - utils/safe_source + utils/file_ops + utils/__init__: shared URL + IP classification helpers + larger streaming upload size caps. - api/auth: WebSocket Origin allow-list + /docs auth-gate. - api/dependencies: register the new HTTP-endpoint store. - api/routes (assets, backup, webhooks): streaming-upload caps + asyncio.gather return_exceptions on broadcast loops. - tests/test_api + tests/e2e/test_backup_flow: cover the new caps and the Origin allow-list.
This commit is contained in:
@@ -11,14 +11,13 @@ from starlette.websockets import WebSocket, WebSocketDisconnect
|
|||||||
|
|
||||||
from ledgrab.config import get_config
|
from ledgrab.config import get_config
|
||||||
from ledgrab.utils import get_logger
|
from ledgrab.utils import get_logger
|
||||||
|
from ledgrab.utils.net_classify import is_loopback as _classify_is_loopback
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
# Security scheme for Bearer token
|
# Security scheme for Bearer token
|
||||||
security = HTTPBearer(auto_error=False)
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
_LOOPBACK_HOSTS = frozenset({"127.0.0.1", "::1", "localhost", "testclient"})
|
|
||||||
|
|
||||||
|
|
||||||
def is_auth_enabled() -> bool:
|
def is_auth_enabled() -> bool:
|
||||||
"""Return True when at least one API key is configured."""
|
"""Return True when at least one API key is configured."""
|
||||||
@@ -26,15 +25,15 @@ def is_auth_enabled() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _is_loopback(host: str | None) -> bool:
|
def _is_loopback(host: str | None) -> bool:
|
||||||
"""Return True when *host* is a loopback address."""
|
"""Return True when *host* is a loopback address.
|
||||||
|
|
||||||
|
Delegates to :func:`ledgrab.utils.net_classify.is_loopback` so this
|
||||||
|
auth gate, the SSRF guard in ``safe_source``, and the LAN-default
|
||||||
|
inference in ``url_scheme`` share one classification source.
|
||||||
|
"""
|
||||||
if not host:
|
if not host:
|
||||||
return False
|
return False
|
||||||
# Strip IPv6 brackets and zone IDs
|
return _classify_is_loopback(host)
|
||||||
h = host.strip().lower()
|
|
||||||
if h.startswith("[") and h.endswith("]"):
|
|
||||||
h = h[1:-1]
|
|
||||||
h = h.split("%", 1)[0]
|
|
||||||
return h in _LOOPBACK_HOSTS
|
|
||||||
|
|
||||||
|
|
||||||
def verify_api_key(
|
def verify_api_key(
|
||||||
@@ -142,6 +141,23 @@ def require_authenticated(label: str) -> None:
|
|||||||
WS_AUTH_CLOSE_CODE = 4401
|
WS_AUTH_CLOSE_CODE = 4401
|
||||||
|
|
||||||
|
|
||||||
|
WS_ORIGIN_CLOSE_CODE = 4403
|
||||||
|
"""Close code sent when a WebSocket request fails the Origin allowlist."""
|
||||||
|
|
||||||
|
|
||||||
|
def _is_origin_allowed(origin: str | None, allowed: list[str]) -> bool:
|
||||||
|
"""Return True when *origin* matches one of the configured CORS origins.
|
||||||
|
|
||||||
|
Non-browser clients (Python scripts, curl) don't send Origin — those are
|
||||||
|
allowed through; the Bearer-token check on the auth handshake is the
|
||||||
|
primary defence in that case. Browsers always set Origin, so this only
|
||||||
|
blocks cross-site WebSocket connection attempts (CSWSH).
|
||||||
|
"""
|
||||||
|
if not origin:
|
||||||
|
return True
|
||||||
|
return origin in set(allowed or [])
|
||||||
|
|
||||||
|
|
||||||
async def accept_and_authenticate_ws(websocket: WebSocket, timeout: float = 3.0) -> str | None:
|
async def accept_and_authenticate_ws(websocket: WebSocket, timeout: float = 3.0) -> str | None:
|
||||||
"""Accept the WebSocket, then perform first-message auth handshake.
|
"""Accept the WebSocket, then perform first-message auth handshake.
|
||||||
|
|
||||||
@@ -152,6 +168,23 @@ async def accept_and_authenticate_ws(websocket: WebSocket, timeout: float = 3.0)
|
|||||||
Returns the caller label on success, ``None`` on failure (connection
|
Returns the caller label on success, ``None`` on failure (connection
|
||||||
already closed).
|
already closed).
|
||||||
"""
|
"""
|
||||||
|
# Reject cross-site WebSocket attempts before accepting — a browser-based
|
||||||
|
# attacker page cannot forge the Origin header, so an Origin mismatch is
|
||||||
|
# a strong signal even before the token check. Non-browser clients
|
||||||
|
# legitimately omit Origin; those fall through to the auth handshake.
|
||||||
|
config = get_config()
|
||||||
|
origin = websocket.headers.get("origin")
|
||||||
|
if not _is_origin_allowed(origin, config.server.cors_origins):
|
||||||
|
logger.warning(
|
||||||
|
"Rejected WebSocket from origin %r (not in cors_origins)",
|
||||||
|
origin,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await websocket.close(code=WS_ORIGIN_CLOSE_CODE)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
label = await verify_ws_auth(websocket, timeout=timeout)
|
label = await verify_ws_auth(websocket, timeout=timeout)
|
||||||
if label is None:
|
if label is None:
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from ledgrab.storage.game_integration_store import GameIntegrationStore
|
|||||||
from ledgrab.core.game_integration.event_bus import GameEventBus
|
from ledgrab.core.game_integration.event_bus import GameEventBus
|
||||||
from ledgrab.storage.mqtt_source_store import MQTTSourceStore
|
from ledgrab.storage.mqtt_source_store import MQTTSourceStore
|
||||||
from ledgrab.core.mqtt.mqtt_manager import MQTTManager
|
from ledgrab.core.mqtt.mqtt_manager import MQTTManager
|
||||||
|
from ledgrab.storage.http_endpoint_store import HTTPEndpointStore
|
||||||
from ledgrab.storage.audio_processing_template_store import AudioProcessingTemplateStore
|
from ledgrab.storage.audio_processing_template_store import AudioProcessingTemplateStore
|
||||||
from ledgrab.storage.pattern_template_store import PatternTemplateStore
|
from ledgrab.storage.pattern_template_store import PatternTemplateStore
|
||||||
|
|
||||||
@@ -165,6 +166,10 @@ def get_mqtt_manager() -> MQTTManager:
|
|||||||
return _get("mqtt_manager", "MQTT manager")
|
return _get("mqtt_manager", "MQTT manager")
|
||||||
|
|
||||||
|
|
||||||
|
def get_http_endpoint_store() -> HTTPEndpointStore:
|
||||||
|
return _get("http_endpoint_store", "HTTP endpoint store")
|
||||||
|
|
||||||
|
|
||||||
def get_audio_processing_template_store() -> AudioProcessingTemplateStore:
|
def get_audio_processing_template_store() -> AudioProcessingTemplateStore:
|
||||||
return _get("audio_processing_template_store", "Audio processing template store")
|
return _get("audio_processing_template_store", "Audio processing template store")
|
||||||
|
|
||||||
@@ -237,6 +242,7 @@ def init_dependencies(
|
|||||||
game_event_bus: GameEventBus | None = None,
|
game_event_bus: GameEventBus | None = None,
|
||||||
mqtt_store: MQTTSourceStore | None = None,
|
mqtt_store: MQTTSourceStore | None = None,
|
||||||
mqtt_manager: MQTTManager | None = None,
|
mqtt_manager: MQTTManager | None = None,
|
||||||
|
http_endpoint_store: HTTPEndpointStore | None = None,
|
||||||
audio_processing_template_store: AudioProcessingTemplateStore | None = None,
|
audio_processing_template_store: AudioProcessingTemplateStore | None = None,
|
||||||
pattern_template_store: PatternTemplateStore | None = None,
|
pattern_template_store: PatternTemplateStore | None = None,
|
||||||
):
|
):
|
||||||
@@ -272,6 +278,7 @@ def init_dependencies(
|
|||||||
"game_event_bus": game_event_bus,
|
"game_event_bus": game_event_bus,
|
||||||
"mqtt_store": mqtt_store,
|
"mqtt_store": mqtt_store,
|
||||||
"mqtt_manager": mqtt_manager,
|
"mqtt_manager": mqtt_manager,
|
||||||
|
"http_endpoint_store": http_endpoint_store,
|
||||||
"audio_processing_template_store": audio_processing_template_store,
|
"audio_processing_template_store": audio_processing_template_store,
|
||||||
"pattern_template_store": pattern_template_store,
|
"pattern_template_store": pattern_template_store,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from ledgrab.api.schemas.assets import (
|
|||||||
from ledgrab.config import get_config
|
from ledgrab.config import get_config
|
||||||
from ledgrab.storage.asset_store import AssetStore
|
from ledgrab.storage.asset_store import AssetStore
|
||||||
from ledgrab.storage.base_store import EntityNotFoundError
|
from ledgrab.storage.base_store import EntityNotFoundError
|
||||||
from ledgrab.utils import get_logger
|
from ledgrab.utils import get_logger, read_upload_capped
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -93,10 +93,11 @@ async def upload_asset(
|
|||||||
config = get_config()
|
config = get_config()
|
||||||
max_size = getattr(getattr(config, "assets", None), "max_file_size_mb", 50) * 1024 * 1024
|
max_size = getattr(getattr(config, "assets", None), "max_file_size_mb", 50) * 1024 * 1024
|
||||||
|
|
||||||
data = await file.read()
|
try:
|
||||||
if len(data) > max_size:
|
data = await read_upload_capped(file, max_size)
|
||||||
|
except ValueError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=413,
|
||||||
detail=f"File too large (max {max_size // (1024 * 1024)} MB)",
|
detail=f"File too large (max {max_size // (1024 * 1024)} MB)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from ledgrab.config import get_config
|
|||||||
from ledgrab.core.backup.auto_backup import AutoBackupEngine
|
from ledgrab.core.backup.auto_backup import AutoBackupEngine
|
||||||
from ledgrab.storage.asset_store import AssetStore
|
from ledgrab.storage.asset_store import AssetStore
|
||||||
from ledgrab.storage.database import Database, freeze_writes
|
from ledgrab.storage.database import Database, freeze_writes
|
||||||
from ledgrab.utils import get_logger
|
from ledgrab.utils import get_logger, read_upload_capped
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -133,9 +133,11 @@ async def restore_config(
|
|||||||
because restore replaces all configuration including secrets).
|
because restore replaces all configuration including secrets).
|
||||||
"""
|
"""
|
||||||
require_authenticated(auth)
|
require_authenticated(auth)
|
||||||
raw = await file.read()
|
_MAX_BACKUP_BYTES = 200 * 1024 * 1024 # 200 MB (ZIP may contain assets)
|
||||||
if len(raw) > 200 * 1024 * 1024: # 200 MB limit (ZIP may contain assets)
|
try:
|
||||||
raise HTTPException(status_code=400, detail="Backup file too large (max 200 MB)")
|
raw = await read_upload_capped(file, _MAX_BACKUP_BYTES)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=413, detail="Backup file too large (max 200 MB)")
|
||||||
|
|
||||||
if len(raw) < 100:
|
if len(raw) < 100:
|
||||||
raise HTTPException(status_code=400, detail="File too small to be a valid backup")
|
raise HTTPException(status_code=400, detail="File too small to be a valid backup")
|
||||||
|
|||||||
@@ -30,6 +30,9 @@ _RATE_WINDOW = 60.0 # seconds
|
|||||||
_rate_hits: dict[str, list[float]] = defaultdict(list)
|
_rate_hits: dict[str, list[float]] = defaultdict(list)
|
||||||
|
|
||||||
|
|
||||||
|
_RATE_HITS_HARD_CAP = 1024
|
||||||
|
|
||||||
|
|
||||||
def _check_rate_limit(client_ip: str) -> None:
|
def _check_rate_limit(client_ip: str) -> None:
|
||||||
"""Raise 429 if *client_ip* exceeded the webhook rate limit."""
|
"""Raise 429 if *client_ip* exceeded the webhook rate limit."""
|
||||||
now = time.time()
|
now = time.time()
|
||||||
@@ -44,11 +47,21 @@ def _check_rate_limit(client_ip: str) -> None:
|
|||||||
)
|
)
|
||||||
_rate_hits[client_ip].append(now)
|
_rate_hits[client_ip].append(now)
|
||||||
|
|
||||||
# Periodic cleanup: remove IPs with no recent hits to prevent unbounded growth
|
# Periodic cleanup: remove IPs with no recent hits to prevent unbounded growth.
|
||||||
if len(_rate_hits) > 100:
|
if len(_rate_hits) > 100:
|
||||||
stale = [ip for ip, ts in _rate_hits.items() if not ts or ts[-1] < window_start]
|
stale = [ip for ip, ts in _rate_hits.items() if not ts or ts[-1] < window_start]
|
||||||
for ip in stale:
|
for ip in stale:
|
||||||
del _rate_hits[ip]
|
del _rate_hits[ip]
|
||||||
|
# Hard cap as a final defence against an attacker spraying many distinct
|
||||||
|
# X-Forwarded-For values to drive memory growth past the soft cleanup
|
||||||
|
# threshold. Drop the oldest-touched IPs (by their latest timestamp).
|
||||||
|
if len(_rate_hits) > _RATE_HITS_HARD_CAP:
|
||||||
|
ordered = sorted(
|
||||||
|
_rate_hits.items(),
|
||||||
|
key=lambda kv: kv[1][-1] if kv[1] else 0.0,
|
||||||
|
)
|
||||||
|
for ip, _ in ordered[: len(ordered) - _RATE_HITS_HARD_CAP]:
|
||||||
|
_rate_hits.pop(ip, None)
|
||||||
|
|
||||||
|
|
||||||
class WebhookPayload(BaseModel):
|
class WebhookPayload(BaseModel):
|
||||||
|
|||||||
@@ -14,6 +14,12 @@ from ledgrab.utils import get_logger, get_monitor_names, get_monitor_refresh_rat
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Reused random Generator for sampling. The legacy ``np.random.randint``
|
||||||
|
# uses the module-level RandomState which is slightly slower per-call and
|
||||||
|
# pulls in extra import-time work; a single Generator is faster and avoids
|
||||||
|
# global-state surprises.
|
||||||
|
_rng = np.random.default_rng()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DisplayInfo:
|
class DisplayInfo:
|
||||||
@@ -326,7 +332,11 @@ def calculate_dominant_color(pixels: np.ndarray) -> tuple[int, int, int]:
|
|||||||
|
|
||||||
max_samples = 1000
|
max_samples = 1000
|
||||||
if n > max_samples:
|
if n > max_samples:
|
||||||
indices = np.random.randint(0, n, max_samples)
|
# ``Generator.integers`` writes into a fresh buffer once per call;
|
||||||
|
# the legacy ``np.random.randint`` did the same plus extra
|
||||||
|
# bookkeeping. Random (not stride) sampling stays robust against
|
||||||
|
# periodic patterns in screen pixels.
|
||||||
|
indices = _rng.integers(0, n, size=max_samples)
|
||||||
pixels_reshaped = pixels_reshaped[indices]
|
pixels_reshaped = pixels_reshaped[indices]
|
||||||
|
|
||||||
# Quantize to 32 levels/channel (drop low 3 bits) and pack into uint32:
|
# Quantize to 32 levels/channel (drop low 3 bits) and pack into uint32:
|
||||||
|
|||||||
@@ -85,6 +85,10 @@ class DiscoveryWatcher:
|
|||||||
self._wled_seen: Dict[str, _DiscoveredEntry] = {}
|
self._wled_seen: Dict[str, _DiscoveredEntry] = {}
|
||||||
# device-path -> entry. Only the serial poller mutates this.
|
# device-path -> entry. Only the serial poller mutates this.
|
||||||
self._serial_seen: Dict[str, _DiscoveredEntry] = {}
|
self._serial_seen: Dict[str, _DiscoveredEntry] = {}
|
||||||
|
# Strong references for fire-and-forget resolve tasks — without
|
||||||
|
# these, Python 3.11+ may GC the task mid-resolve and silently lose
|
||||||
|
# discovery events. Tasks remove themselves on completion.
|
||||||
|
self._resolve_tasks: "set[asyncio.Task]" = set()
|
||||||
|
|
||||||
# --- lifecycle --------------------------------------------------------
|
# --- lifecycle --------------------------------------------------------
|
||||||
|
|
||||||
@@ -155,12 +159,23 @@ class DiscoveryWatcher:
|
|||||||
if state_change in (ServiceStateChange.Added, ServiceStateChange.Updated):
|
if state_change in (ServiceStateChange.Added, ServiceStateChange.Updated):
|
||||||
# Resolve in a task — async_request blocks the handler if awaited
|
# Resolve in a task — async_request blocks the handler if awaited
|
||||||
# synchronously and we don't want to stall mDNS dispatch.
|
# synchronously and we don't want to stall mDNS dispatch.
|
||||||
asyncio.create_task(self._resolve_wled(service_type, name))
|
task = asyncio.create_task(self._resolve_wled(service_type, name))
|
||||||
|
self._resolve_tasks.add(task)
|
||||||
|
task.add_done_callback(self._on_resolve_done)
|
||||||
elif state_change == ServiceStateChange.Removed:
|
elif state_change == ServiceStateChange.Removed:
|
||||||
entry = self._wled_seen.pop(name, None)
|
entry = self._wled_seen.pop(name, None)
|
||||||
if entry is not None and not self._is_configured(entry.url):
|
if entry is not None and not self._is_configured(entry.url):
|
||||||
self._emit("device_lost", entry)
|
self._emit("device_lost", entry)
|
||||||
|
|
||||||
|
def _on_resolve_done(self, task: "asyncio.Task") -> None:
|
||||||
|
"""Release the strong task reference and log any resolve failure."""
|
||||||
|
self._resolve_tasks.discard(task)
|
||||||
|
if task.cancelled():
|
||||||
|
return
|
||||||
|
exc = task.exception()
|
||||||
|
if exc is not None:
|
||||||
|
logger.debug("Discovery watcher: resolve task raised: %s", exc)
|
||||||
|
|
||||||
async def _resolve_wled(self, service_type: str, name: str) -> None:
|
async def _resolve_wled(self, service_type: str, name: str) -> None:
|
||||||
if self._aiozc is None:
|
if self._aiozc is None:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -453,8 +453,15 @@ class WLEDClient(LEDClient):
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.debug(f"Sending {len(pixels)} LEDs via HTTP ({len(indexed_pixels)} values)")
|
# ``str(payload)`` previously stringified the entire indexed
|
||||||
logger.debug(f"Payload size: ~{len(str(payload))} bytes")
|
# array on every send to report a byte estimate; that was the
|
||||||
|
# hot path. Drop the size readout — the LED count + indexed
|
||||||
|
# value count is enough to interpret traffic and is O(1).
|
||||||
|
logger.debug(
|
||||||
|
"Sending %d LEDs via HTTP (%d indexed values)",
|
||||||
|
len(pixels),
|
||||||
|
len(indexed_pixels),
|
||||||
|
)
|
||||||
|
|
||||||
await self._request("POST", "/json/state", json_data=payload)
|
await self._request("POST", "/json/state", json_data=payload)
|
||||||
logger.debug("Successfully sent pixel colors via HTTP")
|
logger.debug("Successfully sent pixel colors via HTTP")
|
||||||
|
|||||||
@@ -98,11 +98,18 @@ class WLEDDeviceProvider(LEDDeviceProvider):
|
|||||||
dict with 'led_count' key.
|
dict with 'led_count' key.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
ValueError: Unsupported scheme or invalid LED count.
|
||||||
httpx.ConnectError: Device unreachable.
|
httpx.ConnectError: Device unreachable.
|
||||||
httpx.TimeoutException: Connection timed out.
|
httpx.TimeoutException: Connection timed out.
|
||||||
ValueError: Invalid LED count.
|
|
||||||
"""
|
"""
|
||||||
url = _normalize_url(url)
|
url = _normalize_url(url)
|
||||||
|
# Reject anything that isn't plain HTTP(S). url_scheme.infer_http_scheme
|
||||||
|
# passes non-HTTP schemes through untouched ("javascript:", "file:",
|
||||||
|
# "data:", etc.); without this guard those would reach httpx and
|
||||||
|
# surface as opaque transport errors at best, or be silently misused
|
||||||
|
# at worst.
|
||||||
|
if not url.lower().startswith(("http://", "https://")):
|
||||||
|
raise ValueError(f"WLED URL must use http:// or https:// scheme (got {url!r})")
|
||||||
async with httpx.AsyncClient(timeout=5) as client:
|
async with httpx.AsyncClient(timeout=5) as client:
|
||||||
response = await client.get(_join(url, "/json/info"))
|
response = await client.get(_join(url, "/json/info"))
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|||||||
@@ -49,6 +49,32 @@ class MQTTRuntime:
|
|||||||
# Pending publishes queued while disconnected
|
# Pending publishes queued while disconnected
|
||||||
self._publish_queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
|
self._publish_queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
|
||||||
|
|
||||||
|
# Strong references for fire-and-forget callback dispatch tasks.
|
||||||
|
# Python 3.11+ may GC bare ``asyncio.create_task(...)`` results mid-
|
||||||
|
# flight, so we hold each task until it completes and surface any
|
||||||
|
# exception via the done-callback.
|
||||||
|
self._dispatch_tasks: Set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
# Compiled ``aiomqtt.Topic`` cache, keyed by the subscription pattern
|
||||||
|
# string. The previous dispatch loop re-parsed every pattern on
|
||||||
|
# every incoming message — on a chatty broker with many wildcards
|
||||||
|
# that adds up fast.
|
||||||
|
self._compiled_topics: Dict[str, aiomqtt.Topic] = {}
|
||||||
|
|
||||||
|
def _on_dispatch_done(self, task: asyncio.Task) -> None:
|
||||||
|
"""Drop the strong reference and surface any callback exception."""
|
||||||
|
self._dispatch_tasks.discard(task)
|
||||||
|
if task.cancelled():
|
||||||
|
return
|
||||||
|
exc = task.exception()
|
||||||
|
if exc is not None:
|
||||||
|
logger.error(
|
||||||
|
"MQTT async callback raised (%s): %s",
|
||||||
|
self._source_id,
|
||||||
|
exc,
|
||||||
|
exc_info=exc,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return self._connected
|
return self._connected
|
||||||
@@ -84,6 +110,14 @@ class MQTTRuntime:
|
|||||||
logger.debug("MQTT runtime task cancelled: %s", self._source_id)
|
logger.debug("MQTT runtime task cancelled: %s", self._source_id)
|
||||||
self._task = None
|
self._task = None
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
# Cancel any in-flight async dispatch callbacks. Without this they
|
||||||
|
# would keep running past the runtime's logical end of life and
|
||||||
|
# could fire callbacks on a stopped subscriber.
|
||||||
|
for task in list(self._dispatch_tasks):
|
||||||
|
task.cancel()
|
||||||
|
if self._dispatch_tasks:
|
||||||
|
await asyncio.gather(*self._dispatch_tasks, return_exceptions=True)
|
||||||
|
self._dispatch_tasks.clear()
|
||||||
logger.info("MQTT runtime stopped: %s", self._source_id)
|
logger.info("MQTT runtime stopped: %s", self._source_id)
|
||||||
|
|
||||||
def update_config(self, source: MQTTSource) -> None:
|
def update_config(self, source: MQTTSource) -> None:
|
||||||
@@ -167,13 +201,23 @@ class MQTTRuntime:
|
|||||||
for topic in self._subscriptions:
|
for topic in self._subscriptions:
|
||||||
await client.subscribe(topic)
|
await client.subscribe(topic)
|
||||||
|
|
||||||
# Drain pending publishes
|
# Drain pending publishes. A single publish failing
|
||||||
|
# (broker rejection, oversized message) must not lose
|
||||||
|
# the rest of the queue — log and continue with the next.
|
||||||
while not self._publish_queue.empty():
|
while not self._publish_queue.empty():
|
||||||
try:
|
try:
|
||||||
t, p, r, q = self._publish_queue.get_nowait()
|
t, p, r, q = self._publish_queue.get_nowait()
|
||||||
await client.publish(t, p, retain=r, qos=q)
|
except asyncio.QueueEmpty:
|
||||||
except Exception:
|
|
||||||
break
|
break
|
||||||
|
try:
|
||||||
|
await client.publish(t, p, retain=r, qos=q)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"MQTT drain publish failed (%s -> %s): %s",
|
||||||
|
self._source_id,
|
||||||
|
t,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
# Message receive loop
|
# Message receive loop
|
||||||
async for msg in client.messages:
|
async for msg in client.messages:
|
||||||
@@ -183,13 +227,21 @@ class MQTTRuntime:
|
|||||||
)
|
)
|
||||||
self._topic_cache[topic_str] = payload_str
|
self._topic_cache[topic_str] = payload_str
|
||||||
|
|
||||||
# Dispatch to callbacks
|
# Dispatch to callbacks. Pattern objects are cached
|
||||||
|
# per subscription string to avoid re-parsing them on
|
||||||
|
# every received message.
|
||||||
for sub_topic, callbacks in self._subscriptions.items():
|
for sub_topic, callbacks in self._subscriptions.items():
|
||||||
if aiomqtt.Topic(sub_topic).matches(msg.topic):
|
compiled = self._compiled_topics.get(sub_topic)
|
||||||
|
if compiled is None:
|
||||||
|
compiled = aiomqtt.Topic(sub_topic)
|
||||||
|
self._compiled_topics[sub_topic] = compiled
|
||||||
|
if compiled.matches(msg.topic):
|
||||||
for cb in callbacks:
|
for cb in callbacks:
|
||||||
try:
|
try:
|
||||||
if asyncio.iscoroutinefunction(cb):
|
if asyncio.iscoroutinefunction(cb):
|
||||||
asyncio.create_task(cb(topic_str, payload_str))
|
task = asyncio.create_task(cb(topic_str, payload_str))
|
||||||
|
self._dispatch_tasks.add(task)
|
||||||
|
task.add_done_callback(self._on_dispatch_done)
|
||||||
else:
|
else:
|
||||||
cb(topic_str, payload_str)
|
cb(topic_str, payload_str)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -11,6 +12,11 @@ from ledgrab.utils import get_logger
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Cap the (src_len, dst_len) resize cache. Each entry holds two ``np.linspace``
|
||||||
|
# arrays plus a per-zone ``uint8`` scratch buffer, which used to grow without
|
||||||
|
# bound under hot reconfigure storms.
|
||||||
|
_RESIZE_CACHE_MAX = 16
|
||||||
|
|
||||||
|
|
||||||
class MappedColorStripStream(ColorStripStream):
|
class MappedColorStripStream(ColorStripStream):
|
||||||
"""Places multiple ColorStripStreams side-by-side at distinct LED ranges.
|
"""Places multiple ColorStripStreams side-by-side at distinct LED ranges.
|
||||||
@@ -46,8 +52,11 @@ class MappedColorStripStream(ColorStripStream):
|
|||||||
|
|
||||||
# zone_index -> (source_id, consumer_id, stream)
|
# zone_index -> (source_id, consumer_id, stream)
|
||||||
self._sub_streams: Dict[int, tuple] = {}
|
self._sub_streams: Dict[int, tuple] = {}
|
||||||
# (src_len, dst_len) -> (src_x, dst_x, buffer) cache for zone resizing
|
# (src_len, dst_len) -> (src_x, dst_x, buffer) cache for zone resizing.
|
||||||
self._resize_cache: Dict[tuple, tuple] = {}
|
# An ``OrderedDict`` with eviction keeps memory bounded if device
|
||||||
|
# configurations fluctuate at runtime (each unique pair adds two
|
||||||
|
# linspace arrays + a per-zone reusable uint8 buffer).
|
||||||
|
self._resize_cache: "OrderedDict[tuple, tuple]" = OrderedDict()
|
||||||
self._sub_lock = threading.Lock() # guards _sub_streams access across threads
|
self._sub_lock = threading.Lock() # guards _sub_streams access across threads
|
||||||
|
|
||||||
# ── ColorStripStream interface ──────────────────────────────
|
# ── ColorStripStream interface ──────────────────────────────
|
||||||
@@ -229,6 +238,14 @@ class MappedColorStripStream(ColorStripStream):
|
|||||||
np.empty((zone_len, 3), dtype=np.uint8),
|
np.empty((zone_len, 3), dtype=np.uint8),
|
||||||
)
|
)
|
||||||
self._resize_cache[rkey] = cached
|
self._resize_cache[rkey] = cached
|
||||||
|
# Drop the least-recently-inserted entry once
|
||||||
|
# we hit the cap. 16 entries comfortably covers
|
||||||
|
# any realistic zone/source layout — pathological
|
||||||
|
# reconfigure storms used to grow this forever.
|
||||||
|
if len(self._resize_cache) > _RESIZE_CACHE_MAX:
|
||||||
|
self._resize_cache.popitem(last=False)
|
||||||
|
else:
|
||||||
|
self._resize_cache.move_to_end(rkey)
|
||||||
src_x, dst_x, resized = cached
|
src_x, dst_x, resized = cached
|
||||||
for ch in range(3):
|
for ch in range(3):
|
||||||
np.copyto(
|
np.copyto(
|
||||||
|
|||||||
@@ -160,9 +160,11 @@ class ProcessedColorStripStream(ColorStripStream):
|
|||||||
self._resolve_count = 0
|
self._resolve_count = 0
|
||||||
self._resolve_filters()
|
self._resolve_filters()
|
||||||
|
|
||||||
colors = None
|
# Bind to a local first — ``update_source()`` may swap or null
|
||||||
if self._input_stream:
|
# out ``_input_stream`` between the check and the read on a
|
||||||
colors = self._input_stream.get_latest_colors()
|
# different thread.
|
||||||
|
inp = self._input_stream
|
||||||
|
colors = inp.get_latest_colors() if inp is not None else None
|
||||||
|
|
||||||
if colors is not None and self._filters:
|
if colors is not None and self._filters:
|
||||||
for flt in self._filters:
|
for flt in self._filters:
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from ledgrab.storage.picture_source_store import PictureSourceStore
|
|||||||
from ledgrab.storage.postprocessing_template_store import PostprocessingTemplateStore
|
from ledgrab.storage.postprocessing_template_store import PostprocessingTemplateStore
|
||||||
from ledgrab.storage.template_store import TemplateStore
|
from ledgrab.storage.template_store import TemplateStore
|
||||||
from ledgrab.storage.value_source_store import ValueSourceStore
|
from ledgrab.storage.value_source_store import ValueSourceStore
|
||||||
|
from ledgrab.storage.http_endpoint_store import HTTPEndpointStore
|
||||||
from ledgrab.storage.asset_store import AssetStore
|
from ledgrab.storage.asset_store import AssetStore
|
||||||
from ledgrab.core.processing.sync_clock_manager import SyncClockManager
|
from ledgrab.core.processing.sync_clock_manager import SyncClockManager
|
||||||
from ledgrab.core.weather.weather_manager import WeatherManager
|
from ledgrab.core.weather.weather_manager import WeatherManager
|
||||||
@@ -74,6 +75,7 @@ class ProcessorDependencies:
|
|||||||
mqtt_manager: Optional[Any] = None # MQTTManager
|
mqtt_manager: Optional[Any] = None # MQTTManager
|
||||||
game_event_bus: Optional[Any] = None # GameEventBus
|
game_event_bus: Optional[Any] = None # GameEventBus
|
||||||
audio_processing_template_store: Optional[Any] = None # AudioProcessingTemplateStore
|
audio_processing_template_store: Optional[Any] = None # AudioProcessingTemplateStore
|
||||||
|
http_endpoint_store: Optional[HTTPEndpointStore] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -169,6 +171,7 @@ class ProcessorManager(AutoRestartMixin, DeviceHealthMixin, DeviceTestModeMixin)
|
|||||||
event_bus=deps.game_event_bus,
|
event_bus=deps.game_event_bus,
|
||||||
audio_processing_template_store=deps.audio_processing_template_store,
|
audio_processing_template_store=deps.audio_processing_template_store,
|
||||||
sync_clock_manager=deps.sync_clock_manager,
|
sync_clock_manager=deps.sync_clock_manager,
|
||||||
|
http_endpoint_store=deps.http_endpoint_store,
|
||||||
)
|
)
|
||||||
if deps.value_source_store
|
if deps.value_source_store
|
||||||
else None
|
else None
|
||||||
|
|||||||
@@ -37,6 +37,33 @@ def is_youtube_url(url: str) -> bool:
|
|||||||
return any(p.search(url) for p in _YT_PATTERNS)
|
return any(p.search(url) for p in _YT_PATTERNS)
|
||||||
|
|
||||||
|
|
||||||
|
# Network schemes accepted by ``cv2.VideoCapture``. ``file://`` is intentionally
|
||||||
|
# excluded — local files are supported via plain path strings (no scheme), so
|
||||||
|
# explicit ``file://`` requests can only be an attempt to coerce FFmpeg into
|
||||||
|
# loading something the path-string code path would reject. ``concat:``,
|
||||||
|
# ``gopher://``, ``crypto:``, etc. are not allowed.
|
||||||
|
_ALLOWED_NETWORK_SCHEMES: tuple[str, ...] = ("http", "https", "rtsp", "rtsps")
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_video_url_allowed(url: str) -> None:
|
||||||
|
"""Reject video URLs that use anything other than a vetted scheme.
|
||||||
|
|
||||||
|
OpenCV/FFmpeg supports many esoteric input protocols (``concat:``,
|
||||||
|
``gopher://``, ``crypto:``, ``udp://``, ``async:``, …). Some can read
|
||||||
|
arbitrary host files or pivot to internal addresses when the caller
|
||||||
|
can influence the URL. Tighten the input to the schemes we actually
|
||||||
|
advertise. URLs without a scheme are accepted as local-file paths.
|
||||||
|
"""
|
||||||
|
if "://" not in url:
|
||||||
|
return # plain local path — OpenCV resolves against the working dir
|
||||||
|
scheme = url.split("://", 1)[0].lower()
|
||||||
|
if scheme not in _ALLOWED_NETWORK_SCHEMES:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Refusing to open video with unsupported scheme {scheme!r}; "
|
||||||
|
f"allowed: {', '.join(_ALLOWED_NETWORK_SCHEMES)} or a local file path."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def resolve_youtube_url(url: str, resolution_limit: Optional[int] = None) -> str:
|
def resolve_youtube_url(url: str, resolution_limit: Optional[int] = None) -> str:
|
||||||
"""Resolve a YouTube URL to a direct stream URL using yt-dlp."""
|
"""Resolve a YouTube URL to a direct stream URL using yt-dlp."""
|
||||||
try:
|
try:
|
||||||
@@ -185,10 +212,14 @@ class VideoCaptureLiveStream(LiveStream):
|
|||||||
if self._running:
|
if self._running:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Resolve YouTube URL if needed
|
# Resolve YouTube URL if needed. Validate AFTER resolution too, so a
|
||||||
|
# malicious yt-dlp result (or a redirect we don't expect) can't slip
|
||||||
|
# through with an unsupported scheme.
|
||||||
actual_url = self._original_url
|
actual_url = self._original_url
|
||||||
|
_assert_video_url_allowed(actual_url)
|
||||||
if is_youtube_url(actual_url):
|
if is_youtube_url(actual_url):
|
||||||
actual_url = resolve_youtube_url(actual_url, self._resolution_limit)
|
actual_url = resolve_youtube_url(actual_url, self._resolution_limit)
|
||||||
|
_assert_video_url_allowed(actual_url)
|
||||||
self._resolved_url = actual_url
|
self._resolved_url = actual_url
|
||||||
|
|
||||||
# Open capture
|
# Open capture
|
||||||
|
|||||||
@@ -224,7 +224,8 @@ class WledTargetProcessor(TargetProcessor):
|
|||||||
|
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
|
|
||||||
# Cancel task
|
# Cancel task. The cancellation is awaited above, so the prior
|
||||||
|
# 50 ms ``asyncio.sleep`` here was pure dead time on every stop().
|
||||||
if self._task:
|
if self._task:
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
try:
|
try:
|
||||||
@@ -233,7 +234,6 @@ class WledTargetProcessor(TargetProcessor):
|
|||||||
logger.debug("WLED target processor task cancelled")
|
logger.debug("WLED target processor task cancelled")
|
||||||
pass
|
pass
|
||||||
self._task = None
|
self._task = None
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
|
|
||||||
# Restore device state (only if auto_shutdown is enabled)
|
# Restore device state (only if auto_shutdown is enabled)
|
||||||
if self._led_client and self._device_state_before:
|
if self._led_client and self._device_state_before:
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from dataclasses import dataclass, field
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from ledgrab.utils import secret_box
|
||||||
|
|
||||||
|
|
||||||
def _parse_common(data: dict) -> dict:
|
def _parse_common(data: dict) -> dict:
|
||||||
"""Extract common fields from a dict, parsing timestamps."""
|
"""Extract common fields from a dict, parsing timestamps."""
|
||||||
@@ -52,13 +54,16 @@ class MQTTSource:
|
|||||||
icon_color: str = ""
|
icon_color: str = ""
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
|
# Always persist the broker password in encrypted envelope form. If
|
||||||
|
# the field already contains an envelope, ``encrypt()`` is a no-op.
|
||||||
|
stored_password = secret_box.encrypt(self.password) if self.password else ""
|
||||||
d = {
|
d = {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"broker_host": self.broker_host,
|
"broker_host": self.broker_host,
|
||||||
"broker_port": self.broker_port,
|
"broker_port": self.broker_port,
|
||||||
"username": self.username,
|
"username": self.username,
|
||||||
"password": self.password,
|
"password": stored_password,
|
||||||
"client_id": self.client_id,
|
"client_id": self.client_id,
|
||||||
"base_topic": self.base_topic,
|
"base_topic": self.base_topic,
|
||||||
"description": self.description,
|
"description": self.description,
|
||||||
@@ -75,12 +80,17 @@ class MQTTSource:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_dict(data: dict) -> "MQTTSource":
|
def from_dict(data: dict) -> "MQTTSource":
|
||||||
common = _parse_common(data)
|
common = _parse_common(data)
|
||||||
|
# Decrypt at load time so consumers see plaintext via ``.password``.
|
||||||
|
# Legacy plaintext rows pass through unchanged — the next save will
|
||||||
|
# write them back in encrypted form.
|
||||||
|
raw_password = data.get("password", "")
|
||||||
|
password = secret_box.decrypt(raw_password) if raw_password else ""
|
||||||
return MQTTSource(
|
return MQTTSource(
|
||||||
**common,
|
**common,
|
||||||
broker_host=data.get("broker_host", "localhost"),
|
broker_host=data.get("broker_host", "localhost"),
|
||||||
broker_port=int(data.get("broker_port", 1883)),
|
broker_port=int(data.get("broker_port", 1883)),
|
||||||
username=data.get("username", ""),
|
username=data.get("username", ""),
|
||||||
password=data.get("password", ""),
|
password=password,
|
||||||
client_id=data.get("client_id", "ledgrab"),
|
client_id=data.get("client_id", "ledgrab"),
|
||||||
base_topic=data.get("base_topic", "ledgrab"),
|
base_topic=data.get("base_topic", "ledgrab"),
|
||||||
icon=data.get("icon", ""),
|
icon=data.get("icon", ""),
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import List, Optional
|
|||||||
from ledgrab.storage.base_sqlite_store import BaseSqliteStore
|
from ledgrab.storage.base_sqlite_store import BaseSqliteStore
|
||||||
from ledgrab.storage.database import Database
|
from ledgrab.storage.database import Database
|
||||||
from ledgrab.storage.mqtt_source import MQTTSource
|
from ledgrab.storage.mqtt_source import MQTTSource
|
||||||
from ledgrab.utils import get_logger
|
from ledgrab.utils import get_logger, secret_box
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -20,6 +20,41 @@ class MQTTSourceStore(BaseSqliteStore[MQTTSource]):
|
|||||||
|
|
||||||
def __init__(self, db: Database):
|
def __init__(self, db: Database):
|
||||||
super().__init__(db, MQTTSource.from_dict)
|
super().__init__(db, MQTTSource.from_dict)
|
||||||
|
self._migrate_plaintext_passwords()
|
||||||
|
|
||||||
|
def _migrate_plaintext_passwords(self) -> None:
|
||||||
|
"""Encrypt any stored MQTT broker passwords still in plaintext at rest.
|
||||||
|
|
||||||
|
Mirrors the HA-token migration: inspect raw DB rows, re-save items
|
||||||
|
that lack the encryption envelope so the next persistence pass
|
||||||
|
writes them in encrypted form.
|
||||||
|
"""
|
||||||
|
migrated = 0
|
||||||
|
try:
|
||||||
|
rows = self._db.load_all(self._table_name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Could not inspect rows for MQTT password migration: %s", exc)
|
||||||
|
return
|
||||||
|
for row in rows:
|
||||||
|
sid = row.get("id")
|
||||||
|
raw_pw = row.get("password", "") or ""
|
||||||
|
if not sid or not raw_pw:
|
||||||
|
continue
|
||||||
|
if secret_box.is_encrypted(raw_pw):
|
||||||
|
continue
|
||||||
|
source = self._items.get(sid)
|
||||||
|
if source is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
self._save_item(sid, source)
|
||||||
|
migrated += 1
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to migrate MQTT password for %s: %s", sid, exc)
|
||||||
|
if migrated:
|
||||||
|
logger.warning(
|
||||||
|
"MIGRATION: encrypted %d plaintext MQTT broker password(s) at rest.",
|
||||||
|
migrated,
|
||||||
|
)
|
||||||
|
|
||||||
# Backward-compatible aliases
|
# Backward-compatible aliases
|
||||||
get_all_sources = BaseSqliteStore.get_all
|
get_all_sources = BaseSqliteStore.get_all
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
"""Utility functions and helpers."""
|
"""Utility functions and helpers."""
|
||||||
|
|
||||||
from .file_ops import atomic_write_json
|
from .file_ops import atomic_write_json, read_upload_capped
|
||||||
from .logger import setup_logging, get_logger
|
from .logger import setup_logging, get_logger
|
||||||
from .monitor_names import get_monitor_names, get_monitor_name, get_monitor_refresh_rates
|
from .monitor_names import get_monitor_names, get_monitor_name, get_monitor_refresh_rates
|
||||||
from .timer import high_resolution_timer
|
from .timer import high_resolution_timer
|
||||||
from .log_broadcaster import broadcaster as log_broadcaster, install_broadcast_handler
|
from .log_broadcaster import broadcaster as log_broadcaster, install_broadcast_handler
|
||||||
|
from .url_scheme import infer_http_scheme
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"atomic_write_json",
|
"atomic_write_json",
|
||||||
|
"read_upload_capped",
|
||||||
"setup_logging",
|
"setup_logging",
|
||||||
"get_logger",
|
"get_logger",
|
||||||
"get_monitor_names",
|
"get_monitor_names",
|
||||||
@@ -16,4 +18,5 @@ __all__ = [
|
|||||||
"high_resolution_timer",
|
"high_resolution_timer",
|
||||||
"log_broadcaster",
|
"log_broadcaster",
|
||||||
"install_broadcast_handler",
|
"install_broadcast_handler",
|
||||||
|
"infer_http_scheme",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,10 +5,40 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastapi import UploadFile
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_UPLOAD_CHUNK_SIZE = 1 * 1024 * 1024 # 1 MiB
|
||||||
|
|
||||||
|
|
||||||
|
async def read_upload_capped(file: "UploadFile", max_size: int) -> bytes:
|
||||||
|
"""Read an uploaded file in chunks, refusing to buffer past ``max_size``.
|
||||||
|
|
||||||
|
A plain ``await file.read()`` will allocate the entire upload body before
|
||||||
|
any size check runs, so a 1 GB upload still costs 1 GB of memory even
|
||||||
|
against a 50 MB limit. This helper streams the read and aborts as soon
|
||||||
|
as the accumulated size exceeds the cap.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: When the upload exceeds *max_size*. The caller decides
|
||||||
|
how to translate the exception (typically HTTP 400/413).
|
||||||
|
"""
|
||||||
|
buf = bytearray()
|
||||||
|
while True:
|
||||||
|
chunk = await file.read(_UPLOAD_CHUNK_SIZE)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
if len(buf) + len(chunk) > max_size:
|
||||||
|
raise ValueError(f"Upload exceeds maximum allowed size of {max_size} bytes")
|
||||||
|
buf.extend(chunk)
|
||||||
|
return bytes(buf)
|
||||||
|
|
||||||
|
|
||||||
def atomic_write_json(file_path: Path, data: dict, indent: int = 2) -> None:
|
def atomic_write_json(file_path: Path, data: dict, indent: int = 2) -> None:
|
||||||
"""Write JSON data to file atomically via temp file + rename.
|
"""Write JSON data to file atomically via temp file + rename.
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ from urllib.parse import urlparse
|
|||||||
import httpx
|
import httpx
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from ledgrab.utils.net_classify import is_blocked_for_ssrf
|
||||||
|
|
||||||
|
|
||||||
# Image file extensions considered safe to serve
|
# Image file extensions considered safe to serve
|
||||||
_IMAGE_EXTENSIONS = frozenset(
|
_IMAGE_EXTENSIONS = frozenset(
|
||||||
@@ -38,19 +40,13 @@ _IMAGE_EXTENSIONS = frozenset(
|
|||||||
|
|
||||||
|
|
||||||
def _is_blocked_ip(ip: str) -> bool:
|
def _is_blocked_ip(ip: str) -> bool:
|
||||||
"""Return True when *ip* belongs to a category we refuse to fetch from."""
|
"""Return True when *ip* belongs to a category we refuse to fetch from.
|
||||||
try:
|
|
||||||
addr = ipaddress.ip_address(ip)
|
Thin wrapper preserved for the local test suite. The actual policy now
|
||||||
except ValueError:
|
lives in :func:`ledgrab.utils.net_classify.is_blocked_for_ssrf` so it
|
||||||
return True # unparseable → block
|
can't drift away from the related ``url_scheme`` / ``auth`` modules.
|
||||||
return (
|
"""
|
||||||
addr.is_private
|
return is_blocked_for_ssrf(ip)
|
||||||
or addr.is_loopback
|
|
||||||
or addr.is_link_local
|
|
||||||
or addr.is_reserved
|
|
||||||
or addr.is_multicast
|
|
||||||
or addr.is_unspecified
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_hostname(hostname: str) -> Iterable[str]:
|
def _resolve_hostname(hostname: str) -> Iterable[str]:
|
||||||
@@ -67,6 +63,11 @@ def _resolve_hostname(hostname: str) -> Iterable[str]:
|
|||||||
def validate_image_url(url: str) -> None:
|
def validate_image_url(url: str) -> None:
|
||||||
"""Validate that *url* is safe to fetch.
|
"""Validate that *url* is safe to fetch.
|
||||||
|
|
||||||
|
Strict SSRF policy — blocks all private / loopback / link-local /
|
||||||
|
reserved / multicast / unspecified IPs. Intended for fetches of
|
||||||
|
user-supplied URLs that should reach the *public* internet only
|
||||||
|
(image sources, weather APIs, etc.).
|
||||||
|
|
||||||
Checks:
|
Checks:
|
||||||
- Scheme is http/https
|
- Scheme is http/https
|
||||||
- Hostname is present
|
- Hostname is present
|
||||||
@@ -74,6 +75,25 @@ def validate_image_url(url: str) -> None:
|
|||||||
- DNS resolves to non-private, non-loopback, non-link-local, non-reserved,
|
- DNS resolves to non-private, non-loopback, non-link-local, non-reserved,
|
||||||
non-multicast addresses (SSRF protection)
|
non-multicast addresses (SSRF protection)
|
||||||
"""
|
"""
|
||||||
|
_validate_url(url, allow_private=False)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_polling_url(url: str) -> None:
|
||||||
|
"""Validate that *url* is safe to poll from an automation trigger.
|
||||||
|
|
||||||
|
Relaxed SSRF policy — allows PRIVATE (LAN) IPs so users can target
|
||||||
|
devices on their own network (Plex on 192.168.x.x, Home Assistant
|
||||||
|
on 10.x.x.x). Still blocks LOOPBACK (protects the server's own
|
||||||
|
admin ports), LINK_LOCAL (blocks 169.254.169.254 cloud-metadata
|
||||||
|
services), MULTICAST, RESERVED, and UNSPECIFIED.
|
||||||
|
|
||||||
|
Use this instead of :func:`validate_image_url` for any user-configured
|
||||||
|
HTTP polling target.
|
||||||
|
"""
|
||||||
|
_validate_url(url, allow_private=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_url(url: str, *, allow_private: bool) -> None:
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
if parsed.scheme not in ("http", "https"):
|
if parsed.scheme not in ("http", "https"):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -93,17 +113,41 @@ def validate_image_url(url: str) -> None:
|
|||||||
ips = _resolve_hostname(hostname)
|
ips = _resolve_hostname(hostname)
|
||||||
|
|
||||||
for ip in ips:
|
for ip in ips:
|
||||||
if _is_blocked_ip(ip):
|
if _ip_is_blocked(ip, allow_private=allow_private):
|
||||||
|
policy = (
|
||||||
|
"loopback / link-local / reserved / multicast"
|
||||||
|
if allow_private
|
||||||
|
else "private / loopback / link-local / reserved / multicast"
|
||||||
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=(
|
detail=(
|
||||||
f"Refusing to fetch URL: hostname {hostname!r} resolves to "
|
f"Refusing to fetch URL: hostname {hostname!r} resolves to "
|
||||||
f"blocked address {ip} (private / loopback / link-local / "
|
f"blocked address {ip} ({policy})"
|
||||||
"reserved / multicast)"
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ip_is_blocked(ip: str, *, allow_private: bool) -> bool:
|
||||||
|
"""Block-list predicate parameterised by whether PRIVATE is permitted."""
|
||||||
|
from ledgrab.utils.net_classify import HostCategory, classify_ip
|
||||||
|
|
||||||
|
category = classify_ip(ip)
|
||||||
|
always_block = {
|
||||||
|
HostCategory.LOOPBACK,
|
||||||
|
HostCategory.LINK_LOCAL,
|
||||||
|
HostCategory.MULTICAST,
|
||||||
|
HostCategory.RESERVED,
|
||||||
|
HostCategory.UNSPECIFIED,
|
||||||
|
HostCategory.UNPARSEABLE,
|
||||||
|
}
|
||||||
|
if category in always_block:
|
||||||
|
return True
|
||||||
|
if not allow_private and category == HostCategory.PRIVATE:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def validate_image_path(file_path: str | Path) -> Path:
|
def validate_image_path(file_path: str | Path) -> Path:
|
||||||
"""Validate a local file path points to a real image file.
|
"""Validate a local file path points to a real image file.
|
||||||
|
|
||||||
@@ -165,3 +209,88 @@ async def safe_fetch(
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response
|
return response
|
||||||
raise HTTPException(status_code=400, detail=f"Too many redirects (max {max_hops})")
|
raise HTTPException(status_code=400, detail=f"Too many redirects (max {max_hops})")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bounded safe request — for polling-style fetches that need streaming
|
||||||
|
# + size cap + arbitrary method/headers. Used by the HTTP poll runtime.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
# Hard cap to avoid OOM from polled endpoints returning huge bodies.
|
||||||
|
# 1 MiB covers every realistic JSON status endpoint by orders of magnitude.
|
||||||
|
_MAX_RESPONSE_BYTES = 1 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
async def safe_request_bounded(
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
timeout: float = 10.0,
|
||||||
|
max_bytes: int = _MAX_RESPONSE_BYTES,
|
||||||
|
allow_private: bool = True,
|
||||||
|
) -> tuple[int, bytes, str | None]:
|
||||||
|
"""Make a single HTTP request with SSRF validation and a body-size cap.
|
||||||
|
|
||||||
|
Validates the URL up-front (scheme + IP-block) and streams the response
|
||||||
|
body, aborting as soon as ``max_bytes`` is exceeded so a malicious or
|
||||||
|
misconfigured endpoint can't OOM the server.
|
||||||
|
|
||||||
|
``allow_private`` selects the SSRF policy: ``True`` (default — for HTTP
|
||||||
|
polling against user-owned LAN devices) routes through
|
||||||
|
:func:`validate_polling_url`; ``False`` routes through
|
||||||
|
:func:`validate_image_url`'s strict public-only policy.
|
||||||
|
|
||||||
|
Note on DNS rebinding: this validates the URL hostname's resolved IPs
|
||||||
|
once, then httpx independently re-resolves to make the actual request.
|
||||||
|
The window between the two resolutions is short but not zero. True
|
||||||
|
IP-pinning (rewriting the URL to the literal IP + Host header) is a
|
||||||
|
project-wide concern that would also need to update ``safe_fetch`` and
|
||||||
|
every other outbound caller; tracked as a follow-up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``(status_code, body_bytes, error_message)``. ``error_message`` is
|
||||||
|
``None`` on success; on transport failure ``status_code`` is 0 and
|
||||||
|
``body_bytes`` is empty.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: when the URL fails validation (4xx-style error
|
||||||
|
intended for callers that surface this directly to the user).
|
||||||
|
"""
|
||||||
|
if allow_private:
|
||||||
|
validate_polling_url(url)
|
||||||
|
else:
|
||||||
|
validate_image_url(url)
|
||||||
|
method = method.upper()
|
||||||
|
headers = dict(headers or {})
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client:
|
||||||
|
async with client.stream(method, url, headers=headers) as response:
|
||||||
|
chunks: list[bytes] = []
|
||||||
|
total = 0
|
||||||
|
truncated = False
|
||||||
|
async for chunk in response.aiter_bytes():
|
||||||
|
if total + len(chunk) > max_bytes:
|
||||||
|
chunks.append(chunk[: max_bytes - total])
|
||||||
|
total = max_bytes
|
||||||
|
truncated = True
|
||||||
|
break
|
||||||
|
chunks.append(chunk)
|
||||||
|
total += len(chunk)
|
||||||
|
body = b"".join(chunks)
|
||||||
|
# Truncation is NOT an error — the partial body is returned
|
||||||
|
# for the caller to use (e.g. evaluate a JSON path against the
|
||||||
|
# first 1 MiB). Callers that care can detect it by comparing
|
||||||
|
# ``len(body) == max_bytes``. We only flag actual transport
|
||||||
|
# failures in the error slot.
|
||||||
|
_ = truncated
|
||||||
|
return response.status_code, body, None
|
||||||
|
except httpx.RequestError as exc:
|
||||||
|
# Never include `exc` repr in the message — some httpx errors include
|
||||||
|
# request URL + headers, which could leak auth tokens supplied by the
|
||||||
|
# caller. Type name is sufficient diagnostic.
|
||||||
|
return 0, b"", f"Request failed: {type(exc).__name__}"
|
||||||
|
except Exception as exc:
|
||||||
|
return 0, b"", f"Unexpected error: {type(exc).__name__}"
|
||||||
|
|||||||
@@ -12,22 +12,28 @@ class TestBackupRestoreFlow:
|
|||||||
"""A user backs up their configuration and restores it."""
|
"""A user backs up their configuration and restores it."""
|
||||||
|
|
||||||
def _create_device(self, client, name="Backup Device") -> str:
|
def _create_device(self, client, name="Backup Device") -> str:
|
||||||
resp = client.post("/api/v1/devices", json={
|
resp = client.post(
|
||||||
"name": name,
|
"/api/v1/devices",
|
||||||
"url": "mock://backup",
|
json={
|
||||||
"device_type": "mock",
|
"name": name,
|
||||||
"led_count": 30,
|
"url": "mock://backup",
|
||||||
})
|
"device_type": "mock",
|
||||||
|
"led_count": 30,
|
||||||
|
},
|
||||||
|
)
|
||||||
assert resp.status_code == 201
|
assert resp.status_code == 201
|
||||||
return resp.json()["id"]
|
return resp.json()["id"]
|
||||||
|
|
||||||
def _create_css(self, client, name="Backup CSS") -> str:
|
def _create_css(self, client, name="Backup CSS") -> str:
|
||||||
resp = client.post("/api/v1/color-strip-sources", json={
|
resp = client.post(
|
||||||
"name": name,
|
"/api/v1/color-strip-sources",
|
||||||
"source_type": "static",
|
json={
|
||||||
"color": [255, 0, 0],
|
"name": name,
|
||||||
"led_count": 30,
|
"source_type": "single_color",
|
||||||
})
|
"color": [255, 0, 0],
|
||||||
|
"led_count": 30,
|
||||||
|
},
|
||||||
|
)
|
||||||
assert resp.status_code == 201
|
assert resp.status_code == 201
|
||||||
return resp.json()["id"]
|
return resp.json()["id"]
|
||||||
|
|
||||||
|
|||||||
@@ -78,15 +78,27 @@ def test_get_displays(client):
|
|||||||
|
|
||||||
|
|
||||||
def test_openapi_docs(client):
|
def test_openapi_docs(client):
|
||||||
"""Test OpenAPI documentation is available."""
|
"""Test OpenAPI documentation is available to authenticated clients."""
|
||||||
response = client.get("/openapi.json")
|
response = client.get("/openapi.json", headers=AUTH_HEADERS)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["info"]["version"] == __version__
|
assert data["info"]["version"] == __version__
|
||||||
|
|
||||||
|
|
||||||
def test_swagger_ui(client):
|
def test_swagger_ui(client):
|
||||||
"""Test Swagger UI is available."""
|
"""Test Swagger UI is available to authenticated clients."""
|
||||||
response = client.get("/docs")
|
response = client.get("/docs", headers=AUTH_HEADERS)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert "text/html" in response.headers["content-type"]
|
assert "text/html" in response.headers["content-type"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_openapi_docs_requires_auth(client):
|
||||||
|
"""OpenAPI surface must NOT be reachable without auth (info disclosure)."""
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_swagger_ui_requires_auth(client):
|
||||||
|
"""Swagger UI must NOT be reachable without auth."""
|
||||||
|
response = client.get("/docs")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|||||||
Reference in New Issue
Block a user