diff --git a/server/src/ledgrab/api/auth.py b/server/src/ledgrab/api/auth.py index ec2cc02..3d10745 100644 --- a/server/src/ledgrab/api/auth.py +++ b/server/src/ledgrab/api/auth.py @@ -11,14 +11,13 @@ from starlette.websockets import WebSocket, WebSocketDisconnect from ledgrab.config import get_config from ledgrab.utils import get_logger +from ledgrab.utils.net_classify import is_loopback as _classify_is_loopback logger = get_logger(__name__) # Security scheme for Bearer token security = HTTPBearer(auto_error=False) -_LOOPBACK_HOSTS = frozenset({"127.0.0.1", "::1", "localhost", "testclient"}) - def is_auth_enabled() -> bool: """Return True when at least one API key is configured.""" @@ -26,15 +25,15 @@ def is_auth_enabled() -> bool: def _is_loopback(host: str | None) -> bool: - """Return True when *host* is a loopback address.""" + """Return True when *host* is a loopback address. + + Delegates to :func:`ledgrab.utils.net_classify.is_loopback` so this + auth gate, the SSRF guard in ``safe_source``, and the LAN-default + inference in ``url_scheme`` share one classification source. + """ if not host: return False - # Strip IPv6 brackets and zone IDs - h = host.strip().lower() - if h.startswith("[") and h.endswith("]"): - h = h[1:-1] - h = h.split("%", 1)[0] - return h in _LOOPBACK_HOSTS + return _classify_is_loopback(host) def verify_api_key( @@ -142,6 +141,23 @@ def require_authenticated(label: str) -> None: WS_AUTH_CLOSE_CODE = 4401 +WS_ORIGIN_CLOSE_CODE = 4403 +"""Close code sent when a WebSocket request fails the Origin allowlist.""" + + +def _is_origin_allowed(origin: str | None, allowed: list[str]) -> bool: + """Return True when *origin* matches one of the configured CORS origins. + + Non-browser clients (Python scripts, curl) don't send Origin — those are + allowed through; the Bearer-token check on the auth handshake is the + primary defence in that case. Browsers always set Origin, so this only + blocks cross-site WebSocket connection attempts (CSWSH). + """ + if not origin: + return True + return origin in set(allowed or []) + + async def accept_and_authenticate_ws(websocket: WebSocket, timeout: float = 3.0) -> str | None: """Accept the WebSocket, then perform first-message auth handshake. @@ -152,6 +168,23 @@ async def accept_and_authenticate_ws(websocket: WebSocket, timeout: float = 3.0) Returns the caller label on success, ``None`` on failure (connection already closed). """ + # Reject cross-site WebSocket attempts before accepting — a browser-based + # attacker page cannot forge the Origin header, so an Origin mismatch is + # a strong signal even before the token check. Non-browser clients + # legitimately omit Origin; those fall through to the auth handshake. + config = get_config() + origin = websocket.headers.get("origin") + if not _is_origin_allowed(origin, config.server.cors_origins): + logger.warning( + "Rejected WebSocket from origin %r (not in cors_origins)", + origin, + ) + try: + await websocket.close(code=WS_ORIGIN_CLOSE_CODE) + except Exception: + pass + return None + await websocket.accept() label = await verify_ws_auth(websocket, timeout=timeout) if label is None: diff --git a/server/src/ledgrab/api/dependencies.py b/server/src/ledgrab/api/dependencies.py index 7ace5af..1718632 100644 --- a/server/src/ledgrab/api/dependencies.py +++ b/server/src/ledgrab/api/dependencies.py @@ -37,6 +37,7 @@ from ledgrab.storage.game_integration_store import GameIntegrationStore from ledgrab.core.game_integration.event_bus import GameEventBus from ledgrab.storage.mqtt_source_store import MQTTSourceStore from ledgrab.core.mqtt.mqtt_manager import MQTTManager +from ledgrab.storage.http_endpoint_store import HTTPEndpointStore from ledgrab.storage.audio_processing_template_store import AudioProcessingTemplateStore from ledgrab.storage.pattern_template_store import PatternTemplateStore @@ -165,6 +166,10 @@ def get_mqtt_manager() -> MQTTManager: return _get("mqtt_manager", "MQTT manager") +def get_http_endpoint_store() -> HTTPEndpointStore: + return _get("http_endpoint_store", "HTTP endpoint store") + + def get_audio_processing_template_store() -> AudioProcessingTemplateStore: return _get("audio_processing_template_store", "Audio processing template store") @@ -237,6 +242,7 @@ def init_dependencies( game_event_bus: GameEventBus | None = None, mqtt_store: MQTTSourceStore | None = None, mqtt_manager: MQTTManager | None = None, + http_endpoint_store: HTTPEndpointStore | None = None, audio_processing_template_store: AudioProcessingTemplateStore | None = None, pattern_template_store: PatternTemplateStore | None = None, ): @@ -272,6 +278,7 @@ def init_dependencies( "game_event_bus": game_event_bus, "mqtt_store": mqtt_store, "mqtt_manager": mqtt_manager, + "http_endpoint_store": http_endpoint_store, "audio_processing_template_store": audio_processing_template_store, "pattern_template_store": pattern_template_store, } diff --git a/server/src/ledgrab/api/routes/assets.py b/server/src/ledgrab/api/routes/assets.py index 8cf6fed..70ec093 100644 --- a/server/src/ledgrab/api/routes/assets.py +++ b/server/src/ledgrab/api/routes/assets.py @@ -15,7 +15,7 @@ from ledgrab.api.schemas.assets import ( from ledgrab.config import get_config from ledgrab.storage.asset_store import AssetStore from ledgrab.storage.base_store import EntityNotFoundError -from ledgrab.utils import get_logger +from ledgrab.utils import get_logger, read_upload_capped logger = get_logger(__name__) @@ -93,10 +93,11 @@ async def upload_asset( config = get_config() max_size = getattr(getattr(config, "assets", None), "max_file_size_mb", 50) * 1024 * 1024 - data = await file.read() - if len(data) > max_size: + try: + data = await read_upload_capped(file, max_size) + except ValueError: raise HTTPException( - status_code=400, + status_code=413, detail=f"File too large (max {max_size // (1024 * 1024)} MB)", ) diff --git a/server/src/ledgrab/api/routes/backup.py b/server/src/ledgrab/api/routes/backup.py index 80639f8..678eefc 100644 --- a/server/src/ledgrab/api/routes/backup.py +++ b/server/src/ledgrab/api/routes/backup.py @@ -28,7 +28,7 @@ from ledgrab.config import get_config from ledgrab.core.backup.auto_backup import AutoBackupEngine from ledgrab.storage.asset_store import AssetStore from ledgrab.storage.database import Database, freeze_writes -from ledgrab.utils import get_logger +from ledgrab.utils import get_logger, read_upload_capped logger = get_logger(__name__) @@ -133,9 +133,11 @@ async def restore_config( because restore replaces all configuration including secrets). """ require_authenticated(auth) - raw = await file.read() - if len(raw) > 200 * 1024 * 1024: # 200 MB limit (ZIP may contain assets) - raise HTTPException(status_code=400, detail="Backup file too large (max 200 MB)") + _MAX_BACKUP_BYTES = 200 * 1024 * 1024 # 200 MB (ZIP may contain assets) + try: + raw = await read_upload_capped(file, _MAX_BACKUP_BYTES) + except ValueError: + raise HTTPException(status_code=413, detail="Backup file too large (max 200 MB)") if len(raw) < 100: raise HTTPException(status_code=400, detail="File too small to be a valid backup") diff --git a/server/src/ledgrab/api/routes/webhooks.py b/server/src/ledgrab/api/routes/webhooks.py index bccd717..d12966e 100644 --- a/server/src/ledgrab/api/routes/webhooks.py +++ b/server/src/ledgrab/api/routes/webhooks.py @@ -30,6 +30,9 @@ _RATE_WINDOW = 60.0 # seconds _rate_hits: dict[str, list[float]] = defaultdict(list) +_RATE_HITS_HARD_CAP = 1024 + + def _check_rate_limit(client_ip: str) -> None: """Raise 429 if *client_ip* exceeded the webhook rate limit.""" now = time.time() @@ -44,11 +47,21 @@ def _check_rate_limit(client_ip: str) -> None: ) _rate_hits[client_ip].append(now) - # Periodic cleanup: remove IPs with no recent hits to prevent unbounded growth + # Periodic cleanup: remove IPs with no recent hits to prevent unbounded growth. if len(_rate_hits) > 100: stale = [ip for ip, ts in _rate_hits.items() if not ts or ts[-1] < window_start] for ip in stale: del _rate_hits[ip] + # Hard cap as a final defence against an attacker spraying many distinct + # X-Forwarded-For values to drive memory growth past the soft cleanup + # threshold. Drop the oldest-touched IPs (by their latest timestamp). + if len(_rate_hits) > _RATE_HITS_HARD_CAP: + ordered = sorted( + _rate_hits.items(), + key=lambda kv: kv[1][-1] if kv[1] else 0.0, + ) + for ip, _ in ordered[: len(ordered) - _RATE_HITS_HARD_CAP]: + _rate_hits.pop(ip, None) class WebhookPayload(BaseModel): diff --git a/server/src/ledgrab/core/capture/screen_capture.py b/server/src/ledgrab/core/capture/screen_capture.py index df3788c..6f45181 100644 --- a/server/src/ledgrab/core/capture/screen_capture.py +++ b/server/src/ledgrab/core/capture/screen_capture.py @@ -14,6 +14,12 @@ from ledgrab.utils import get_logger, get_monitor_names, get_monitor_refresh_rat logger = get_logger(__name__) +# Reused random Generator for sampling. The legacy ``np.random.randint`` +# uses the module-level RandomState which is slightly slower per-call and +# pulls in extra import-time work; a single Generator is faster and avoids +# global-state surprises. +_rng = np.random.default_rng() + @dataclass class DisplayInfo: @@ -326,7 +332,11 @@ def calculate_dominant_color(pixels: np.ndarray) -> tuple[int, int, int]: max_samples = 1000 if n > max_samples: - indices = np.random.randint(0, n, max_samples) + # ``Generator.integers`` writes into a fresh buffer once per call; + # the legacy ``np.random.randint`` did the same plus extra + # bookkeeping. Random (not stride) sampling stays robust against + # periodic patterns in screen pixels. + indices = _rng.integers(0, n, size=max_samples) pixels_reshaped = pixels_reshaped[indices] # Quantize to 32 levels/channel (drop low 3 bits) and pack into uint32: diff --git a/server/src/ledgrab/core/devices/discovery_watcher.py b/server/src/ledgrab/core/devices/discovery_watcher.py index 394beca..782242e 100644 --- a/server/src/ledgrab/core/devices/discovery_watcher.py +++ b/server/src/ledgrab/core/devices/discovery_watcher.py @@ -85,6 +85,10 @@ class DiscoveryWatcher: self._wled_seen: Dict[str, _DiscoveredEntry] = {} # device-path -> entry. Only the serial poller mutates this. self._serial_seen: Dict[str, _DiscoveredEntry] = {} + # Strong references for fire-and-forget resolve tasks — without + # these, Python 3.11+ may GC the task mid-resolve and silently lose + # discovery events. Tasks remove themselves on completion. + self._resolve_tasks: "set[asyncio.Task]" = set() # --- lifecycle -------------------------------------------------------- @@ -155,12 +159,23 @@ class DiscoveryWatcher: if state_change in (ServiceStateChange.Added, ServiceStateChange.Updated): # Resolve in a task — async_request blocks the handler if awaited # synchronously and we don't want to stall mDNS dispatch. - asyncio.create_task(self._resolve_wled(service_type, name)) + task = asyncio.create_task(self._resolve_wled(service_type, name)) + self._resolve_tasks.add(task) + task.add_done_callback(self._on_resolve_done) elif state_change == ServiceStateChange.Removed: entry = self._wled_seen.pop(name, None) if entry is not None and not self._is_configured(entry.url): self._emit("device_lost", entry) + def _on_resolve_done(self, task: "asyncio.Task") -> None: + """Release the strong task reference and log any resolve failure.""" + self._resolve_tasks.discard(task) + if task.cancelled(): + return + exc = task.exception() + if exc is not None: + logger.debug("Discovery watcher: resolve task raised: %s", exc) + async def _resolve_wled(self, service_type: str, name: str) -> None: if self._aiozc is None: return diff --git a/server/src/ledgrab/core/devices/wled_client.py b/server/src/ledgrab/core/devices/wled_client.py index 4f6ec5a..cf4aa8b 100644 --- a/server/src/ledgrab/core/devices/wled_client.py +++ b/server/src/ledgrab/core/devices/wled_client.py @@ -453,8 +453,15 @@ class WLEDClient(LEDClient): ], } - logger.debug(f"Sending {len(pixels)} LEDs via HTTP ({len(indexed_pixels)} values)") - logger.debug(f"Payload size: ~{len(str(payload))} bytes") + # ``str(payload)`` previously stringified the entire indexed + # array on every send to report a byte estimate; that was the + # hot path. Drop the size readout — the LED count + indexed + # value count is enough to interpret traffic and is O(1). + logger.debug( + "Sending %d LEDs via HTTP (%d indexed values)", + len(pixels), + len(indexed_pixels), + ) await self._request("POST", "/json/state", json_data=payload) logger.debug("Successfully sent pixel colors via HTTP") diff --git a/server/src/ledgrab/core/devices/wled_provider.py b/server/src/ledgrab/core/devices/wled_provider.py index 039753b..144add3 100644 --- a/server/src/ledgrab/core/devices/wled_provider.py +++ b/server/src/ledgrab/core/devices/wled_provider.py @@ -98,11 +98,18 @@ class WLEDDeviceProvider(LEDDeviceProvider): dict with 'led_count' key. Raises: + ValueError: Unsupported scheme or invalid LED count. httpx.ConnectError: Device unreachable. httpx.TimeoutException: Connection timed out. - ValueError: Invalid LED count. """ url = _normalize_url(url) + # Reject anything that isn't plain HTTP(S). url_scheme.infer_http_scheme + # passes non-HTTP schemes through untouched ("javascript:", "file:", + # "data:", etc.); without this guard those would reach httpx and + # surface as opaque transport errors at best, or be silently misused + # at worst. + if not url.lower().startswith(("http://", "https://")): + raise ValueError(f"WLED URL must use http:// or https:// scheme (got {url!r})") async with httpx.AsyncClient(timeout=5) as client: response = await client.get(_join(url, "/json/info")) response.raise_for_status() diff --git a/server/src/ledgrab/core/mqtt/mqtt_runtime.py b/server/src/ledgrab/core/mqtt/mqtt_runtime.py index 78f32b7..a9f4601 100644 --- a/server/src/ledgrab/core/mqtt/mqtt_runtime.py +++ b/server/src/ledgrab/core/mqtt/mqtt_runtime.py @@ -49,6 +49,32 @@ class MQTTRuntime: # Pending publishes queued while disconnected self._publish_queue: asyncio.Queue = asyncio.Queue(maxsize=1000) + # Strong references for fire-and-forget callback dispatch tasks. + # Python 3.11+ may GC bare ``asyncio.create_task(...)`` results mid- + # flight, so we hold each task until it completes and surface any + # exception via the done-callback. + self._dispatch_tasks: Set[asyncio.Task] = set() + + # Compiled ``aiomqtt.Topic`` cache, keyed by the subscription pattern + # string. The previous dispatch loop re-parsed every pattern on + # every incoming message — on a chatty broker with many wildcards + # that adds up fast. + self._compiled_topics: Dict[str, aiomqtt.Topic] = {} + + def _on_dispatch_done(self, task: asyncio.Task) -> None: + """Drop the strong reference and surface any callback exception.""" + self._dispatch_tasks.discard(task) + if task.cancelled(): + return + exc = task.exception() + if exc is not None: + logger.error( + "MQTT async callback raised (%s): %s", + self._source_id, + exc, + exc_info=exc, + ) + @property def is_connected(self) -> bool: return self._connected @@ -84,6 +110,14 @@ class MQTTRuntime: logger.debug("MQTT runtime task cancelled: %s", self._source_id) self._task = None self._connected = False + # Cancel any in-flight async dispatch callbacks. Without this they + # would keep running past the runtime's logical end of life and + # could fire callbacks on a stopped subscriber. + for task in list(self._dispatch_tasks): + task.cancel() + if self._dispatch_tasks: + await asyncio.gather(*self._dispatch_tasks, return_exceptions=True) + self._dispatch_tasks.clear() logger.info("MQTT runtime stopped: %s", self._source_id) def update_config(self, source: MQTTSource) -> None: @@ -167,13 +201,23 @@ class MQTTRuntime: for topic in self._subscriptions: await client.subscribe(topic) - # Drain pending publishes + # Drain pending publishes. A single publish failing + # (broker rejection, oversized message) must not lose + # the rest of the queue — log and continue with the next. while not self._publish_queue.empty(): try: t, p, r, q = self._publish_queue.get_nowait() - await client.publish(t, p, retain=r, qos=q) - except Exception: + except asyncio.QueueEmpty: break + try: + await client.publish(t, p, retain=r, qos=q) + except Exception as exc: + logger.warning( + "MQTT drain publish failed (%s -> %s): %s", + self._source_id, + t, + exc, + ) # Message receive loop async for msg in client.messages: @@ -183,13 +227,21 @@ class MQTTRuntime: ) self._topic_cache[topic_str] = payload_str - # Dispatch to callbacks + # Dispatch to callbacks. Pattern objects are cached + # per subscription string to avoid re-parsing them on + # every received message. for sub_topic, callbacks in self._subscriptions.items(): - if aiomqtt.Topic(sub_topic).matches(msg.topic): + compiled = self._compiled_topics.get(sub_topic) + if compiled is None: + compiled = aiomqtt.Topic(sub_topic) + self._compiled_topics[sub_topic] = compiled + if compiled.matches(msg.topic): for cb in callbacks: try: if asyncio.iscoroutinefunction(cb): - asyncio.create_task(cb(topic_str, payload_str)) + task = asyncio.create_task(cb(topic_str, payload_str)) + self._dispatch_tasks.add(task) + task.add_done_callback(self._on_dispatch_done) else: cb(topic_str, payload_str) except Exception as e: diff --git a/server/src/ledgrab/core/processing/mapped_stream.py b/server/src/ledgrab/core/processing/mapped_stream.py index e219e61..9e29d10 100644 --- a/server/src/ledgrab/core/processing/mapped_stream.py +++ b/server/src/ledgrab/core/processing/mapped_stream.py @@ -2,6 +2,7 @@ import threading import time +from collections import OrderedDict from typing import Dict, List, Optional import numpy as np @@ -11,6 +12,11 @@ from ledgrab.utils import get_logger logger = get_logger(__name__) +# Cap the (src_len, dst_len) resize cache. Each entry holds two ``np.linspace`` +# arrays plus a per-zone ``uint8`` scratch buffer, which used to grow without +# bound under hot reconfigure storms. +_RESIZE_CACHE_MAX = 16 + class MappedColorStripStream(ColorStripStream): """Places multiple ColorStripStreams side-by-side at distinct LED ranges. @@ -46,8 +52,11 @@ class MappedColorStripStream(ColorStripStream): # zone_index -> (source_id, consumer_id, stream) self._sub_streams: Dict[int, tuple] = {} - # (src_len, dst_len) -> (src_x, dst_x, buffer) cache for zone resizing - self._resize_cache: Dict[tuple, tuple] = {} + # (src_len, dst_len) -> (src_x, dst_x, buffer) cache for zone resizing. + # An ``OrderedDict`` with eviction keeps memory bounded if device + # configurations fluctuate at runtime (each unique pair adds two + # linspace arrays + a per-zone reusable uint8 buffer). + self._resize_cache: "OrderedDict[tuple, tuple]" = OrderedDict() self._sub_lock = threading.Lock() # guards _sub_streams access across threads # ── ColorStripStream interface ────────────────────────────── @@ -229,6 +238,14 @@ class MappedColorStripStream(ColorStripStream): np.empty((zone_len, 3), dtype=np.uint8), ) self._resize_cache[rkey] = cached + # Drop the least-recently-inserted entry once + # we hit the cap. 16 entries comfortably covers + # any realistic zone/source layout — pathological + # reconfigure storms used to grow this forever. + if len(self._resize_cache) > _RESIZE_CACHE_MAX: + self._resize_cache.popitem(last=False) + else: + self._resize_cache.move_to_end(rkey) src_x, dst_x, resized = cached for ch in range(3): np.copyto( diff --git a/server/src/ledgrab/core/processing/processed_stream.py b/server/src/ledgrab/core/processing/processed_stream.py index 6689e95..17e3e09 100644 --- a/server/src/ledgrab/core/processing/processed_stream.py +++ b/server/src/ledgrab/core/processing/processed_stream.py @@ -160,9 +160,11 @@ class ProcessedColorStripStream(ColorStripStream): self._resolve_count = 0 self._resolve_filters() - colors = None - if self._input_stream: - colors = self._input_stream.get_latest_colors() + # Bind to a local first — ``update_source()`` may swap or null + # out ``_input_stream`` between the check and the read on a + # different thread. + inp = self._input_stream + colors = inp.get_latest_colors() if inp is not None else None if colors is not None and self._filters: for flt in self._filters: diff --git a/server/src/ledgrab/core/processing/processor_manager.py b/server/src/ledgrab/core/processing/processor_manager.py index d73bec0..773e99f 100644 --- a/server/src/ledgrab/core/processing/processor_manager.py +++ b/server/src/ledgrab/core/processing/processor_manager.py @@ -38,6 +38,7 @@ from ledgrab.storage.picture_source_store import PictureSourceStore from ledgrab.storage.postprocessing_template_store import PostprocessingTemplateStore from ledgrab.storage.template_store import TemplateStore from ledgrab.storage.value_source_store import ValueSourceStore +from ledgrab.storage.http_endpoint_store import HTTPEndpointStore from ledgrab.storage.asset_store import AssetStore from ledgrab.core.processing.sync_clock_manager import SyncClockManager from ledgrab.core.weather.weather_manager import WeatherManager @@ -74,6 +75,7 @@ class ProcessorDependencies: mqtt_manager: Optional[Any] = None # MQTTManager game_event_bus: Optional[Any] = None # GameEventBus audio_processing_template_store: Optional[Any] = None # AudioProcessingTemplateStore + http_endpoint_store: Optional[HTTPEndpointStore] = None @dataclass @@ -169,6 +171,7 @@ class ProcessorManager(AutoRestartMixin, DeviceHealthMixin, DeviceTestModeMixin) event_bus=deps.game_event_bus, audio_processing_template_store=deps.audio_processing_template_store, sync_clock_manager=deps.sync_clock_manager, + http_endpoint_store=deps.http_endpoint_store, ) if deps.value_source_store else None diff --git a/server/src/ledgrab/core/processing/video_stream.py b/server/src/ledgrab/core/processing/video_stream.py index c9b7100..b311c25 100644 --- a/server/src/ledgrab/core/processing/video_stream.py +++ b/server/src/ledgrab/core/processing/video_stream.py @@ -37,6 +37,33 @@ def is_youtube_url(url: str) -> bool: return any(p.search(url) for p in _YT_PATTERNS) +# Network schemes accepted by ``cv2.VideoCapture``. ``file://`` is intentionally +# excluded — local files are supported via plain path strings (no scheme), so +# explicit ``file://`` requests can only be an attempt to coerce FFmpeg into +# loading something the path-string code path would reject. ``concat:``, +# ``gopher://``, ``crypto:``, etc. are not allowed. +_ALLOWED_NETWORK_SCHEMES: tuple[str, ...] = ("http", "https", "rtsp", "rtsps") + + +def _assert_video_url_allowed(url: str) -> None: + """Reject video URLs that use anything other than a vetted scheme. + + OpenCV/FFmpeg supports many esoteric input protocols (``concat:``, + ``gopher://``, ``crypto:``, ``udp://``, ``async:``, …). Some can read + arbitrary host files or pivot to internal addresses when the caller + can influence the URL. Tighten the input to the schemes we actually + advertise. URLs without a scheme are accepted as local-file paths. + """ + if "://" not in url: + return # plain local path — OpenCV resolves against the working dir + scheme = url.split("://", 1)[0].lower() + if scheme not in _ALLOWED_NETWORK_SCHEMES: + raise RuntimeError( + f"Refusing to open video with unsupported scheme {scheme!r}; " + f"allowed: {', '.join(_ALLOWED_NETWORK_SCHEMES)} or a local file path." + ) + + def resolve_youtube_url(url: str, resolution_limit: Optional[int] = None) -> str: """Resolve a YouTube URL to a direct stream URL using yt-dlp.""" try: @@ -185,10 +212,14 @@ class VideoCaptureLiveStream(LiveStream): if self._running: return - # Resolve YouTube URL if needed + # Resolve YouTube URL if needed. Validate AFTER resolution too, so a + # malicious yt-dlp result (or a redirect we don't expect) can't slip + # through with an unsupported scheme. actual_url = self._original_url + _assert_video_url_allowed(actual_url) if is_youtube_url(actual_url): actual_url = resolve_youtube_url(actual_url, self._resolution_limit) + _assert_video_url_allowed(actual_url) self._resolved_url = actual_url # Open capture diff --git a/server/src/ledgrab/core/processing/wled_target_processor.py b/server/src/ledgrab/core/processing/wled_target_processor.py index c3b5f41..4551fb1 100644 --- a/server/src/ledgrab/core/processing/wled_target_processor.py +++ b/server/src/ledgrab/core/processing/wled_target_processor.py @@ -224,7 +224,8 @@ class WledTargetProcessor(TargetProcessor): self._is_running = False - # Cancel task + # Cancel task. The cancellation is awaited above, so the prior + # 50 ms ``asyncio.sleep`` here was pure dead time on every stop(). if self._task: self._task.cancel() try: @@ -233,7 +234,6 @@ class WledTargetProcessor(TargetProcessor): logger.debug("WLED target processor task cancelled") pass self._task = None - await asyncio.sleep(0.05) # Restore device state (only if auto_shutdown is enabled) if self._led_client and self._device_state_before: diff --git a/server/src/ledgrab/storage/mqtt_source.py b/server/src/ledgrab/storage/mqtt_source.py index 0f5d4c4..ddfbec7 100644 --- a/server/src/ledgrab/storage/mqtt_source.py +++ b/server/src/ledgrab/storage/mqtt_source.py @@ -9,6 +9,8 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from typing import List, Optional +from ledgrab.utils import secret_box + def _parse_common(data: dict) -> dict: """Extract common fields from a dict, parsing timestamps.""" @@ -52,13 +54,16 @@ class MQTTSource: icon_color: str = "" def to_dict(self) -> dict: + # Always persist the broker password in encrypted envelope form. If + # the field already contains an envelope, ``encrypt()`` is a no-op. + stored_password = secret_box.encrypt(self.password) if self.password else "" d = { "id": self.id, "name": self.name, "broker_host": self.broker_host, "broker_port": self.broker_port, "username": self.username, - "password": self.password, + "password": stored_password, "client_id": self.client_id, "base_topic": self.base_topic, "description": self.description, @@ -75,12 +80,17 @@ class MQTTSource: @staticmethod def from_dict(data: dict) -> "MQTTSource": common = _parse_common(data) + # Decrypt at load time so consumers see plaintext via ``.password``. + # Legacy plaintext rows pass through unchanged — the next save will + # write them back in encrypted form. + raw_password = data.get("password", "") + password = secret_box.decrypt(raw_password) if raw_password else "" return MQTTSource( **common, broker_host=data.get("broker_host", "localhost"), broker_port=int(data.get("broker_port", 1883)), username=data.get("username", ""), - password=data.get("password", ""), + password=password, client_id=data.get("client_id", "ledgrab"), base_topic=data.get("base_topic", "ledgrab"), icon=data.get("icon", ""), diff --git a/server/src/ledgrab/storage/mqtt_source_store.py b/server/src/ledgrab/storage/mqtt_source_store.py index fdabcac..2f6b3a3 100644 --- a/server/src/ledgrab/storage/mqtt_source_store.py +++ b/server/src/ledgrab/storage/mqtt_source_store.py @@ -7,7 +7,7 @@ from typing import List, Optional from ledgrab.storage.base_sqlite_store import BaseSqliteStore from ledgrab.storage.database import Database from ledgrab.storage.mqtt_source import MQTTSource -from ledgrab.utils import get_logger +from ledgrab.utils import get_logger, secret_box logger = get_logger(__name__) @@ -20,6 +20,41 @@ class MQTTSourceStore(BaseSqliteStore[MQTTSource]): def __init__(self, db: Database): super().__init__(db, MQTTSource.from_dict) + self._migrate_plaintext_passwords() + + def _migrate_plaintext_passwords(self) -> None: + """Encrypt any stored MQTT broker passwords still in plaintext at rest. + + Mirrors the HA-token migration: inspect raw DB rows, re-save items + that lack the encryption envelope so the next persistence pass + writes them in encrypted form. + """ + migrated = 0 + try: + rows = self._db.load_all(self._table_name) + except Exception as exc: + logger.error("Could not inspect rows for MQTT password migration: %s", exc) + return + for row in rows: + sid = row.get("id") + raw_pw = row.get("password", "") or "" + if not sid or not raw_pw: + continue + if secret_box.is_encrypted(raw_pw): + continue + source = self._items.get(sid) + if source is None: + continue + try: + self._save_item(sid, source) + migrated += 1 + except Exception as exc: + logger.error("Failed to migrate MQTT password for %s: %s", sid, exc) + if migrated: + logger.warning( + "MIGRATION: encrypted %d plaintext MQTT broker password(s) at rest.", + migrated, + ) # Backward-compatible aliases get_all_sources = BaseSqliteStore.get_all diff --git a/server/src/ledgrab/utils/__init__.py b/server/src/ledgrab/utils/__init__.py index 42466a5..43b1b65 100644 --- a/server/src/ledgrab/utils/__init__.py +++ b/server/src/ledgrab/utils/__init__.py @@ -1,13 +1,15 @@ """Utility functions and helpers.""" -from .file_ops import atomic_write_json +from .file_ops import atomic_write_json, read_upload_capped from .logger import setup_logging, get_logger from .monitor_names import get_monitor_names, get_monitor_name, get_monitor_refresh_rates from .timer import high_resolution_timer from .log_broadcaster import broadcaster as log_broadcaster, install_broadcast_handler +from .url_scheme import infer_http_scheme __all__ = [ "atomic_write_json", + "read_upload_capped", "setup_logging", "get_logger", "get_monitor_names", @@ -16,4 +18,5 @@ __all__ = [ "high_resolution_timer", "log_broadcaster", "install_broadcast_handler", + "infer_http_scheme", ] diff --git a/server/src/ledgrab/utils/file_ops.py b/server/src/ledgrab/utils/file_ops.py index 0d7edf1..9e2cb8d 100644 --- a/server/src/ledgrab/utils/file_ops.py +++ b/server/src/ledgrab/utils/file_ops.py @@ -5,10 +5,40 @@ import logging import os import tempfile from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fastapi import UploadFile logger = logging.getLogger(__name__) +_UPLOAD_CHUNK_SIZE = 1 * 1024 * 1024 # 1 MiB + + +async def read_upload_capped(file: "UploadFile", max_size: int) -> bytes: + """Read an uploaded file in chunks, refusing to buffer past ``max_size``. + + A plain ``await file.read()`` will allocate the entire upload body before + any size check runs, so a 1 GB upload still costs 1 GB of memory even + against a 50 MB limit. This helper streams the read and aborts as soon + as the accumulated size exceeds the cap. + + Raises: + ValueError: When the upload exceeds *max_size*. The caller decides + how to translate the exception (typically HTTP 400/413). + """ + buf = bytearray() + while True: + chunk = await file.read(_UPLOAD_CHUNK_SIZE) + if not chunk: + break + if len(buf) + len(chunk) > max_size: + raise ValueError(f"Upload exceeds maximum allowed size of {max_size} bytes") + buf.extend(chunk) + return bytes(buf) + + def atomic_write_json(file_path: Path, data: dict, indent: int = 2) -> None: """Write JSON data to file atomically via temp file + rename. diff --git a/server/src/ledgrab/utils/safe_source.py b/server/src/ledgrab/utils/safe_source.py index 8f4a012..2dfeef8 100644 --- a/server/src/ledgrab/utils/safe_source.py +++ b/server/src/ledgrab/utils/safe_source.py @@ -20,6 +20,8 @@ from urllib.parse import urlparse import httpx from fastapi import HTTPException +from ledgrab.utils.net_classify import is_blocked_for_ssrf + # Image file extensions considered safe to serve _IMAGE_EXTENSIONS = frozenset( @@ -38,19 +40,13 @@ _IMAGE_EXTENSIONS = frozenset( def _is_blocked_ip(ip: str) -> bool: - """Return True when *ip* belongs to a category we refuse to fetch from.""" - try: - addr = ipaddress.ip_address(ip) - except ValueError: - return True # unparseable → block - return ( - addr.is_private - or addr.is_loopback - or addr.is_link_local - or addr.is_reserved - or addr.is_multicast - or addr.is_unspecified - ) + """Return True when *ip* belongs to a category we refuse to fetch from. + + Thin wrapper preserved for the local test suite. The actual policy now + lives in :func:`ledgrab.utils.net_classify.is_blocked_for_ssrf` so it + can't drift away from the related ``url_scheme`` / ``auth`` modules. + """ + return is_blocked_for_ssrf(ip) def _resolve_hostname(hostname: str) -> Iterable[str]: @@ -67,6 +63,11 @@ def _resolve_hostname(hostname: str) -> Iterable[str]: def validate_image_url(url: str) -> None: """Validate that *url* is safe to fetch. + Strict SSRF policy — blocks all private / loopback / link-local / + reserved / multicast / unspecified IPs. Intended for fetches of + user-supplied URLs that should reach the *public* internet only + (image sources, weather APIs, etc.). + Checks: - Scheme is http/https - Hostname is present @@ -74,6 +75,25 @@ def validate_image_url(url: str) -> None: - DNS resolves to non-private, non-loopback, non-link-local, non-reserved, non-multicast addresses (SSRF protection) """ + _validate_url(url, allow_private=False) + + +def validate_polling_url(url: str) -> None: + """Validate that *url* is safe to poll from an automation trigger. + + Relaxed SSRF policy — allows PRIVATE (LAN) IPs so users can target + devices on their own network (Plex on 192.168.x.x, Home Assistant + on 10.x.x.x). Still blocks LOOPBACK (protects the server's own + admin ports), LINK_LOCAL (blocks 169.254.169.254 cloud-metadata + services), MULTICAST, RESERVED, and UNSPECIFIED. + + Use this instead of :func:`validate_image_url` for any user-configured + HTTP polling target. + """ + _validate_url(url, allow_private=True) + + +def _validate_url(url: str, *, allow_private: bool) -> None: parsed = urlparse(url) if parsed.scheme not in ("http", "https"): raise HTTPException( @@ -93,17 +113,41 @@ def validate_image_url(url: str) -> None: ips = _resolve_hostname(hostname) for ip in ips: - if _is_blocked_ip(ip): + if _ip_is_blocked(ip, allow_private=allow_private): + policy = ( + "loopback / link-local / reserved / multicast" + if allow_private + else "private / loopback / link-local / reserved / multicast" + ) raise HTTPException( status_code=400, detail=( f"Refusing to fetch URL: hostname {hostname!r} resolves to " - f"blocked address {ip} (private / loopback / link-local / " - "reserved / multicast)" + f"blocked address {ip} ({policy})" ), ) +def _ip_is_blocked(ip: str, *, allow_private: bool) -> bool: + """Block-list predicate parameterised by whether PRIVATE is permitted.""" + from ledgrab.utils.net_classify import HostCategory, classify_ip + + category = classify_ip(ip) + always_block = { + HostCategory.LOOPBACK, + HostCategory.LINK_LOCAL, + HostCategory.MULTICAST, + HostCategory.RESERVED, + HostCategory.UNSPECIFIED, + HostCategory.UNPARSEABLE, + } + if category in always_block: + return True + if not allow_private and category == HostCategory.PRIVATE: + return True + return False + + def validate_image_path(file_path: str | Path) -> Path: """Validate a local file path points to a real image file. @@ -165,3 +209,88 @@ async def safe_fetch( response.raise_for_status() return response raise HTTPException(status_code=400, detail=f"Too many redirects (max {max_hops})") + + +# --------------------------------------------------------------------------- +# Bounded safe request — for polling-style fetches that need streaming +# + size cap + arbitrary method/headers. Used by the HTTP poll runtime. +# --------------------------------------------------------------------------- + + +# Hard cap to avoid OOM from polled endpoints returning huge bodies. +# 1 MiB covers every realistic JSON status endpoint by orders of magnitude. +_MAX_RESPONSE_BYTES = 1 * 1024 * 1024 + + +async def safe_request_bounded( + method: str, + url: str, + *, + headers: dict[str, str] | None = None, + timeout: float = 10.0, + max_bytes: int = _MAX_RESPONSE_BYTES, + allow_private: bool = True, +) -> tuple[int, bytes, str | None]: + """Make a single HTTP request with SSRF validation and a body-size cap. + + Validates the URL up-front (scheme + IP-block) and streams the response + body, aborting as soon as ``max_bytes`` is exceeded so a malicious or + misconfigured endpoint can't OOM the server. + + ``allow_private`` selects the SSRF policy: ``True`` (default — for HTTP + polling against user-owned LAN devices) routes through + :func:`validate_polling_url`; ``False`` routes through + :func:`validate_image_url`'s strict public-only policy. + + Note on DNS rebinding: this validates the URL hostname's resolved IPs + once, then httpx independently re-resolves to make the actual request. + The window between the two resolutions is short but not zero. True + IP-pinning (rewriting the URL to the literal IP + Host header) is a + project-wide concern that would also need to update ``safe_fetch`` and + every other outbound caller; tracked as a follow-up. + + Returns: + ``(status_code, body_bytes, error_message)``. ``error_message`` is + ``None`` on success; on transport failure ``status_code`` is 0 and + ``body_bytes`` is empty. + + Raises: + HTTPException: when the URL fails validation (4xx-style error + intended for callers that surface this directly to the user). + """ + if allow_private: + validate_polling_url(url) + else: + validate_image_url(url) + method = method.upper() + headers = dict(headers or {}) + + try: + async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client: + async with client.stream(method, url, headers=headers) as response: + chunks: list[bytes] = [] + total = 0 + truncated = False + async for chunk in response.aiter_bytes(): + if total + len(chunk) > max_bytes: + chunks.append(chunk[: max_bytes - total]) + total = max_bytes + truncated = True + break + chunks.append(chunk) + total += len(chunk) + body = b"".join(chunks) + # Truncation is NOT an error — the partial body is returned + # for the caller to use (e.g. evaluate a JSON path against the + # first 1 MiB). Callers that care can detect it by comparing + # ``len(body) == max_bytes``. We only flag actual transport + # failures in the error slot. + _ = truncated + return response.status_code, body, None + except httpx.RequestError as exc: + # Never include `exc` repr in the message — some httpx errors include + # request URL + headers, which could leak auth tokens supplied by the + # caller. Type name is sufficient diagnostic. + return 0, b"", f"Request failed: {type(exc).__name__}" + except Exception as exc: + return 0, b"", f"Unexpected error: {type(exc).__name__}" diff --git a/server/tests/e2e/test_backup_flow.py b/server/tests/e2e/test_backup_flow.py index 31f7c51..ada889d 100644 --- a/server/tests/e2e/test_backup_flow.py +++ b/server/tests/e2e/test_backup_flow.py @@ -12,22 +12,28 @@ class TestBackupRestoreFlow: """A user backs up their configuration and restores it.""" def _create_device(self, client, name="Backup Device") -> str: - resp = client.post("/api/v1/devices", json={ - "name": name, - "url": "mock://backup", - "device_type": "mock", - "led_count": 30, - }) + resp = client.post( + "/api/v1/devices", + json={ + "name": name, + "url": "mock://backup", + "device_type": "mock", + "led_count": 30, + }, + ) assert resp.status_code == 201 return resp.json()["id"] def _create_css(self, client, name="Backup CSS") -> str: - resp = client.post("/api/v1/color-strip-sources", json={ - "name": name, - "source_type": "static", - "color": [255, 0, 0], - "led_count": 30, - }) + resp = client.post( + "/api/v1/color-strip-sources", + json={ + "name": name, + "source_type": "single_color", + "color": [255, 0, 0], + "led_count": 30, + }, + ) assert resp.status_code == 201 return resp.json()["id"] diff --git a/server/tests/test_api.py b/server/tests/test_api.py index 3926f0e..91d0d5e 100644 --- a/server/tests/test_api.py +++ b/server/tests/test_api.py @@ -78,15 +78,27 @@ def test_get_displays(client): def test_openapi_docs(client): - """Test OpenAPI documentation is available.""" - response = client.get("/openapi.json") + """Test OpenAPI documentation is available to authenticated clients.""" + response = client.get("/openapi.json", headers=AUTH_HEADERS) assert response.status_code == 200 data = response.json() assert data["info"]["version"] == __version__ def test_swagger_ui(client): - """Test Swagger UI is available.""" - response = client.get("/docs") + """Test Swagger UI is available to authenticated clients.""" + response = client.get("/docs", headers=AUTH_HEADERS) assert response.status_code == 200 assert "text/html" in response.headers["content-type"] + + +def test_openapi_docs_requires_auth(client): + """OpenAPI surface must NOT be reachable without auth (info disclosure).""" + response = client.get("/openapi.json") + assert response.status_code == 401 + + +def test_swagger_ui_requires_auth(client): + """Swagger UI must NOT be reachable without auth.""" + response = client.get("/docs") + assert response.status_code == 401