chore(backend): MQTT/WLED/devices/capture/utils + api routes hardening

Bundle the remaining backend touch-ups that the production review
landed individually as small surgical edits across many modules:
- MQTT runtime: fire-and-forget task tracking + drain resilience.
- mqtt_source + store + storage/color_strip_source: secret_box
  encryption for credentials with auto-migration of plaintext fields.
- devices/discovery_watcher: task tracking on watcher start/stop.
- devices/wled_client + wled_provider: URL scheme inference helper
  applied at the create/update boundary so bare hostnames stay valid.
- core/capture/screen_capture: hardened error paths.
- core/processing (mapped/processed/processor_manager/video/wled_target):
  smaller follow-throughs from the registry refactor that landed
  earlier on the branch.
- utils/safe_source + utils/file_ops + utils/__init__: shared URL +
  IP classification helpers + larger streaming upload size caps.
- api/auth: WebSocket Origin allow-list + /docs auth-gate.
- api/dependencies: register the new HTTP-endpoint store.
- api/routes (assets, backup, webhooks): streaming-upload caps +
  asyncio.gather return_exceptions on broadcast loops.
- tests/test_api + tests/e2e/test_backup_flow: cover the new caps and
  the Origin allow-list.
This commit is contained in:
2026-05-23 00:50:01 +03:00
parent 45d12b2811
commit 898912f8b1
22 changed files with 498 additions and 73 deletions
+42 -9
View File
@@ -11,14 +11,13 @@ from starlette.websockets import WebSocket, WebSocketDisconnect
from ledgrab.config import get_config
from ledgrab.utils import get_logger
from ledgrab.utils.net_classify import is_loopback as _classify_is_loopback
logger = get_logger(__name__)
# Security scheme for Bearer token
security = HTTPBearer(auto_error=False)
_LOOPBACK_HOSTS = frozenset({"127.0.0.1", "::1", "localhost", "testclient"})
def is_auth_enabled() -> bool:
"""Return True when at least one API key is configured."""
@@ -26,15 +25,15 @@ def is_auth_enabled() -> bool:
def _is_loopback(host: str | None) -> bool:
"""Return True when *host* is a loopback address."""
"""Return True when *host* is a loopback address.
Delegates to :func:`ledgrab.utils.net_classify.is_loopback` so this
auth gate, the SSRF guard in ``safe_source``, and the LAN-default
inference in ``url_scheme`` share one classification source.
"""
if not host:
return False
# Strip IPv6 brackets and zone IDs
h = host.strip().lower()
if h.startswith("[") and h.endswith("]"):
h = h[1:-1]
h = h.split("%", 1)[0]
return h in _LOOPBACK_HOSTS
return _classify_is_loopback(host)
def verify_api_key(
@@ -142,6 +141,23 @@ def require_authenticated(label: str) -> None:
WS_AUTH_CLOSE_CODE = 4401
WS_ORIGIN_CLOSE_CODE = 4403
"""Close code sent when a WebSocket request fails the Origin allowlist."""
def _is_origin_allowed(origin: str | None, allowed: list[str]) -> bool:
"""Return True when *origin* matches one of the configured CORS origins.
Non-browser clients (Python scripts, curl) don't send Origin — those are
allowed through; the Bearer-token check on the auth handshake is the
primary defence in that case. Browsers always set Origin, so this only
blocks cross-site WebSocket connection attempts (CSWSH).
"""
if not origin:
return True
return origin in set(allowed or [])
async def accept_and_authenticate_ws(websocket: WebSocket, timeout: float = 3.0) -> str | None:
"""Accept the WebSocket, then perform first-message auth handshake.
@@ -152,6 +168,23 @@ async def accept_and_authenticate_ws(websocket: WebSocket, timeout: float = 3.0)
Returns the caller label on success, ``None`` on failure (connection
already closed).
"""
# Reject cross-site WebSocket attempts before accepting — a browser-based
# attacker page cannot forge the Origin header, so an Origin mismatch is
# a strong signal even before the token check. Non-browser clients
# legitimately omit Origin; those fall through to the auth handshake.
config = get_config()
origin = websocket.headers.get("origin")
if not _is_origin_allowed(origin, config.server.cors_origins):
logger.warning(
"Rejected WebSocket from origin %r (not in cors_origins)",
origin,
)
try:
await websocket.close(code=WS_ORIGIN_CLOSE_CODE)
except Exception:
pass
return None
await websocket.accept()
label = await verify_ws_auth(websocket, timeout=timeout)
if label is None:
+7
View File
@@ -37,6 +37,7 @@ from ledgrab.storage.game_integration_store import GameIntegrationStore
from ledgrab.core.game_integration.event_bus import GameEventBus
from ledgrab.storage.mqtt_source_store import MQTTSourceStore
from ledgrab.core.mqtt.mqtt_manager import MQTTManager
from ledgrab.storage.http_endpoint_store import HTTPEndpointStore
from ledgrab.storage.audio_processing_template_store import AudioProcessingTemplateStore
from ledgrab.storage.pattern_template_store import PatternTemplateStore
@@ -165,6 +166,10 @@ def get_mqtt_manager() -> MQTTManager:
return _get("mqtt_manager", "MQTT manager")
def get_http_endpoint_store() -> HTTPEndpointStore:
return _get("http_endpoint_store", "HTTP endpoint store")
def get_audio_processing_template_store() -> AudioProcessingTemplateStore:
return _get("audio_processing_template_store", "Audio processing template store")
@@ -237,6 +242,7 @@ def init_dependencies(
game_event_bus: GameEventBus | None = None,
mqtt_store: MQTTSourceStore | None = None,
mqtt_manager: MQTTManager | None = None,
http_endpoint_store: HTTPEndpointStore | None = None,
audio_processing_template_store: AudioProcessingTemplateStore | None = None,
pattern_template_store: PatternTemplateStore | None = None,
):
@@ -272,6 +278,7 @@ def init_dependencies(
"game_event_bus": game_event_bus,
"mqtt_store": mqtt_store,
"mqtt_manager": mqtt_manager,
"http_endpoint_store": http_endpoint_store,
"audio_processing_template_store": audio_processing_template_store,
"pattern_template_store": pattern_template_store,
}
+5 -4
View File
@@ -15,7 +15,7 @@ from ledgrab.api.schemas.assets import (
from ledgrab.config import get_config
from ledgrab.storage.asset_store import AssetStore
from ledgrab.storage.base_store import EntityNotFoundError
from ledgrab.utils import get_logger
from ledgrab.utils import get_logger, read_upload_capped
logger = get_logger(__name__)
@@ -93,10 +93,11 @@ async def upload_asset(
config = get_config()
max_size = getattr(getattr(config, "assets", None), "max_file_size_mb", 50) * 1024 * 1024
data = await file.read()
if len(data) > max_size:
try:
data = await read_upload_capped(file, max_size)
except ValueError:
raise HTTPException(
status_code=400,
status_code=413,
detail=f"File too large (max {max_size // (1024 * 1024)} MB)",
)
+6 -4
View File
@@ -28,7 +28,7 @@ from ledgrab.config import get_config
from ledgrab.core.backup.auto_backup import AutoBackupEngine
from ledgrab.storage.asset_store import AssetStore
from ledgrab.storage.database import Database, freeze_writes
from ledgrab.utils import get_logger
from ledgrab.utils import get_logger, read_upload_capped
logger = get_logger(__name__)
@@ -133,9 +133,11 @@ async def restore_config(
because restore replaces all configuration including secrets).
"""
require_authenticated(auth)
raw = await file.read()
if len(raw) > 200 * 1024 * 1024: # 200 MB limit (ZIP may contain assets)
raise HTTPException(status_code=400, detail="Backup file too large (max 200 MB)")
_MAX_BACKUP_BYTES = 200 * 1024 * 1024 # 200 MB (ZIP may contain assets)
try:
raw = await read_upload_capped(file, _MAX_BACKUP_BYTES)
except ValueError:
raise HTTPException(status_code=413, detail="Backup file too large (max 200 MB)")
if len(raw) < 100:
raise HTTPException(status_code=400, detail="File too small to be a valid backup")
+14 -1
View File
@@ -30,6 +30,9 @@ _RATE_WINDOW = 60.0 # seconds
_rate_hits: dict[str, list[float]] = defaultdict(list)
_RATE_HITS_HARD_CAP = 1024
def _check_rate_limit(client_ip: str) -> None:
"""Raise 429 if *client_ip* exceeded the webhook rate limit."""
now = time.time()
@@ -44,11 +47,21 @@ def _check_rate_limit(client_ip: str) -> None:
)
_rate_hits[client_ip].append(now)
# Periodic cleanup: remove IPs with no recent hits to prevent unbounded growth
# Periodic cleanup: remove IPs with no recent hits to prevent unbounded growth.
if len(_rate_hits) > 100:
stale = [ip for ip, ts in _rate_hits.items() if not ts or ts[-1] < window_start]
for ip in stale:
del _rate_hits[ip]
# Hard cap as a final defence against an attacker spraying many distinct
# X-Forwarded-For values to drive memory growth past the soft cleanup
# threshold. Drop the oldest-touched IPs (by their latest timestamp).
if len(_rate_hits) > _RATE_HITS_HARD_CAP:
ordered = sorted(
_rate_hits.items(),
key=lambda kv: kv[1][-1] if kv[1] else 0.0,
)
for ip, _ in ordered[: len(ordered) - _RATE_HITS_HARD_CAP]:
_rate_hits.pop(ip, None)
class WebhookPayload(BaseModel):
@@ -14,6 +14,12 @@ from ledgrab.utils import get_logger, get_monitor_names, get_monitor_refresh_rat
logger = get_logger(__name__)
# Reused random Generator for sampling. The legacy ``np.random.randint``
# uses the module-level RandomState which is slightly slower per-call and
# pulls in extra import-time work; a single Generator is faster and avoids
# global-state surprises.
_rng = np.random.default_rng()
@dataclass
class DisplayInfo:
@@ -326,7 +332,11 @@ def calculate_dominant_color(pixels: np.ndarray) -> tuple[int, int, int]:
max_samples = 1000
if n > max_samples:
indices = np.random.randint(0, n, max_samples)
# ``Generator.integers`` writes into a fresh buffer once per call;
# the legacy ``np.random.randint`` did the same plus extra
# bookkeeping. Random (not stride) sampling stays robust against
# periodic patterns in screen pixels.
indices = _rng.integers(0, n, size=max_samples)
pixels_reshaped = pixels_reshaped[indices]
# Quantize to 32 levels/channel (drop low 3 bits) and pack into uint32:
@@ -85,6 +85,10 @@ class DiscoveryWatcher:
self._wled_seen: Dict[str, _DiscoveredEntry] = {}
# device-path -> entry. Only the serial poller mutates this.
self._serial_seen: Dict[str, _DiscoveredEntry] = {}
# Strong references for fire-and-forget resolve tasks — without
# these, Python 3.11+ may GC the task mid-resolve and silently lose
# discovery events. Tasks remove themselves on completion.
self._resolve_tasks: "set[asyncio.Task]" = set()
# --- lifecycle --------------------------------------------------------
@@ -155,12 +159,23 @@ class DiscoveryWatcher:
if state_change in (ServiceStateChange.Added, ServiceStateChange.Updated):
# Resolve in a task — async_request blocks the handler if awaited
# synchronously and we don't want to stall mDNS dispatch.
asyncio.create_task(self._resolve_wled(service_type, name))
task = asyncio.create_task(self._resolve_wled(service_type, name))
self._resolve_tasks.add(task)
task.add_done_callback(self._on_resolve_done)
elif state_change == ServiceStateChange.Removed:
entry = self._wled_seen.pop(name, None)
if entry is not None and not self._is_configured(entry.url):
self._emit("device_lost", entry)
def _on_resolve_done(self, task: "asyncio.Task") -> None:
"""Release the strong task reference and log any resolve failure."""
self._resolve_tasks.discard(task)
if task.cancelled():
return
exc = task.exception()
if exc is not None:
logger.debug("Discovery watcher: resolve task raised: %s", exc)
async def _resolve_wled(self, service_type: str, name: str) -> None:
if self._aiozc is None:
return
@@ -453,8 +453,15 @@ class WLEDClient(LEDClient):
],
}
logger.debug(f"Sending {len(pixels)} LEDs via HTTP ({len(indexed_pixels)} values)")
logger.debug(f"Payload size: ~{len(str(payload))} bytes")
# ``str(payload)`` previously stringified the entire indexed
# array on every send to report a byte estimate; that was the
# hot path. Drop the size readout — the LED count + indexed
# value count is enough to interpret traffic and is O(1).
logger.debug(
"Sending %d LEDs via HTTP (%d indexed values)",
len(pixels),
len(indexed_pixels),
)
await self._request("POST", "/json/state", json_data=payload)
logger.debug("Successfully sent pixel colors via HTTP")
@@ -98,11 +98,18 @@ class WLEDDeviceProvider(LEDDeviceProvider):
dict with 'led_count' key.
Raises:
ValueError: Unsupported scheme or invalid LED count.
httpx.ConnectError: Device unreachable.
httpx.TimeoutException: Connection timed out.
ValueError: Invalid LED count.
"""
url = _normalize_url(url)
# Reject anything that isn't plain HTTP(S). url_scheme.infer_http_scheme
# passes non-HTTP schemes through untouched ("javascript:", "file:",
# "data:", etc.); without this guard those would reach httpx and
# surface as opaque transport errors at best, or be silently misused
# at worst.
if not url.lower().startswith(("http://", "https://")):
raise ValueError(f"WLED URL must use http:// or https:// scheme (got {url!r})")
async with httpx.AsyncClient(timeout=5) as client:
response = await client.get(_join(url, "/json/info"))
response.raise_for_status()
+58 -6
View File
@@ -49,6 +49,32 @@ class MQTTRuntime:
# Pending publishes queued while disconnected
self._publish_queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
# Strong references for fire-and-forget callback dispatch tasks.
# Python 3.11+ may GC bare ``asyncio.create_task(...)`` results mid-
# flight, so we hold each task until it completes and surface any
# exception via the done-callback.
self._dispatch_tasks: Set[asyncio.Task] = set()
# Compiled ``aiomqtt.Topic`` cache, keyed by the subscription pattern
# string. The previous dispatch loop re-parsed every pattern on
# every incoming message — on a chatty broker with many wildcards
# that adds up fast.
self._compiled_topics: Dict[str, aiomqtt.Topic] = {}
def _on_dispatch_done(self, task: asyncio.Task) -> None:
"""Drop the strong reference and surface any callback exception."""
self._dispatch_tasks.discard(task)
if task.cancelled():
return
exc = task.exception()
if exc is not None:
logger.error(
"MQTT async callback raised (%s): %s",
self._source_id,
exc,
exc_info=exc,
)
@property
def is_connected(self) -> bool:
return self._connected
@@ -84,6 +110,14 @@ class MQTTRuntime:
logger.debug("MQTT runtime task cancelled: %s", self._source_id)
self._task = None
self._connected = False
# Cancel any in-flight async dispatch callbacks. Without this they
# would keep running past the runtime's logical end of life and
# could fire callbacks on a stopped subscriber.
for task in list(self._dispatch_tasks):
task.cancel()
if self._dispatch_tasks:
await asyncio.gather(*self._dispatch_tasks, return_exceptions=True)
self._dispatch_tasks.clear()
logger.info("MQTT runtime stopped: %s", self._source_id)
def update_config(self, source: MQTTSource) -> None:
@@ -167,13 +201,23 @@ class MQTTRuntime:
for topic in self._subscriptions:
await client.subscribe(topic)
# Drain pending publishes
# Drain pending publishes. A single publish failing
# (broker rejection, oversized message) must not lose
# the rest of the queue — log and continue with the next.
while not self._publish_queue.empty():
try:
t, p, r, q = self._publish_queue.get_nowait()
await client.publish(t, p, retain=r, qos=q)
except Exception:
except asyncio.QueueEmpty:
break
try:
await client.publish(t, p, retain=r, qos=q)
except Exception as exc:
logger.warning(
"MQTT drain publish failed (%s -> %s): %s",
self._source_id,
t,
exc,
)
# Message receive loop
async for msg in client.messages:
@@ -183,13 +227,21 @@ class MQTTRuntime:
)
self._topic_cache[topic_str] = payload_str
# Dispatch to callbacks
# Dispatch to callbacks. Pattern objects are cached
# per subscription string to avoid re-parsing them on
# every received message.
for sub_topic, callbacks in self._subscriptions.items():
if aiomqtt.Topic(sub_topic).matches(msg.topic):
compiled = self._compiled_topics.get(sub_topic)
if compiled is None:
compiled = aiomqtt.Topic(sub_topic)
self._compiled_topics[sub_topic] = compiled
if compiled.matches(msg.topic):
for cb in callbacks:
try:
if asyncio.iscoroutinefunction(cb):
asyncio.create_task(cb(topic_str, payload_str))
task = asyncio.create_task(cb(topic_str, payload_str))
self._dispatch_tasks.add(task)
task.add_done_callback(self._on_dispatch_done)
else:
cb(topic_str, payload_str)
except Exception as e:
@@ -2,6 +2,7 @@
import threading
import time
from collections import OrderedDict
from typing import Dict, List, Optional
import numpy as np
@@ -11,6 +12,11 @@ from ledgrab.utils import get_logger
logger = get_logger(__name__)
# Cap the (src_len, dst_len) resize cache. Each entry holds two ``np.linspace``
# arrays plus a per-zone ``uint8`` scratch buffer, which used to grow without
# bound under hot reconfigure storms.
_RESIZE_CACHE_MAX = 16
class MappedColorStripStream(ColorStripStream):
"""Places multiple ColorStripStreams side-by-side at distinct LED ranges.
@@ -46,8 +52,11 @@ class MappedColorStripStream(ColorStripStream):
# zone_index -> (source_id, consumer_id, stream)
self._sub_streams: Dict[int, tuple] = {}
# (src_len, dst_len) -> (src_x, dst_x, buffer) cache for zone resizing
self._resize_cache: Dict[tuple, tuple] = {}
# (src_len, dst_len) -> (src_x, dst_x, buffer) cache for zone resizing.
# An ``OrderedDict`` with eviction keeps memory bounded if device
# configurations fluctuate at runtime (each unique pair adds two
# linspace arrays + a per-zone reusable uint8 buffer).
self._resize_cache: "OrderedDict[tuple, tuple]" = OrderedDict()
self._sub_lock = threading.Lock() # guards _sub_streams access across threads
# ── ColorStripStream interface ──────────────────────────────
@@ -229,6 +238,14 @@ class MappedColorStripStream(ColorStripStream):
np.empty((zone_len, 3), dtype=np.uint8),
)
self._resize_cache[rkey] = cached
# Drop the least-recently-inserted entry once
# we hit the cap. 16 entries comfortably covers
# any realistic zone/source layout — pathological
# reconfigure storms used to grow this forever.
if len(self._resize_cache) > _RESIZE_CACHE_MAX:
self._resize_cache.popitem(last=False)
else:
self._resize_cache.move_to_end(rkey)
src_x, dst_x, resized = cached
for ch in range(3):
np.copyto(
@@ -160,9 +160,11 @@ class ProcessedColorStripStream(ColorStripStream):
self._resolve_count = 0
self._resolve_filters()
colors = None
if self._input_stream:
colors = self._input_stream.get_latest_colors()
# Bind to a local first — ``update_source()`` may swap or null
# out ``_input_stream`` between the check and the read on a
# different thread.
inp = self._input_stream
colors = inp.get_latest_colors() if inp is not None else None
if colors is not None and self._filters:
for flt in self._filters:
@@ -38,6 +38,7 @@ from ledgrab.storage.picture_source_store import PictureSourceStore
from ledgrab.storage.postprocessing_template_store import PostprocessingTemplateStore
from ledgrab.storage.template_store import TemplateStore
from ledgrab.storage.value_source_store import ValueSourceStore
from ledgrab.storage.http_endpoint_store import HTTPEndpointStore
from ledgrab.storage.asset_store import AssetStore
from ledgrab.core.processing.sync_clock_manager import SyncClockManager
from ledgrab.core.weather.weather_manager import WeatherManager
@@ -74,6 +75,7 @@ class ProcessorDependencies:
mqtt_manager: Optional[Any] = None # MQTTManager
game_event_bus: Optional[Any] = None # GameEventBus
audio_processing_template_store: Optional[Any] = None # AudioProcessingTemplateStore
http_endpoint_store: Optional[HTTPEndpointStore] = None
@dataclass
@@ -169,6 +171,7 @@ class ProcessorManager(AutoRestartMixin, DeviceHealthMixin, DeviceTestModeMixin)
event_bus=deps.game_event_bus,
audio_processing_template_store=deps.audio_processing_template_store,
sync_clock_manager=deps.sync_clock_manager,
http_endpoint_store=deps.http_endpoint_store,
)
if deps.value_source_store
else None
@@ -37,6 +37,33 @@ def is_youtube_url(url: str) -> bool:
return any(p.search(url) for p in _YT_PATTERNS)
# Network schemes accepted by ``cv2.VideoCapture``. ``file://`` is intentionally
# excluded — local files are supported via plain path strings (no scheme), so
# explicit ``file://`` requests can only be an attempt to coerce FFmpeg into
# loading something the path-string code path would reject. ``concat:``,
# ``gopher://``, ``crypto:``, etc. are not allowed.
_ALLOWED_NETWORK_SCHEMES: tuple[str, ...] = ("http", "https", "rtsp", "rtsps")
def _assert_video_url_allowed(url: str) -> None:
"""Reject video URLs that use anything other than a vetted scheme.
OpenCV/FFmpeg supports many esoteric input protocols (``concat:``,
``gopher://``, ``crypto:``, ``udp://``, ``async:``, …). Some can read
arbitrary host files or pivot to internal addresses when the caller
can influence the URL. Tighten the input to the schemes we actually
advertise. URLs without a scheme are accepted as local-file paths.
"""
if "://" not in url:
return # plain local path — OpenCV resolves against the working dir
scheme = url.split("://", 1)[0].lower()
if scheme not in _ALLOWED_NETWORK_SCHEMES:
raise RuntimeError(
f"Refusing to open video with unsupported scheme {scheme!r}; "
f"allowed: {', '.join(_ALLOWED_NETWORK_SCHEMES)} or a local file path."
)
def resolve_youtube_url(url: str, resolution_limit: Optional[int] = None) -> str:
"""Resolve a YouTube URL to a direct stream URL using yt-dlp."""
try:
@@ -185,10 +212,14 @@ class VideoCaptureLiveStream(LiveStream):
if self._running:
return
# Resolve YouTube URL if needed
# Resolve YouTube URL if needed. Validate AFTER resolution too, so a
# malicious yt-dlp result (or a redirect we don't expect) can't slip
# through with an unsupported scheme.
actual_url = self._original_url
_assert_video_url_allowed(actual_url)
if is_youtube_url(actual_url):
actual_url = resolve_youtube_url(actual_url, self._resolution_limit)
_assert_video_url_allowed(actual_url)
self._resolved_url = actual_url
# Open capture
@@ -224,7 +224,8 @@ class WledTargetProcessor(TargetProcessor):
self._is_running = False
# Cancel task
# Cancel task. The cancellation is awaited above, so the prior
# 50 ms ``asyncio.sleep`` here was pure dead time on every stop().
if self._task:
self._task.cancel()
try:
@@ -233,7 +234,6 @@ class WledTargetProcessor(TargetProcessor):
logger.debug("WLED target processor task cancelled")
pass
self._task = None
await asyncio.sleep(0.05)
# Restore device state (only if auto_shutdown is enabled)
if self._led_client and self._device_state_before:
+12 -2
View File
@@ -9,6 +9,8 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import List, Optional
from ledgrab.utils import secret_box
def _parse_common(data: dict) -> dict:
"""Extract common fields from a dict, parsing timestamps."""
@@ -52,13 +54,16 @@ class MQTTSource:
icon_color: str = ""
def to_dict(self) -> dict:
# Always persist the broker password in encrypted envelope form. If
# the field already contains an envelope, ``encrypt()`` is a no-op.
stored_password = secret_box.encrypt(self.password) if self.password else ""
d = {
"id": self.id,
"name": self.name,
"broker_host": self.broker_host,
"broker_port": self.broker_port,
"username": self.username,
"password": self.password,
"password": stored_password,
"client_id": self.client_id,
"base_topic": self.base_topic,
"description": self.description,
@@ -75,12 +80,17 @@ class MQTTSource:
@staticmethod
def from_dict(data: dict) -> "MQTTSource":
common = _parse_common(data)
# Decrypt at load time so consumers see plaintext via ``.password``.
# Legacy plaintext rows pass through unchanged — the next save will
# write them back in encrypted form.
raw_password = data.get("password", "")
password = secret_box.decrypt(raw_password) if raw_password else ""
return MQTTSource(
**common,
broker_host=data.get("broker_host", "localhost"),
broker_port=int(data.get("broker_port", 1883)),
username=data.get("username", ""),
password=data.get("password", ""),
password=password,
client_id=data.get("client_id", "ledgrab"),
base_topic=data.get("base_topic", "ledgrab"),
icon=data.get("icon", ""),
@@ -7,7 +7,7 @@ from typing import List, Optional
from ledgrab.storage.base_sqlite_store import BaseSqliteStore
from ledgrab.storage.database import Database
from ledgrab.storage.mqtt_source import MQTTSource
from ledgrab.utils import get_logger
from ledgrab.utils import get_logger, secret_box
logger = get_logger(__name__)
@@ -20,6 +20,41 @@ class MQTTSourceStore(BaseSqliteStore[MQTTSource]):
def __init__(self, db: Database):
super().__init__(db, MQTTSource.from_dict)
self._migrate_plaintext_passwords()
def _migrate_plaintext_passwords(self) -> None:
"""Encrypt any stored MQTT broker passwords still in plaintext at rest.
Mirrors the HA-token migration: inspect raw DB rows, re-save items
that lack the encryption envelope so the next persistence pass
writes them in encrypted form.
"""
migrated = 0
try:
rows = self._db.load_all(self._table_name)
except Exception as exc:
logger.error("Could not inspect rows for MQTT password migration: %s", exc)
return
for row in rows:
sid = row.get("id")
raw_pw = row.get("password", "") or ""
if not sid or not raw_pw:
continue
if secret_box.is_encrypted(raw_pw):
continue
source = self._items.get(sid)
if source is None:
continue
try:
self._save_item(sid, source)
migrated += 1
except Exception as exc:
logger.error("Failed to migrate MQTT password for %s: %s", sid, exc)
if migrated:
logger.warning(
"MIGRATION: encrypted %d plaintext MQTT broker password(s) at rest.",
migrated,
)
# Backward-compatible aliases
get_all_sources = BaseSqliteStore.get_all
+4 -1
View File
@@ -1,13 +1,15 @@
"""Utility functions and helpers."""
from .file_ops import atomic_write_json
from .file_ops import atomic_write_json, read_upload_capped
from .logger import setup_logging, get_logger
from .monitor_names import get_monitor_names, get_monitor_name, get_monitor_refresh_rates
from .timer import high_resolution_timer
from .log_broadcaster import broadcaster as log_broadcaster, install_broadcast_handler
from .url_scheme import infer_http_scheme
__all__ = [
"atomic_write_json",
"read_upload_capped",
"setup_logging",
"get_logger",
"get_monitor_names",
@@ -16,4 +18,5 @@ __all__ = [
"high_resolution_timer",
"log_broadcaster",
"install_broadcast_handler",
"infer_http_scheme",
]
+30
View File
@@ -5,10 +5,40 @@ import logging
import os
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from fastapi import UploadFile
logger = logging.getLogger(__name__)
_UPLOAD_CHUNK_SIZE = 1 * 1024 * 1024 # 1 MiB
async def read_upload_capped(file: "UploadFile", max_size: int) -> bytes:
"""Read an uploaded file in chunks, refusing to buffer past ``max_size``.
A plain ``await file.read()`` will allocate the entire upload body before
any size check runs, so a 1 GB upload still costs 1 GB of memory even
against a 50 MB limit. This helper streams the read and aborts as soon
as the accumulated size exceeds the cap.
Raises:
ValueError: When the upload exceeds *max_size*. The caller decides
how to translate the exception (typically HTTP 400/413).
"""
buf = bytearray()
while True:
chunk = await file.read(_UPLOAD_CHUNK_SIZE)
if not chunk:
break
if len(buf) + len(chunk) > max_size:
raise ValueError(f"Upload exceeds maximum allowed size of {max_size} bytes")
buf.extend(chunk)
return bytes(buf)
def atomic_write_json(file_path: Path, data: dict, indent: int = 2) -> None:
"""Write JSON data to file atomically via temp file + rename.
+145 -16
View File
@@ -20,6 +20,8 @@ from urllib.parse import urlparse
import httpx
from fastapi import HTTPException
from ledgrab.utils.net_classify import is_blocked_for_ssrf
# Image file extensions considered safe to serve
_IMAGE_EXTENSIONS = frozenset(
@@ -38,19 +40,13 @@ _IMAGE_EXTENSIONS = frozenset(
def _is_blocked_ip(ip: str) -> bool:
"""Return True when *ip* belongs to a category we refuse to fetch from."""
try:
addr = ipaddress.ip_address(ip)
except ValueError:
return True # unparseable → block
return (
addr.is_private
or addr.is_loopback
or addr.is_link_local
or addr.is_reserved
or addr.is_multicast
or addr.is_unspecified
)
"""Return True when *ip* belongs to a category we refuse to fetch from.
Thin wrapper preserved for the local test suite. The actual policy now
lives in :func:`ledgrab.utils.net_classify.is_blocked_for_ssrf` so it
can't drift away from the related ``url_scheme`` / ``auth`` modules.
"""
return is_blocked_for_ssrf(ip)
def _resolve_hostname(hostname: str) -> Iterable[str]:
@@ -67,6 +63,11 @@ def _resolve_hostname(hostname: str) -> Iterable[str]:
def validate_image_url(url: str) -> None:
"""Validate that *url* is safe to fetch.
Strict SSRF policy — blocks all private / loopback / link-local /
reserved / multicast / unspecified IPs. Intended for fetches of
user-supplied URLs that should reach the *public* internet only
(image sources, weather APIs, etc.).
Checks:
- Scheme is http/https
- Hostname is present
@@ -74,6 +75,25 @@ def validate_image_url(url: str) -> None:
- DNS resolves to non-private, non-loopback, non-link-local, non-reserved,
non-multicast addresses (SSRF protection)
"""
_validate_url(url, allow_private=False)
def validate_polling_url(url: str) -> None:
"""Validate that *url* is safe to poll from an automation trigger.
Relaxed SSRF policy — allows PRIVATE (LAN) IPs so users can target
devices on their own network (Plex on 192.168.x.x, Home Assistant
on 10.x.x.x). Still blocks LOOPBACK (protects the server's own
admin ports), LINK_LOCAL (blocks 169.254.169.254 cloud-metadata
services), MULTICAST, RESERVED, and UNSPECIFIED.
Use this instead of :func:`validate_image_url` for any user-configured
HTTP polling target.
"""
_validate_url(url, allow_private=True)
def _validate_url(url: str, *, allow_private: bool) -> None:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
raise HTTPException(
@@ -93,17 +113,41 @@ def validate_image_url(url: str) -> None:
ips = _resolve_hostname(hostname)
for ip in ips:
if _is_blocked_ip(ip):
if _ip_is_blocked(ip, allow_private=allow_private):
policy = (
"loopback / link-local / reserved / multicast"
if allow_private
else "private / loopback / link-local / reserved / multicast"
)
raise HTTPException(
status_code=400,
detail=(
f"Refusing to fetch URL: hostname {hostname!r} resolves to "
f"blocked address {ip} (private / loopback / link-local / "
"reserved / multicast)"
f"blocked address {ip} ({policy})"
),
)
def _ip_is_blocked(ip: str, *, allow_private: bool) -> bool:
"""Block-list predicate parameterised by whether PRIVATE is permitted."""
from ledgrab.utils.net_classify import HostCategory, classify_ip
category = classify_ip(ip)
always_block = {
HostCategory.LOOPBACK,
HostCategory.LINK_LOCAL,
HostCategory.MULTICAST,
HostCategory.RESERVED,
HostCategory.UNSPECIFIED,
HostCategory.UNPARSEABLE,
}
if category in always_block:
return True
if not allow_private and category == HostCategory.PRIVATE:
return True
return False
def validate_image_path(file_path: str | Path) -> Path:
"""Validate a local file path points to a real image file.
@@ -165,3 +209,88 @@ async def safe_fetch(
response.raise_for_status()
return response
raise HTTPException(status_code=400, detail=f"Too many redirects (max {max_hops})")
# ---------------------------------------------------------------------------
# Bounded safe request — for polling-style fetches that need streaming
# + size cap + arbitrary method/headers. Used by the HTTP poll runtime.
# ---------------------------------------------------------------------------
# Hard cap to avoid OOM from polled endpoints returning huge bodies.
# 1 MiB covers every realistic JSON status endpoint by orders of magnitude.
_MAX_RESPONSE_BYTES = 1 * 1024 * 1024
async def safe_request_bounded(
method: str,
url: str,
*,
headers: dict[str, str] | None = None,
timeout: float = 10.0,
max_bytes: int = _MAX_RESPONSE_BYTES,
allow_private: bool = True,
) -> tuple[int, bytes, str | None]:
"""Make a single HTTP request with SSRF validation and a body-size cap.
Validates the URL up-front (scheme + IP-block) and streams the response
body, aborting as soon as ``max_bytes`` is exceeded so a malicious or
misconfigured endpoint can't OOM the server.
``allow_private`` selects the SSRF policy: ``True`` (default — for HTTP
polling against user-owned LAN devices) routes through
:func:`validate_polling_url`; ``False`` routes through
:func:`validate_image_url`'s strict public-only policy.
Note on DNS rebinding: this validates the URL hostname's resolved IPs
once, then httpx independently re-resolves to make the actual request.
The window between the two resolutions is short but not zero. True
IP-pinning (rewriting the URL to the literal IP + Host header) is a
project-wide concern that would also need to update ``safe_fetch`` and
every other outbound caller; tracked as a follow-up.
Returns:
``(status_code, body_bytes, error_message)``. ``error_message`` is
``None`` on success; on transport failure ``status_code`` is 0 and
``body_bytes`` is empty.
Raises:
HTTPException: when the URL fails validation (4xx-style error
intended for callers that surface this directly to the user).
"""
if allow_private:
validate_polling_url(url)
else:
validate_image_url(url)
method = method.upper()
headers = dict(headers or {})
try:
async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client:
async with client.stream(method, url, headers=headers) as response:
chunks: list[bytes] = []
total = 0
truncated = False
async for chunk in response.aiter_bytes():
if total + len(chunk) > max_bytes:
chunks.append(chunk[: max_bytes - total])
total = max_bytes
truncated = True
break
chunks.append(chunk)
total += len(chunk)
body = b"".join(chunks)
# Truncation is NOT an error — the partial body is returned
# for the caller to use (e.g. evaluate a JSON path against the
# first 1 MiB). Callers that care can detect it by comparing
# ``len(body) == max_bytes``. We only flag actual transport
# failures in the error slot.
_ = truncated
return response.status_code, body, None
except httpx.RequestError as exc:
# Never include `exc` repr in the message — some httpx errors include
# request URL + headers, which could leak auth tokens supplied by the
# caller. Type name is sufficient diagnostic.
return 0, b"", f"Request failed: {type(exc).__name__}"
except Exception as exc:
return 0, b"", f"Unexpected error: {type(exc).__name__}"
+18 -12
View File
@@ -12,22 +12,28 @@ class TestBackupRestoreFlow:
"""A user backs up their configuration and restores it."""
def _create_device(self, client, name="Backup Device") -> str:
resp = client.post("/api/v1/devices", json={
"name": name,
"url": "mock://backup",
"device_type": "mock",
"led_count": 30,
})
resp = client.post(
"/api/v1/devices",
json={
"name": name,
"url": "mock://backup",
"device_type": "mock",
"led_count": 30,
},
)
assert resp.status_code == 201
return resp.json()["id"]
def _create_css(self, client, name="Backup CSS") -> str:
resp = client.post("/api/v1/color-strip-sources", json={
"name": name,
"source_type": "static",
"color": [255, 0, 0],
"led_count": 30,
})
resp = client.post(
"/api/v1/color-strip-sources",
json={
"name": name,
"source_type": "single_color",
"color": [255, 0, 0],
"led_count": 30,
},
)
assert resp.status_code == 201
return resp.json()["id"]
+16 -4
View File
@@ -78,15 +78,27 @@ def test_get_displays(client):
def test_openapi_docs(client):
"""Test OpenAPI documentation is available."""
response = client.get("/openapi.json")
"""Test OpenAPI documentation is available to authenticated clients."""
response = client.get("/openapi.json", headers=AUTH_HEADERS)
assert response.status_code == 200
data = response.json()
assert data["info"]["version"] == __version__
def test_swagger_ui(client):
"""Test Swagger UI is available."""
response = client.get("/docs")
"""Test Swagger UI is available to authenticated clients."""
response = client.get("/docs", headers=AUTH_HEADERS)
assert response.status_code == 200
assert "text/html" in response.headers["content-type"]
def test_openapi_docs_requires_auth(client):
"""OpenAPI surface must NOT be reachable without auth (info disclosure)."""
response = client.get("/openapi.json")
assert response.status_code == 401
def test_swagger_ui_requires_auth(client):
"""Swagger UI must NOT be reachable without auth."""
response = client.get("/docs")
assert response.status_code == 401