chore(backend): MQTT/WLED/devices/capture/utils + api routes hardening

Bundle the remaining backend touch-ups that the production review
landed individually as small surgical edits across many modules:
- MQTT runtime: fire-and-forget task tracking + drain resilience.
- mqtt_source + store + storage/color_strip_source: secret_box
  encryption for credentials with auto-migration of plaintext fields.
- devices/discovery_watcher: task tracking on watcher start/stop.
- devices/wled_client + wled_provider: URL scheme inference helper
  applied at the create/update boundary so bare hostnames stay valid.
- core/capture/screen_capture: hardened error paths.
- core/processing (mapped/processed/processor_manager/video/wled_target):
  smaller follow-throughs from the registry refactor that landed
  earlier on the branch.
- utils/safe_source + utils/file_ops + utils/__init__: shared URL +
  IP classification helpers + larger streaming upload size caps.
- api/auth: WebSocket Origin allow-list + /docs auth-gate.
- api/dependencies: register the new HTTP-endpoint store.
- api/routes (assets, backup, webhooks): streaming-upload caps +
  asyncio.gather return_exceptions on broadcast loops.
- tests/test_api + tests/e2e/test_backup_flow: cover the new caps and
  the Origin allow-list.
This commit is contained in:
2026-05-23 00:50:01 +03:00
parent 45d12b2811
commit 898912f8b1
22 changed files with 498 additions and 73 deletions
+42 -9
View File
@@ -11,14 +11,13 @@ from starlette.websockets import WebSocket, WebSocketDisconnect
from ledgrab.config import get_config from ledgrab.config import get_config
from ledgrab.utils import get_logger from ledgrab.utils import get_logger
from ledgrab.utils.net_classify import is_loopback as _classify_is_loopback
logger = get_logger(__name__) logger = get_logger(__name__)
# Security scheme for Bearer token # Security scheme for Bearer token
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
_LOOPBACK_HOSTS = frozenset({"127.0.0.1", "::1", "localhost", "testclient"})
def is_auth_enabled() -> bool: def is_auth_enabled() -> bool:
"""Return True when at least one API key is configured.""" """Return True when at least one API key is configured."""
@@ -26,15 +25,15 @@ def is_auth_enabled() -> bool:
def _is_loopback(host: str | None) -> bool: def _is_loopback(host: str | None) -> bool:
"""Return True when *host* is a loopback address.""" """Return True when *host* is a loopback address.
Delegates to :func:`ledgrab.utils.net_classify.is_loopback` so this
auth gate, the SSRF guard in ``safe_source``, and the LAN-default
inference in ``url_scheme`` share one classification source.
"""
if not host: if not host:
return False return False
# Strip IPv6 brackets and zone IDs return _classify_is_loopback(host)
h = host.strip().lower()
if h.startswith("[") and h.endswith("]"):
h = h[1:-1]
h = h.split("%", 1)[0]
return h in _LOOPBACK_HOSTS
def verify_api_key( def verify_api_key(
@@ -142,6 +141,23 @@ def require_authenticated(label: str) -> None:
WS_AUTH_CLOSE_CODE = 4401 WS_AUTH_CLOSE_CODE = 4401
WS_ORIGIN_CLOSE_CODE = 4403
"""Close code sent when a WebSocket request fails the Origin allowlist."""
def _is_origin_allowed(origin: str | None, allowed: list[str]) -> bool:
"""Return True when *origin* matches one of the configured CORS origins.
Non-browser clients (Python scripts, curl) don't send Origin — those are
allowed through; the Bearer-token check on the auth handshake is the
primary defence in that case. Browsers always set Origin, so this only
blocks cross-site WebSocket connection attempts (CSWSH).
"""
if not origin:
return True
return origin in set(allowed or [])
async def accept_and_authenticate_ws(websocket: WebSocket, timeout: float = 3.0) -> str | None: async def accept_and_authenticate_ws(websocket: WebSocket, timeout: float = 3.0) -> str | None:
"""Accept the WebSocket, then perform first-message auth handshake. """Accept the WebSocket, then perform first-message auth handshake.
@@ -152,6 +168,23 @@ async def accept_and_authenticate_ws(websocket: WebSocket, timeout: float = 3.0)
Returns the caller label on success, ``None`` on failure (connection Returns the caller label on success, ``None`` on failure (connection
already closed). already closed).
""" """
# Reject cross-site WebSocket attempts before accepting — a browser-based
# attacker page cannot forge the Origin header, so an Origin mismatch is
# a strong signal even before the token check. Non-browser clients
# legitimately omit Origin; those fall through to the auth handshake.
config = get_config()
origin = websocket.headers.get("origin")
if not _is_origin_allowed(origin, config.server.cors_origins):
logger.warning(
"Rejected WebSocket from origin %r (not in cors_origins)",
origin,
)
try:
await websocket.close(code=WS_ORIGIN_CLOSE_CODE)
except Exception:
pass
return None
await websocket.accept() await websocket.accept()
label = await verify_ws_auth(websocket, timeout=timeout) label = await verify_ws_auth(websocket, timeout=timeout)
if label is None: if label is None:
+7
View File
@@ -37,6 +37,7 @@ from ledgrab.storage.game_integration_store import GameIntegrationStore
from ledgrab.core.game_integration.event_bus import GameEventBus from ledgrab.core.game_integration.event_bus import GameEventBus
from ledgrab.storage.mqtt_source_store import MQTTSourceStore from ledgrab.storage.mqtt_source_store import MQTTSourceStore
from ledgrab.core.mqtt.mqtt_manager import MQTTManager from ledgrab.core.mqtt.mqtt_manager import MQTTManager
from ledgrab.storage.http_endpoint_store import HTTPEndpointStore
from ledgrab.storage.audio_processing_template_store import AudioProcessingTemplateStore from ledgrab.storage.audio_processing_template_store import AudioProcessingTemplateStore
from ledgrab.storage.pattern_template_store import PatternTemplateStore from ledgrab.storage.pattern_template_store import PatternTemplateStore
@@ -165,6 +166,10 @@ def get_mqtt_manager() -> MQTTManager:
return _get("mqtt_manager", "MQTT manager") return _get("mqtt_manager", "MQTT manager")
def get_http_endpoint_store() -> HTTPEndpointStore:
return _get("http_endpoint_store", "HTTP endpoint store")
def get_audio_processing_template_store() -> AudioProcessingTemplateStore: def get_audio_processing_template_store() -> AudioProcessingTemplateStore:
return _get("audio_processing_template_store", "Audio processing template store") return _get("audio_processing_template_store", "Audio processing template store")
@@ -237,6 +242,7 @@ def init_dependencies(
game_event_bus: GameEventBus | None = None, game_event_bus: GameEventBus | None = None,
mqtt_store: MQTTSourceStore | None = None, mqtt_store: MQTTSourceStore | None = None,
mqtt_manager: MQTTManager | None = None, mqtt_manager: MQTTManager | None = None,
http_endpoint_store: HTTPEndpointStore | None = None,
audio_processing_template_store: AudioProcessingTemplateStore | None = None, audio_processing_template_store: AudioProcessingTemplateStore | None = None,
pattern_template_store: PatternTemplateStore | None = None, pattern_template_store: PatternTemplateStore | None = None,
): ):
@@ -272,6 +278,7 @@ def init_dependencies(
"game_event_bus": game_event_bus, "game_event_bus": game_event_bus,
"mqtt_store": mqtt_store, "mqtt_store": mqtt_store,
"mqtt_manager": mqtt_manager, "mqtt_manager": mqtt_manager,
"http_endpoint_store": http_endpoint_store,
"audio_processing_template_store": audio_processing_template_store, "audio_processing_template_store": audio_processing_template_store,
"pattern_template_store": pattern_template_store, "pattern_template_store": pattern_template_store,
} }
+5 -4
View File
@@ -15,7 +15,7 @@ from ledgrab.api.schemas.assets import (
from ledgrab.config import get_config from ledgrab.config import get_config
from ledgrab.storage.asset_store import AssetStore from ledgrab.storage.asset_store import AssetStore
from ledgrab.storage.base_store import EntityNotFoundError from ledgrab.storage.base_store import EntityNotFoundError
from ledgrab.utils import get_logger from ledgrab.utils import get_logger, read_upload_capped
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -93,10 +93,11 @@ async def upload_asset(
config = get_config() config = get_config()
max_size = getattr(getattr(config, "assets", None), "max_file_size_mb", 50) * 1024 * 1024 max_size = getattr(getattr(config, "assets", None), "max_file_size_mb", 50) * 1024 * 1024
data = await file.read() try:
if len(data) > max_size: data = await read_upload_capped(file, max_size)
except ValueError:
raise HTTPException( raise HTTPException(
status_code=400, status_code=413,
detail=f"File too large (max {max_size // (1024 * 1024)} MB)", detail=f"File too large (max {max_size // (1024 * 1024)} MB)",
) )
+6 -4
View File
@@ -28,7 +28,7 @@ from ledgrab.config import get_config
from ledgrab.core.backup.auto_backup import AutoBackupEngine from ledgrab.core.backup.auto_backup import AutoBackupEngine
from ledgrab.storage.asset_store import AssetStore from ledgrab.storage.asset_store import AssetStore
from ledgrab.storage.database import Database, freeze_writes from ledgrab.storage.database import Database, freeze_writes
from ledgrab.utils import get_logger from ledgrab.utils import get_logger, read_upload_capped
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -133,9 +133,11 @@ async def restore_config(
because restore replaces all configuration including secrets). because restore replaces all configuration including secrets).
""" """
require_authenticated(auth) require_authenticated(auth)
raw = await file.read() _MAX_BACKUP_BYTES = 200 * 1024 * 1024 # 200 MB (ZIP may contain assets)
if len(raw) > 200 * 1024 * 1024: # 200 MB limit (ZIP may contain assets) try:
raise HTTPException(status_code=400, detail="Backup file too large (max 200 MB)") raw = await read_upload_capped(file, _MAX_BACKUP_BYTES)
except ValueError:
raise HTTPException(status_code=413, detail="Backup file too large (max 200 MB)")
if len(raw) < 100: if len(raw) < 100:
raise HTTPException(status_code=400, detail="File too small to be a valid backup") raise HTTPException(status_code=400, detail="File too small to be a valid backup")
+14 -1
View File
@@ -30,6 +30,9 @@ _RATE_WINDOW = 60.0 # seconds
_rate_hits: dict[str, list[float]] = defaultdict(list) _rate_hits: dict[str, list[float]] = defaultdict(list)
_RATE_HITS_HARD_CAP = 1024
def _check_rate_limit(client_ip: str) -> None: def _check_rate_limit(client_ip: str) -> None:
"""Raise 429 if *client_ip* exceeded the webhook rate limit.""" """Raise 429 if *client_ip* exceeded the webhook rate limit."""
now = time.time() now = time.time()
@@ -44,11 +47,21 @@ def _check_rate_limit(client_ip: str) -> None:
) )
_rate_hits[client_ip].append(now) _rate_hits[client_ip].append(now)
# Periodic cleanup: remove IPs with no recent hits to prevent unbounded growth # Periodic cleanup: remove IPs with no recent hits to prevent unbounded growth.
if len(_rate_hits) > 100: if len(_rate_hits) > 100:
stale = [ip for ip, ts in _rate_hits.items() if not ts or ts[-1] < window_start] stale = [ip for ip, ts in _rate_hits.items() if not ts or ts[-1] < window_start]
for ip in stale: for ip in stale:
del _rate_hits[ip] del _rate_hits[ip]
# Hard cap as a final defence against an attacker spraying many distinct
# X-Forwarded-For values to drive memory growth past the soft cleanup
# threshold. Drop the oldest-touched IPs (by their latest timestamp).
if len(_rate_hits) > _RATE_HITS_HARD_CAP:
ordered = sorted(
_rate_hits.items(),
key=lambda kv: kv[1][-1] if kv[1] else 0.0,
)
for ip, _ in ordered[: len(ordered) - _RATE_HITS_HARD_CAP]:
_rate_hits.pop(ip, None)
class WebhookPayload(BaseModel): class WebhookPayload(BaseModel):
@@ -14,6 +14,12 @@ from ledgrab.utils import get_logger, get_monitor_names, get_monitor_refresh_rat
logger = get_logger(__name__) logger = get_logger(__name__)
# Reused random Generator for sampling. The legacy ``np.random.randint``
# uses the module-level RandomState which is slightly slower per-call and
# pulls in extra import-time work; a single Generator is faster and avoids
# global-state surprises.
_rng = np.random.default_rng()
@dataclass @dataclass
class DisplayInfo: class DisplayInfo:
@@ -326,7 +332,11 @@ def calculate_dominant_color(pixels: np.ndarray) -> tuple[int, int, int]:
max_samples = 1000 max_samples = 1000
if n > max_samples: if n > max_samples:
indices = np.random.randint(0, n, max_samples) # ``Generator.integers`` writes into a fresh buffer once per call;
# the legacy ``np.random.randint`` did the same plus extra
# bookkeeping. Random (not stride) sampling stays robust against
# periodic patterns in screen pixels.
indices = _rng.integers(0, n, size=max_samples)
pixels_reshaped = pixels_reshaped[indices] pixels_reshaped = pixels_reshaped[indices]
# Quantize to 32 levels/channel (drop low 3 bits) and pack into uint32: # Quantize to 32 levels/channel (drop low 3 bits) and pack into uint32:
@@ -85,6 +85,10 @@ class DiscoveryWatcher:
self._wled_seen: Dict[str, _DiscoveredEntry] = {} self._wled_seen: Dict[str, _DiscoveredEntry] = {}
# device-path -> entry. Only the serial poller mutates this. # device-path -> entry. Only the serial poller mutates this.
self._serial_seen: Dict[str, _DiscoveredEntry] = {} self._serial_seen: Dict[str, _DiscoveredEntry] = {}
# Strong references for fire-and-forget resolve tasks — without
# these, Python 3.11+ may GC the task mid-resolve and silently lose
# discovery events. Tasks remove themselves on completion.
self._resolve_tasks: "set[asyncio.Task]" = set()
# --- lifecycle -------------------------------------------------------- # --- lifecycle --------------------------------------------------------
@@ -155,12 +159,23 @@ class DiscoveryWatcher:
if state_change in (ServiceStateChange.Added, ServiceStateChange.Updated): if state_change in (ServiceStateChange.Added, ServiceStateChange.Updated):
# Resolve in a task — async_request blocks the handler if awaited # Resolve in a task — async_request blocks the handler if awaited
# synchronously and we don't want to stall mDNS dispatch. # synchronously and we don't want to stall mDNS dispatch.
asyncio.create_task(self._resolve_wled(service_type, name)) task = asyncio.create_task(self._resolve_wled(service_type, name))
self._resolve_tasks.add(task)
task.add_done_callback(self._on_resolve_done)
elif state_change == ServiceStateChange.Removed: elif state_change == ServiceStateChange.Removed:
entry = self._wled_seen.pop(name, None) entry = self._wled_seen.pop(name, None)
if entry is not None and not self._is_configured(entry.url): if entry is not None and not self._is_configured(entry.url):
self._emit("device_lost", entry) self._emit("device_lost", entry)
def _on_resolve_done(self, task: "asyncio.Task") -> None:
"""Release the strong task reference and log any resolve failure."""
self._resolve_tasks.discard(task)
if task.cancelled():
return
exc = task.exception()
if exc is not None:
logger.debug("Discovery watcher: resolve task raised: %s", exc)
async def _resolve_wled(self, service_type: str, name: str) -> None: async def _resolve_wled(self, service_type: str, name: str) -> None:
if self._aiozc is None: if self._aiozc is None:
return return
@@ -453,8 +453,15 @@ class WLEDClient(LEDClient):
], ],
} }
logger.debug(f"Sending {len(pixels)} LEDs via HTTP ({len(indexed_pixels)} values)") # ``str(payload)`` previously stringified the entire indexed
logger.debug(f"Payload size: ~{len(str(payload))} bytes") # array on every send to report a byte estimate; that was the
# hot path. Drop the size readout — the LED count + indexed
# value count is enough to interpret traffic and is O(1).
logger.debug(
"Sending %d LEDs via HTTP (%d indexed values)",
len(pixels),
len(indexed_pixels),
)
await self._request("POST", "/json/state", json_data=payload) await self._request("POST", "/json/state", json_data=payload)
logger.debug("Successfully sent pixel colors via HTTP") logger.debug("Successfully sent pixel colors via HTTP")
@@ -98,11 +98,18 @@ class WLEDDeviceProvider(LEDDeviceProvider):
dict with 'led_count' key. dict with 'led_count' key.
Raises: Raises:
ValueError: Unsupported scheme or invalid LED count.
httpx.ConnectError: Device unreachable. httpx.ConnectError: Device unreachable.
httpx.TimeoutException: Connection timed out. httpx.TimeoutException: Connection timed out.
ValueError: Invalid LED count.
""" """
url = _normalize_url(url) url = _normalize_url(url)
# Reject anything that isn't plain HTTP(S). url_scheme.infer_http_scheme
# passes non-HTTP schemes through untouched ("javascript:", "file:",
# "data:", etc.); without this guard those would reach httpx and
# surface as opaque transport errors at best, or be silently misused
# at worst.
if not url.lower().startswith(("http://", "https://")):
raise ValueError(f"WLED URL must use http:// or https:// scheme (got {url!r})")
async with httpx.AsyncClient(timeout=5) as client: async with httpx.AsyncClient(timeout=5) as client:
response = await client.get(_join(url, "/json/info")) response = await client.get(_join(url, "/json/info"))
response.raise_for_status() response.raise_for_status()
+58 -6
View File
@@ -49,6 +49,32 @@ class MQTTRuntime:
# Pending publishes queued while disconnected # Pending publishes queued while disconnected
self._publish_queue: asyncio.Queue = asyncio.Queue(maxsize=1000) self._publish_queue: asyncio.Queue = asyncio.Queue(maxsize=1000)
# Strong references for fire-and-forget callback dispatch tasks.
# Python 3.11+ may GC bare ``asyncio.create_task(...)`` results mid-
# flight, so we hold each task until it completes and surface any
# exception via the done-callback.
self._dispatch_tasks: Set[asyncio.Task] = set()
# Compiled ``aiomqtt.Topic`` cache, keyed by the subscription pattern
# string. The previous dispatch loop re-parsed every pattern on
# every incoming message — on a chatty broker with many wildcards
# that adds up fast.
self._compiled_topics: Dict[str, aiomqtt.Topic] = {}
def _on_dispatch_done(self, task: asyncio.Task) -> None:
"""Drop the strong reference and surface any callback exception."""
self._dispatch_tasks.discard(task)
if task.cancelled():
return
exc = task.exception()
if exc is not None:
logger.error(
"MQTT async callback raised (%s): %s",
self._source_id,
exc,
exc_info=exc,
)
@property @property
def is_connected(self) -> bool: def is_connected(self) -> bool:
return self._connected return self._connected
@@ -84,6 +110,14 @@ class MQTTRuntime:
logger.debug("MQTT runtime task cancelled: %s", self._source_id) logger.debug("MQTT runtime task cancelled: %s", self._source_id)
self._task = None self._task = None
self._connected = False self._connected = False
# Cancel any in-flight async dispatch callbacks. Without this they
# would keep running past the runtime's logical end of life and
# could fire callbacks on a stopped subscriber.
for task in list(self._dispatch_tasks):
task.cancel()
if self._dispatch_tasks:
await asyncio.gather(*self._dispatch_tasks, return_exceptions=True)
self._dispatch_tasks.clear()
logger.info("MQTT runtime stopped: %s", self._source_id) logger.info("MQTT runtime stopped: %s", self._source_id)
def update_config(self, source: MQTTSource) -> None: def update_config(self, source: MQTTSource) -> None:
@@ -167,13 +201,23 @@ class MQTTRuntime:
for topic in self._subscriptions: for topic in self._subscriptions:
await client.subscribe(topic) await client.subscribe(topic)
# Drain pending publishes # Drain pending publishes. A single publish failing
# (broker rejection, oversized message) must not lose
# the rest of the queue — log and continue with the next.
while not self._publish_queue.empty(): while not self._publish_queue.empty():
try: try:
t, p, r, q = self._publish_queue.get_nowait() t, p, r, q = self._publish_queue.get_nowait()
await client.publish(t, p, retain=r, qos=q) except asyncio.QueueEmpty:
except Exception:
break break
try:
await client.publish(t, p, retain=r, qos=q)
except Exception as exc:
logger.warning(
"MQTT drain publish failed (%s -> %s): %s",
self._source_id,
t,
exc,
)
# Message receive loop # Message receive loop
async for msg in client.messages: async for msg in client.messages:
@@ -183,13 +227,21 @@ class MQTTRuntime:
) )
self._topic_cache[topic_str] = payload_str self._topic_cache[topic_str] = payload_str
# Dispatch to callbacks # Dispatch to callbacks. Pattern objects are cached
# per subscription string to avoid re-parsing them on
# every received message.
for sub_topic, callbacks in self._subscriptions.items(): for sub_topic, callbacks in self._subscriptions.items():
if aiomqtt.Topic(sub_topic).matches(msg.topic): compiled = self._compiled_topics.get(sub_topic)
if compiled is None:
compiled = aiomqtt.Topic(sub_topic)
self._compiled_topics[sub_topic] = compiled
if compiled.matches(msg.topic):
for cb in callbacks: for cb in callbacks:
try: try:
if asyncio.iscoroutinefunction(cb): if asyncio.iscoroutinefunction(cb):
asyncio.create_task(cb(topic_str, payload_str)) task = asyncio.create_task(cb(topic_str, payload_str))
self._dispatch_tasks.add(task)
task.add_done_callback(self._on_dispatch_done)
else: else:
cb(topic_str, payload_str) cb(topic_str, payload_str)
except Exception as e: except Exception as e:
@@ -2,6 +2,7 @@
import threading import threading
import time import time
from collections import OrderedDict
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np import numpy as np
@@ -11,6 +12,11 @@ from ledgrab.utils import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
# Cap the (src_len, dst_len) resize cache. Each entry holds two ``np.linspace``
# arrays plus a per-zone ``uint8`` scratch buffer, which used to grow without
# bound under hot reconfigure storms.
_RESIZE_CACHE_MAX = 16
class MappedColorStripStream(ColorStripStream): class MappedColorStripStream(ColorStripStream):
"""Places multiple ColorStripStreams side-by-side at distinct LED ranges. """Places multiple ColorStripStreams side-by-side at distinct LED ranges.
@@ -46,8 +52,11 @@ class MappedColorStripStream(ColorStripStream):
# zone_index -> (source_id, consumer_id, stream) # zone_index -> (source_id, consumer_id, stream)
self._sub_streams: Dict[int, tuple] = {} self._sub_streams: Dict[int, tuple] = {}
# (src_len, dst_len) -> (src_x, dst_x, buffer) cache for zone resizing # (src_len, dst_len) -> (src_x, dst_x, buffer) cache for zone resizing.
self._resize_cache: Dict[tuple, tuple] = {} # An ``OrderedDict`` with eviction keeps memory bounded if device
# configurations fluctuate at runtime (each unique pair adds two
# linspace arrays + a per-zone reusable uint8 buffer).
self._resize_cache: "OrderedDict[tuple, tuple]" = OrderedDict()
self._sub_lock = threading.Lock() # guards _sub_streams access across threads self._sub_lock = threading.Lock() # guards _sub_streams access across threads
# ── ColorStripStream interface ────────────────────────────── # ── ColorStripStream interface ──────────────────────────────
@@ -229,6 +238,14 @@ class MappedColorStripStream(ColorStripStream):
np.empty((zone_len, 3), dtype=np.uint8), np.empty((zone_len, 3), dtype=np.uint8),
) )
self._resize_cache[rkey] = cached self._resize_cache[rkey] = cached
# Drop the least-recently-inserted entry once
# we hit the cap. 16 entries comfortably covers
# any realistic zone/source layout — pathological
# reconfigure storms used to grow this forever.
if len(self._resize_cache) > _RESIZE_CACHE_MAX:
self._resize_cache.popitem(last=False)
else:
self._resize_cache.move_to_end(rkey)
src_x, dst_x, resized = cached src_x, dst_x, resized = cached
for ch in range(3): for ch in range(3):
np.copyto( np.copyto(
@@ -160,9 +160,11 @@ class ProcessedColorStripStream(ColorStripStream):
self._resolve_count = 0 self._resolve_count = 0
self._resolve_filters() self._resolve_filters()
colors = None # Bind to a local first — ``update_source()`` may swap or null
if self._input_stream: # out ``_input_stream`` between the check and the read on a
colors = self._input_stream.get_latest_colors() # different thread.
inp = self._input_stream
colors = inp.get_latest_colors() if inp is not None else None
if colors is not None and self._filters: if colors is not None and self._filters:
for flt in self._filters: for flt in self._filters:
@@ -38,6 +38,7 @@ from ledgrab.storage.picture_source_store import PictureSourceStore
from ledgrab.storage.postprocessing_template_store import PostprocessingTemplateStore from ledgrab.storage.postprocessing_template_store import PostprocessingTemplateStore
from ledgrab.storage.template_store import TemplateStore from ledgrab.storage.template_store import TemplateStore
from ledgrab.storage.value_source_store import ValueSourceStore from ledgrab.storage.value_source_store import ValueSourceStore
from ledgrab.storage.http_endpoint_store import HTTPEndpointStore
from ledgrab.storage.asset_store import AssetStore from ledgrab.storage.asset_store import AssetStore
from ledgrab.core.processing.sync_clock_manager import SyncClockManager from ledgrab.core.processing.sync_clock_manager import SyncClockManager
from ledgrab.core.weather.weather_manager import WeatherManager from ledgrab.core.weather.weather_manager import WeatherManager
@@ -74,6 +75,7 @@ class ProcessorDependencies:
mqtt_manager: Optional[Any] = None # MQTTManager mqtt_manager: Optional[Any] = None # MQTTManager
game_event_bus: Optional[Any] = None # GameEventBus game_event_bus: Optional[Any] = None # GameEventBus
audio_processing_template_store: Optional[Any] = None # AudioProcessingTemplateStore audio_processing_template_store: Optional[Any] = None # AudioProcessingTemplateStore
http_endpoint_store: Optional[HTTPEndpointStore] = None
@dataclass @dataclass
@@ -169,6 +171,7 @@ class ProcessorManager(AutoRestartMixin, DeviceHealthMixin, DeviceTestModeMixin)
event_bus=deps.game_event_bus, event_bus=deps.game_event_bus,
audio_processing_template_store=deps.audio_processing_template_store, audio_processing_template_store=deps.audio_processing_template_store,
sync_clock_manager=deps.sync_clock_manager, sync_clock_manager=deps.sync_clock_manager,
http_endpoint_store=deps.http_endpoint_store,
) )
if deps.value_source_store if deps.value_source_store
else None else None
@@ -37,6 +37,33 @@ def is_youtube_url(url: str) -> bool:
return any(p.search(url) for p in _YT_PATTERNS) return any(p.search(url) for p in _YT_PATTERNS)
# Network schemes accepted by ``cv2.VideoCapture``. ``file://`` is intentionally
# excluded — local files are supported via plain path strings (no scheme), so
# explicit ``file://`` requests can only be an attempt to coerce FFmpeg into
# loading something the path-string code path would reject. ``concat:``,
# ``gopher://``, ``crypto:``, etc. are not allowed.
_ALLOWED_NETWORK_SCHEMES: tuple[str, ...] = ("http", "https", "rtsp", "rtsps")
def _assert_video_url_allowed(url: str) -> None:
"""Reject video URLs that use anything other than a vetted scheme.
OpenCV/FFmpeg supports many esoteric input protocols (``concat:``,
``gopher://``, ``crypto:``, ``udp://``, ``async:``, …). Some can read
arbitrary host files or pivot to internal addresses when the caller
can influence the URL. Tighten the input to the schemes we actually
advertise. URLs without a scheme are accepted as local-file paths.
"""
if "://" not in url:
return # plain local path — OpenCV resolves against the working dir
scheme = url.split("://", 1)[0].lower()
if scheme not in _ALLOWED_NETWORK_SCHEMES:
raise RuntimeError(
f"Refusing to open video with unsupported scheme {scheme!r}; "
f"allowed: {', '.join(_ALLOWED_NETWORK_SCHEMES)} or a local file path."
)
def resolve_youtube_url(url: str, resolution_limit: Optional[int] = None) -> str: def resolve_youtube_url(url: str, resolution_limit: Optional[int] = None) -> str:
"""Resolve a YouTube URL to a direct stream URL using yt-dlp.""" """Resolve a YouTube URL to a direct stream URL using yt-dlp."""
try: try:
@@ -185,10 +212,14 @@ class VideoCaptureLiveStream(LiveStream):
if self._running: if self._running:
return return
# Resolve YouTube URL if needed # Resolve YouTube URL if needed. Validate AFTER resolution too, so a
# malicious yt-dlp result (or a redirect we don't expect) can't slip
# through with an unsupported scheme.
actual_url = self._original_url actual_url = self._original_url
_assert_video_url_allowed(actual_url)
if is_youtube_url(actual_url): if is_youtube_url(actual_url):
actual_url = resolve_youtube_url(actual_url, self._resolution_limit) actual_url = resolve_youtube_url(actual_url, self._resolution_limit)
_assert_video_url_allowed(actual_url)
self._resolved_url = actual_url self._resolved_url = actual_url
# Open capture # Open capture
@@ -224,7 +224,8 @@ class WledTargetProcessor(TargetProcessor):
self._is_running = False self._is_running = False
# Cancel task # Cancel task. The cancellation is awaited above, so the prior
# 50 ms ``asyncio.sleep`` here was pure dead time on every stop().
if self._task: if self._task:
self._task.cancel() self._task.cancel()
try: try:
@@ -233,7 +234,6 @@ class WledTargetProcessor(TargetProcessor):
logger.debug("WLED target processor task cancelled") logger.debug("WLED target processor task cancelled")
pass pass
self._task = None self._task = None
await asyncio.sleep(0.05)
# Restore device state (only if auto_shutdown is enabled) # Restore device state (only if auto_shutdown is enabled)
if self._led_client and self._device_state_before: if self._led_client and self._device_state_before:
+12 -2
View File
@@ -9,6 +9,8 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional from typing import List, Optional
from ledgrab.utils import secret_box
def _parse_common(data: dict) -> dict: def _parse_common(data: dict) -> dict:
"""Extract common fields from a dict, parsing timestamps.""" """Extract common fields from a dict, parsing timestamps."""
@@ -52,13 +54,16 @@ class MQTTSource:
icon_color: str = "" icon_color: str = ""
def to_dict(self) -> dict: def to_dict(self) -> dict:
# Always persist the broker password in encrypted envelope form. If
# the field already contains an envelope, ``encrypt()`` is a no-op.
stored_password = secret_box.encrypt(self.password) if self.password else ""
d = { d = {
"id": self.id, "id": self.id,
"name": self.name, "name": self.name,
"broker_host": self.broker_host, "broker_host": self.broker_host,
"broker_port": self.broker_port, "broker_port": self.broker_port,
"username": self.username, "username": self.username,
"password": self.password, "password": stored_password,
"client_id": self.client_id, "client_id": self.client_id,
"base_topic": self.base_topic, "base_topic": self.base_topic,
"description": self.description, "description": self.description,
@@ -75,12 +80,17 @@ class MQTTSource:
@staticmethod @staticmethod
def from_dict(data: dict) -> "MQTTSource": def from_dict(data: dict) -> "MQTTSource":
common = _parse_common(data) common = _parse_common(data)
# Decrypt at load time so consumers see plaintext via ``.password``.
# Legacy plaintext rows pass through unchanged — the next save will
# write them back in encrypted form.
raw_password = data.get("password", "")
password = secret_box.decrypt(raw_password) if raw_password else ""
return MQTTSource( return MQTTSource(
**common, **common,
broker_host=data.get("broker_host", "localhost"), broker_host=data.get("broker_host", "localhost"),
broker_port=int(data.get("broker_port", 1883)), broker_port=int(data.get("broker_port", 1883)),
username=data.get("username", ""), username=data.get("username", ""),
password=data.get("password", ""), password=password,
client_id=data.get("client_id", "ledgrab"), client_id=data.get("client_id", "ledgrab"),
base_topic=data.get("base_topic", "ledgrab"), base_topic=data.get("base_topic", "ledgrab"),
icon=data.get("icon", ""), icon=data.get("icon", ""),
@@ -7,7 +7,7 @@ from typing import List, Optional
from ledgrab.storage.base_sqlite_store import BaseSqliteStore from ledgrab.storage.base_sqlite_store import BaseSqliteStore
from ledgrab.storage.database import Database from ledgrab.storage.database import Database
from ledgrab.storage.mqtt_source import MQTTSource from ledgrab.storage.mqtt_source import MQTTSource
from ledgrab.utils import get_logger from ledgrab.utils import get_logger, secret_box
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -20,6 +20,41 @@ class MQTTSourceStore(BaseSqliteStore[MQTTSource]):
def __init__(self, db: Database): def __init__(self, db: Database):
super().__init__(db, MQTTSource.from_dict) super().__init__(db, MQTTSource.from_dict)
self._migrate_plaintext_passwords()
def _migrate_plaintext_passwords(self) -> None:
"""Encrypt any stored MQTT broker passwords still in plaintext at rest.
Mirrors the HA-token migration: inspect raw DB rows, re-save items
that lack the encryption envelope so the next persistence pass
writes them in encrypted form.
"""
migrated = 0
try:
rows = self._db.load_all(self._table_name)
except Exception as exc:
logger.error("Could not inspect rows for MQTT password migration: %s", exc)
return
for row in rows:
sid = row.get("id")
raw_pw = row.get("password", "") or ""
if not sid or not raw_pw:
continue
if secret_box.is_encrypted(raw_pw):
continue
source = self._items.get(sid)
if source is None:
continue
try:
self._save_item(sid, source)
migrated += 1
except Exception as exc:
logger.error("Failed to migrate MQTT password for %s: %s", sid, exc)
if migrated:
logger.warning(
"MIGRATION: encrypted %d plaintext MQTT broker password(s) at rest.",
migrated,
)
# Backward-compatible aliases # Backward-compatible aliases
get_all_sources = BaseSqliteStore.get_all get_all_sources = BaseSqliteStore.get_all
+4 -1
View File
@@ -1,13 +1,15 @@
"""Utility functions and helpers.""" """Utility functions and helpers."""
from .file_ops import atomic_write_json from .file_ops import atomic_write_json, read_upload_capped
from .logger import setup_logging, get_logger from .logger import setup_logging, get_logger
from .monitor_names import get_monitor_names, get_monitor_name, get_monitor_refresh_rates from .monitor_names import get_monitor_names, get_monitor_name, get_monitor_refresh_rates
from .timer import high_resolution_timer from .timer import high_resolution_timer
from .log_broadcaster import broadcaster as log_broadcaster, install_broadcast_handler from .log_broadcaster import broadcaster as log_broadcaster, install_broadcast_handler
from .url_scheme import infer_http_scheme
__all__ = [ __all__ = [
"atomic_write_json", "atomic_write_json",
"read_upload_capped",
"setup_logging", "setup_logging",
"get_logger", "get_logger",
"get_monitor_names", "get_monitor_names",
@@ -16,4 +18,5 @@ __all__ = [
"high_resolution_timer", "high_resolution_timer",
"log_broadcaster", "log_broadcaster",
"install_broadcast_handler", "install_broadcast_handler",
"infer_http_scheme",
] ]
+30
View File
@@ -5,10 +5,40 @@ import logging
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from fastapi import UploadFile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_UPLOAD_CHUNK_SIZE = 1 * 1024 * 1024 # 1 MiB
async def read_upload_capped(file: "UploadFile", max_size: int) -> bytes:
"""Read an uploaded file in chunks, refusing to buffer past ``max_size``.
A plain ``await file.read()`` will allocate the entire upload body before
any size check runs, so a 1 GB upload still costs 1 GB of memory even
against a 50 MB limit. This helper streams the read and aborts as soon
as the accumulated size exceeds the cap.
Raises:
ValueError: When the upload exceeds *max_size*. The caller decides
how to translate the exception (typically HTTP 400/413).
"""
buf = bytearray()
while True:
chunk = await file.read(_UPLOAD_CHUNK_SIZE)
if not chunk:
break
if len(buf) + len(chunk) > max_size:
raise ValueError(f"Upload exceeds maximum allowed size of {max_size} bytes")
buf.extend(chunk)
return bytes(buf)
def atomic_write_json(file_path: Path, data: dict, indent: int = 2) -> None: def atomic_write_json(file_path: Path, data: dict, indent: int = 2) -> None:
"""Write JSON data to file atomically via temp file + rename. """Write JSON data to file atomically via temp file + rename.
+145 -16
View File
@@ -20,6 +20,8 @@ from urllib.parse import urlparse
import httpx import httpx
from fastapi import HTTPException from fastapi import HTTPException
from ledgrab.utils.net_classify import is_blocked_for_ssrf
# Image file extensions considered safe to serve # Image file extensions considered safe to serve
_IMAGE_EXTENSIONS = frozenset( _IMAGE_EXTENSIONS = frozenset(
@@ -38,19 +40,13 @@ _IMAGE_EXTENSIONS = frozenset(
def _is_blocked_ip(ip: str) -> bool: def _is_blocked_ip(ip: str) -> bool:
"""Return True when *ip* belongs to a category we refuse to fetch from.""" """Return True when *ip* belongs to a category we refuse to fetch from.
try:
addr = ipaddress.ip_address(ip) Thin wrapper preserved for the local test suite. The actual policy now
except ValueError: lives in :func:`ledgrab.utils.net_classify.is_blocked_for_ssrf` so it
return True # unparseable → block can't drift away from the related ``url_scheme`` / ``auth`` modules.
return ( """
addr.is_private return is_blocked_for_ssrf(ip)
or addr.is_loopback
or addr.is_link_local
or addr.is_reserved
or addr.is_multicast
or addr.is_unspecified
)
def _resolve_hostname(hostname: str) -> Iterable[str]: def _resolve_hostname(hostname: str) -> Iterable[str]:
@@ -67,6 +63,11 @@ def _resolve_hostname(hostname: str) -> Iterable[str]:
def validate_image_url(url: str) -> None: def validate_image_url(url: str) -> None:
"""Validate that *url* is safe to fetch. """Validate that *url* is safe to fetch.
Strict SSRF policy — blocks all private / loopback / link-local /
reserved / multicast / unspecified IPs. Intended for fetches of
user-supplied URLs that should reach the *public* internet only
(image sources, weather APIs, etc.).
Checks: Checks:
- Scheme is http/https - Scheme is http/https
- Hostname is present - Hostname is present
@@ -74,6 +75,25 @@ def validate_image_url(url: str) -> None:
- DNS resolves to non-private, non-loopback, non-link-local, non-reserved, - DNS resolves to non-private, non-loopback, non-link-local, non-reserved,
non-multicast addresses (SSRF protection) non-multicast addresses (SSRF protection)
""" """
_validate_url(url, allow_private=False)
def validate_polling_url(url: str) -> None:
"""Validate that *url* is safe to poll from an automation trigger.
Relaxed SSRF policy — allows PRIVATE (LAN) IPs so users can target
devices on their own network (Plex on 192.168.x.x, Home Assistant
on 10.x.x.x). Still blocks LOOPBACK (protects the server's own
admin ports), LINK_LOCAL (blocks 169.254.169.254 cloud-metadata
services), MULTICAST, RESERVED, and UNSPECIFIED.
Use this instead of :func:`validate_image_url` for any user-configured
HTTP polling target.
"""
_validate_url(url, allow_private=True)
def _validate_url(url: str, *, allow_private: bool) -> None:
parsed = urlparse(url) parsed = urlparse(url)
if parsed.scheme not in ("http", "https"): if parsed.scheme not in ("http", "https"):
raise HTTPException( raise HTTPException(
@@ -93,17 +113,41 @@ def validate_image_url(url: str) -> None:
ips = _resolve_hostname(hostname) ips = _resolve_hostname(hostname)
for ip in ips: for ip in ips:
if _is_blocked_ip(ip): if _ip_is_blocked(ip, allow_private=allow_private):
policy = (
"loopback / link-local / reserved / multicast"
if allow_private
else "private / loopback / link-local / reserved / multicast"
)
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=( detail=(
f"Refusing to fetch URL: hostname {hostname!r} resolves to " f"Refusing to fetch URL: hostname {hostname!r} resolves to "
f"blocked address {ip} (private / loopback / link-local / " f"blocked address {ip} ({policy})"
"reserved / multicast)"
), ),
) )
def _ip_is_blocked(ip: str, *, allow_private: bool) -> bool:
"""Block-list predicate parameterised by whether PRIVATE is permitted."""
from ledgrab.utils.net_classify import HostCategory, classify_ip
category = classify_ip(ip)
always_block = {
HostCategory.LOOPBACK,
HostCategory.LINK_LOCAL,
HostCategory.MULTICAST,
HostCategory.RESERVED,
HostCategory.UNSPECIFIED,
HostCategory.UNPARSEABLE,
}
if category in always_block:
return True
if not allow_private and category == HostCategory.PRIVATE:
return True
return False
def validate_image_path(file_path: str | Path) -> Path: def validate_image_path(file_path: str | Path) -> Path:
"""Validate a local file path points to a real image file. """Validate a local file path points to a real image file.
@@ -165,3 +209,88 @@ async def safe_fetch(
response.raise_for_status() response.raise_for_status()
return response return response
raise HTTPException(status_code=400, detail=f"Too many redirects (max {max_hops})") raise HTTPException(status_code=400, detail=f"Too many redirects (max {max_hops})")
# ---------------------------------------------------------------------------
# Bounded safe request — for polling-style fetches that need streaming
# + size cap + arbitrary method/headers. Used by the HTTP poll runtime.
# ---------------------------------------------------------------------------
# Hard cap to avoid OOM from polled endpoints returning huge bodies.
# 1 MiB covers every realistic JSON status endpoint by orders of magnitude.
_MAX_RESPONSE_BYTES = 1 * 1024 * 1024
async def safe_request_bounded(
method: str,
url: str,
*,
headers: dict[str, str] | None = None,
timeout: float = 10.0,
max_bytes: int = _MAX_RESPONSE_BYTES,
allow_private: bool = True,
) -> tuple[int, bytes, str | None]:
"""Make a single HTTP request with SSRF validation and a body-size cap.
Validates the URL up-front (scheme + IP-block) and streams the response
body, aborting as soon as ``max_bytes`` is exceeded so a malicious or
misconfigured endpoint can't OOM the server.
``allow_private`` selects the SSRF policy: ``True`` (default — for HTTP
polling against user-owned LAN devices) routes through
:func:`validate_polling_url`; ``False`` routes through
:func:`validate_image_url`'s strict public-only policy.
Note on DNS rebinding: this validates the URL hostname's resolved IPs
once, then httpx independently re-resolves to make the actual request.
The window between the two resolutions is short but not zero. True
IP-pinning (rewriting the URL to the literal IP + Host header) is a
project-wide concern that would also need to update ``safe_fetch`` and
every other outbound caller; tracked as a follow-up.
Returns:
``(status_code, body_bytes, error_message)``. ``error_message`` is
``None`` on success; on transport failure ``status_code`` is 0 and
``body_bytes`` is empty.
Raises:
HTTPException: when the URL fails validation (4xx-style error
intended for callers that surface this directly to the user).
"""
if allow_private:
validate_polling_url(url)
else:
validate_image_url(url)
method = method.upper()
headers = dict(headers or {})
try:
async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client:
async with client.stream(method, url, headers=headers) as response:
chunks: list[bytes] = []
total = 0
truncated = False
async for chunk in response.aiter_bytes():
if total + len(chunk) > max_bytes:
chunks.append(chunk[: max_bytes - total])
total = max_bytes
truncated = True
break
chunks.append(chunk)
total += len(chunk)
body = b"".join(chunks)
# Truncation is NOT an error — the partial body is returned
# for the caller to use (e.g. evaluate a JSON path against the
# first 1 MiB). Callers that care can detect it by comparing
# ``len(body) == max_bytes``. We only flag actual transport
# failures in the error slot.
_ = truncated
return response.status_code, body, None
except httpx.RequestError as exc:
# Never include `exc` repr in the message — some httpx errors include
# request URL + headers, which could leak auth tokens supplied by the
# caller. Type name is sufficient diagnostic.
return 0, b"", f"Request failed: {type(exc).__name__}"
except Exception as exc:
return 0, b"", f"Unexpected error: {type(exc).__name__}"
+18 -12
View File
@@ -12,22 +12,28 @@ class TestBackupRestoreFlow:
"""A user backs up their configuration and restores it.""" """A user backs up their configuration and restores it."""
def _create_device(self, client, name="Backup Device") -> str: def _create_device(self, client, name="Backup Device") -> str:
resp = client.post("/api/v1/devices", json={ resp = client.post(
"name": name, "/api/v1/devices",
"url": "mock://backup", json={
"device_type": "mock", "name": name,
"led_count": 30, "url": "mock://backup",
}) "device_type": "mock",
"led_count": 30,
},
)
assert resp.status_code == 201 assert resp.status_code == 201
return resp.json()["id"] return resp.json()["id"]
def _create_css(self, client, name="Backup CSS") -> str: def _create_css(self, client, name="Backup CSS") -> str:
resp = client.post("/api/v1/color-strip-sources", json={ resp = client.post(
"name": name, "/api/v1/color-strip-sources",
"source_type": "static", json={
"color": [255, 0, 0], "name": name,
"led_count": 30, "source_type": "single_color",
}) "color": [255, 0, 0],
"led_count": 30,
},
)
assert resp.status_code == 201 assert resp.status_code == 201
return resp.json()["id"] return resp.json()["id"]
+16 -4
View File
@@ -78,15 +78,27 @@ def test_get_displays(client):
def test_openapi_docs(client): def test_openapi_docs(client):
"""Test OpenAPI documentation is available.""" """Test OpenAPI documentation is available to authenticated clients."""
response = client.get("/openapi.json") response = client.get("/openapi.json", headers=AUTH_HEADERS)
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["info"]["version"] == __version__ assert data["info"]["version"] == __version__
def test_swagger_ui(client): def test_swagger_ui(client):
"""Test Swagger UI is available.""" """Test Swagger UI is available to authenticated clients."""
response = client.get("/docs") response = client.get("/docs", headers=AUTH_HEADERS)
assert response.status_code == 200 assert response.status_code == 200
assert "text/html" in response.headers["content-type"] assert "text/html" in response.headers["content-type"]
def test_openapi_docs_requires_auth(client):
"""OpenAPI surface must NOT be reachable without auth (info disclosure)."""
response = client.get("/openapi.json")
assert response.status_code == 401
def test_swagger_ui_requires_auth(client):
"""Swagger UI must NOT be reachable without auth."""
response = client.get("/docs")
assert response.status_code == 401