diff --git a/server/src/ledgrab/utils/net_classify.py b/server/src/ledgrab/utils/net_classify.py new file mode 100644 index 0000000..c54370a --- /dev/null +++ b/server/src/ledgrab/utils/net_classify.py @@ -0,0 +1,136 @@ +"""Single source of truth for IP / hostname categorisation. + +Three modules used to encode their own copy of "is this address local" +logic with subtly different policies: + +* :mod:`ledgrab.utils.safe_source` — *blocks* private/loopback/link-local/ + reserved/multicast hosts for SSRF protection. +* :mod:`ledgrab.utils.url_scheme` — *prefers* ``http://`` for the same + set (private/loopback/link-local) when inferring schemes for WLED. +* :mod:`ledgrab.api.auth` — *exempts* loopback hosts from API-key + enforcement. + +When those three implementations drift, the resulting inconsistencies are +security-relevant (a host the SSRF guard treats as public could be the +same host the auth gate treats as loopback). Centralise the IP-category +predicate here; the callers express their *policy* on top of the shared +enum so future audits only need to look in one place. + +The module is intentionally tiny and pure — no I/O, no logging — so it +can be imported anywhere without dragging dependencies in. +""" + +from __future__ import annotations + +import enum +import ipaddress + + +class HostCategory(str, enum.Enum): + """Mutually-exclusive categorisation of a host literal. + + Use as ``str`` for cheap comparison / logging; downstream callers map + the enum to their own policy (block, prefer-http, allow-anon, …). + """ + + LOOPBACK = "loopback" # 127.0.0.0/8, ::1 + PRIVATE = "private" # RFC1918 + ULA + LINK_LOCAL = "link_local" # 169.254.0.0/16, fe80::/10 + RESERVED = "reserved" # IANA-reserved ranges + MULTICAST = "multicast" # 224.0.0.0/4, ff00::/8 + UNSPECIFIED = "unspecified" # 0.0.0.0, :: + PUBLIC = "public" # routable on the public internet + UNPARSEABLE = "unparseable" # not a literal IP — caller decides + + +def classify_ip(host: str) -> HostCategory: + """Categorise an IP literal *host*. + + Falls back to :data:`HostCategory.UNPARSEABLE` for anything that isn't + a literal address (hostnames, bare labels, garbage). The category is + derived strictly from Python's :mod:`ipaddress` module, so any future + upstream change to the well-known-address registry propagates + automatically. + """ + try: + addr = ipaddress.ip_address(host) + except ValueError: + return HostCategory.UNPARSEABLE + + # Order matters — the categories below are not strictly disjoint in + # Python's predicates (e.g. ``127.0.0.1`` is both ``is_private`` and + # ``is_loopback``), so we test the most specific labels first. + if addr.is_loopback: + return HostCategory.LOOPBACK + if addr.is_link_local: + return HostCategory.LINK_LOCAL + if addr.is_unspecified: + return HostCategory.UNSPECIFIED + if addr.is_multicast: + return HostCategory.MULTICAST + if addr.is_private: + return HostCategory.PRIVATE + if addr.is_reserved: + return HostCategory.RESERVED + return HostCategory.PUBLIC + + +_NON_PUBLIC_FOR_SSRF: frozenset[HostCategory] = frozenset( + { + HostCategory.LOOPBACK, + HostCategory.PRIVATE, + HostCategory.LINK_LOCAL, + HostCategory.RESERVED, + HostCategory.MULTICAST, + HostCategory.UNSPECIFIED, + HostCategory.UNPARSEABLE, + } +) + + +def is_blocked_for_ssrf(host: str) -> bool: + """Return True when *host* must NOT be reached from outbound fetches. + + Used by :mod:`ledgrab.utils.safe_source` to enforce SSRF protection. + The reverse policy lives in :func:`is_local_for_http_default`. + """ + return classify_ip(host) in _NON_PUBLIC_FOR_SSRF + + +_LOCAL_FOR_HTTP_DEFAULT: frozenset[HostCategory] = frozenset( + { + HostCategory.LOOPBACK, + HostCategory.PRIVATE, + HostCategory.LINK_LOCAL, + HostCategory.UNSPECIFIED, + } +) + + +def is_local_for_http_default(host: str) -> bool: + """Return True when *host* should default to ``http://`` for LAN devices. + + Used by :mod:`ledgrab.utils.url_scheme`. Returns False on + :data:`HostCategory.UNPARSEABLE` — callers (e.g. ``url_scheme``) layer + their own mDNS / bare-label heuristics on top. + """ + cat = classify_ip(host) + return cat in _LOCAL_FOR_HTTP_DEFAULT + + +def is_loopback(host: str) -> bool: + """Return True when *host* is a loopback literal. + + Used by :mod:`ledgrab.api.auth` for the "unauthenticated loopback" + policy. Also accepts the textual placeholders ``localhost`` and + Starlette's TestClient marker ``testclient``. + """ + if not host: + return False + h = host.strip().lower() + if h.startswith("[") and h.endswith("]"): + h = h[1:-1] + h = h.split("%", 1)[0] # strip IPv6 zone id + if h in {"localhost", "testclient"}: + return True + return classify_ip(h) is HostCategory.LOOPBACK diff --git a/server/src/ledgrab/utils/url_scheme.py b/server/src/ledgrab/utils/url_scheme.py new file mode 100644 index 0000000..faf64c6 --- /dev/null +++ b/server/src/ledgrab/utils/url_scheme.py @@ -0,0 +1,150 @@ +"""Helpers for inferring URL schemes from user-supplied device addresses. + +When a user types a bare host (``192.168.1.42``, ``wled-desk.local``, or just +``wled``) we need to pick an HTTP scheme on their behalf. Local-network +targets are addressed over ``http://`` and public hostnames over ``https://``, +matching how WLED is actually reachable in practice. + +Intended for HTTP-style local devices (e.g. WLED). Do not call this helper +for protocols that aren't HTTP — serial, OpenRGB, MQTT, etc. each carry +their own scheme and would be silently passed through unchanged but the +helper would still strip whitespace and react to user input in ways those +callers don't want. + +This module shares its host-classification surface with +:mod:`ledgrab.utils.safe_source`, which applies the *opposite* policy +(blocking private/loopback/link-local for SSRF protection). Keep the +predicates in sync if either side is updated. +""" + +from __future__ import annotations + +import ipaddress +import re + +from ledgrab.utils.net_classify import is_local_for_http_default + + +# RFC 3986 scheme grammar: ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) followed +# by ":". Used to recognise URIs like ``data:...``, ``javascript:...``, +# ``mailto:...`` that omit the ``//`` authority component — we must not +# coerce those into ``http(s)://data:...``. +_URI_SCHEME_RE = re.compile(r"^([A-Za-z][A-Za-z0-9+\-.]*):(.*)$") + + +_LOCAL_HOSTNAME_SUFFIXES: tuple[str, ...] = ( + ".local", + ".lan", + ".home", + ".internal", +) + +# Characters that must never appear inside a bare host (C0 + DEL control +# characters, plus the userinfo separator ``@``). Anything containing these +# is treated as unparseable so the downstream validator reports a clean +# error instead of httpx silently chewing on a malformed string. +_FORBIDDEN_HOST_CHARS: frozenset[str] = frozenset(chr(c) for c in range(0x00, 0x20)) | frozenset( + {chr(0x7F), "@"} +) + + +def _extract_hostname(host_part: str) -> str: + """Pull just the hostname out of ``host[:port]`` (or ``[ipv6]:port``). + + Returns the empty string when *host_part* contains forbidden characters + so the caller can route the input untouched and let the next layer + surface a precise validation error. + """ + if any(ch in _FORBIDDEN_HOST_CHARS for ch in host_part): + return "" + if host_part.startswith("["): + end = host_part.find("]") + if end != -1: + return host_part[1:end] + return host_part[1:] + # Bracketless IPv6 literal (``fe80::1``) — has multiple ``:`` so the + # ``host:port`` shortcut below would misparse it. Probe ipaddress first. + try: + ipaddress.ip_address(host_part) + return host_part + except ValueError: + pass + if host_part.count(":") == 1: + return host_part.split(":", 1)[0] + return host_part + + +def _is_local_host(host: str) -> bool: + """Return True when *host* refers to a private/loopback/link-local target. + + A dotless label (``wled``, ``kitchen-strip``) is treated as local because + those map to mDNS / NetBIOS resolution on a LAN. This is a deliberately + permissive heuristic appropriate for a LAN-only LED controller, not a + general-purpose host classifier. + + IP-literal classification is delegated to + :func:`ledgrab.utils.net_classify.is_local_for_http_default` so this + module can't drift away from the SSRF / auth predicates that share the + same address taxonomy. + """ + host = host.strip().lower() + if not host: + return False + if host == "localhost": + return True + if host.endswith(_LOCAL_HOSTNAME_SUFFIXES): + return True + try: + ipaddress.ip_address(host) + except ValueError: + # Bare label with no dot ⇒ treat as a local mDNS-style hostname. + return "." not in host + return is_local_for_http_default(host) + + +def infer_http_scheme(url: str | None) -> str | None: + """Return *url* with an inferred ``http(s)://`` scheme when none is present. + + Local hosts (private/loopback/link-local IPs, ``localhost``, ``.local`` + style hostnames, and bare single-label names) get ``http://``; everything + else gets ``https://``. Inputs that already carry a scheme — including + non-HTTP ones like ``ws://`` or ``openrgb://`` — are returned unchanged + apart from surrounding whitespace. + + Callers may pass ``None`` (e.g. an absent ``url`` field on a partial + ``DeviceUpdate``); the value round-trips untouched so the request + validator can complain with the right error. + """ + if not url: + return url + trimmed = url.strip() + if not trimmed: + return trimmed + if trimmed.lower().startswith(("http://", "https://")): + return trimmed + # Any other explicit scheme with ``//`` authority: leave the caller in + # control (ws://, openrgb://, file://, gopher://, …). + if "://" in trimmed: + return trimmed + host_part, _, _rest = trimmed.partition("/") + # Authority-less URI schemes (``data:``, ``javascript:``, ``mailto:``): + # the part after the ``:`` is not a port number. We must NOT coerce + # these into ``http://data:...`` — that would produce an attacker- + # controlled fetch URL. ``localhost:8080`` / ``192.168.1.1:8080`` still + # work because the second group is all digits (a port), not a path. + # IPv6 literals (e.g. ``fe80::1``) syntactically resemble ``scheme:rest`` + # — probe ipaddress first to disambiguate before applying the regex. + try: + ipaddress.ip_address(host_part) + except ValueError: + scheme_match = _URI_SCHEME_RE.match(host_part) + if scheme_match and not scheme_match.group(2).isdigit(): + return trimmed + hostname = _extract_hostname(host_part) + if not hostname: + # Forbidden characters or a host we can't safely classify — return + # the input unchanged so the downstream validator can surface a + # precise error rather than guessing a scheme. + return trimmed + scheme = "http" if _is_local_host(hostname) else "https" + return f"{scheme}://{trimmed}" diff --git a/server/tests/test_net_classify.py b/server/tests/test_net_classify.py new file mode 100644 index 0000000..cd26326 --- /dev/null +++ b/server/tests/test_net_classify.py @@ -0,0 +1,137 @@ +"""Tests for ledgrab.utils.net_classify.""" + +import pytest + +from ledgrab.utils.net_classify import ( + HostCategory, + classify_ip, + is_blocked_for_ssrf, + is_local_for_http_default, + is_loopback, +) + + +@pytest.mark.parametrize( + "host,expected", + [ + ("127.0.0.1", HostCategory.LOOPBACK), + ("::1", HostCategory.LOOPBACK), + ("10.0.0.5", HostCategory.PRIVATE), + ("192.168.1.1", HostCategory.PRIVATE), + ("172.16.0.5", HostCategory.PRIVATE), + ("fd00::1", HostCategory.PRIVATE), # ULA + ("169.254.1.1", HostCategory.LINK_LOCAL), + ("fe80::1", HostCategory.LINK_LOCAL), + ("0.0.0.0", HostCategory.UNSPECIFIED), + ("::", HostCategory.UNSPECIFIED), + ("224.0.0.1", HostCategory.MULTICAST), + ("ff00::1", HostCategory.MULTICAST), + # ``240.0.0.0/4`` (class E) is labelled both ``is_private`` and + # ``is_reserved`` by Python; we keep PRIVATE first which is the + # stricter SSRF policy. + ("240.0.0.1", HostCategory.PRIVATE), + ("8.8.8.8", HostCategory.PUBLIC), + ("1.1.1.1", HostCategory.PUBLIC), + ("2606:4700:4700::1111", HostCategory.PUBLIC), + ("not-an-ip", HostCategory.UNPARSEABLE), + ("", HostCategory.UNPARSEABLE), + ("example.com", HostCategory.UNPARSEABLE), + ], +) +def test_classify_ip(host: str, expected: HostCategory) -> None: + assert classify_ip(host) is expected + + +# --------------------------------------------------------------------------- +# SSRF block list — must include EVERY non-public category, including +# unparseable inputs. Regression guard: if anyone narrows this set we lose +# SSRF protection. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "host", + [ + "127.0.0.1", + "::1", + "10.0.0.5", + "192.168.1.1", + "172.16.0.5", + "fd00::1", + "169.254.1.1", + "fe80::1", + "0.0.0.0", + "::", + "224.0.0.1", + "ff00::1", + "240.0.0.1", + "not-an-ip", # unparseable → blocked + "", # unparseable → blocked + ], +) +def test_is_blocked_for_ssrf_blocks_non_public(host: str) -> None: + assert is_blocked_for_ssrf(host) is True + + +@pytest.mark.parametrize( + "host", + ["8.8.8.8", "1.1.1.1", "2606:4700:4700::1111"], +) +def test_is_blocked_for_ssrf_allows_public(host: str) -> None: + assert is_blocked_for_ssrf(host) is False + + +# --------------------------------------------------------------------------- +# LAN-default policy — narrower than SSRF: we infer ``http://`` for loopback +# / private / link-local / unspecified, NOT for multicast / reserved / +# unparseable (those should fall through to ``https://`` or to caller-side +# heuristics like mDNS suffix matching). +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "host", + ["127.0.0.1", "::1", "10.0.0.5", "192.168.1.1", "fe80::1", "0.0.0.0"], +) +def test_is_local_for_http_default_true(host: str) -> None: + assert is_local_for_http_default(host) is True + + +@pytest.mark.parametrize( + "host", + ["8.8.8.8", "224.0.0.1", "not-an-ip", ""], +) +def test_is_local_for_http_default_false(host: str) -> None: + assert is_local_for_http_default(host) is False + + +# --------------------------------------------------------------------------- +# Loopback predicate — accepts both literals and the auth module's textual +# placeholders (localhost, testclient). +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "host", + [ + "127.0.0.1", + "127.0.0.5", + "::1", + "localhost", + "LOCALHOST", + "testclient", + "[::1]", + "fe80::1%eth0", # ← link-local with zone — must NOT match + ], +) +def test_is_loopback_recognises_loopback(host: str) -> None: + expected = host != "fe80::1%eth0" + assert is_loopback(host) is expected + + +@pytest.mark.parametrize( + "host", + ["8.8.8.8", "10.0.0.5", "example.com", "", None], +) +def test_is_loopback_rejects_other(host) -> None: + assert is_loopback(host) is False diff --git a/server/tests/test_url_scheme.py b/server/tests/test_url_scheme.py new file mode 100644 index 0000000..d546a2d --- /dev/null +++ b/server/tests/test_url_scheme.py @@ -0,0 +1,134 @@ +"""Tests for ledgrab.utils.url_scheme.infer_http_scheme.""" + +import pytest + +from ledgrab.utils.url_scheme import infer_http_scheme + + +@pytest.mark.parametrize( + "url", + [ + "http://192.168.1.10", + "https://wled.example.com", + "HTTP://wled.local", + "HTTPS://example.com/api", + "ws://device.local", + "openrgb://localhost:6742/0", + ], +) +def test_preserves_existing_scheme(url): + assert infer_http_scheme(url) == url + + +@pytest.mark.parametrize( + "raw", + [ + "192.168.1.10", + "192.168.1.10:8080", + "10.0.0.5", + "172.16.1.1", + "127.0.0.1", + "localhost", + "localhost:8080", + "wled-desk.local", + "wled-desk.local:80", + "wled", # bare label ⇒ mDNS-style + "kitchen-strip", + "[::1]", + "[fe80::1]:80", + "fe80::1", # bracketless link-local IPv6 + "device.lan", + "rack.home", + "service.internal", + ], +) +def test_local_targets_get_http(raw): + assert infer_http_scheme(raw) == f"http://{raw}" + + +@pytest.mark.parametrize( + "raw", + [ + "example.com", + "wled.example.com", + "wled.example.com:443", + "wled.example.com/api", + "1.2.3.4", # public IPv4 + "8.8.8.8:80", + "my-host.io/path?x=1", + ], +) +def test_external_targets_get_https(raw): + assert infer_http_scheme(raw) == f"https://{raw}" + + +def test_trims_whitespace_before_inference(): + assert infer_http_scheme(" 192.168.0.1 ") == "http://192.168.0.1" + assert infer_http_scheme(" example.com ") == "https://example.com" + + +def test_empty_string_returns_unchanged(): + assert infer_http_scheme("") == "" + + +def test_none_returns_unchanged(): + # Callers occasionally hand us None; preserve it so the validator can complain. + assert infer_http_scheme(None) is None + + +def test_whitespace_only_collapses_to_empty(): + # Whitespace alone has no host to infer a scheme for — trim and bail out. + assert infer_http_scheme(" ") == "" + + +# --------------------------------------------------------------------------- +# Malicious / hostile inputs — must round-trip *unchanged* (no scheme +# coerced onto them) so the downstream validator surfaces a clean error +# rather than letting a coerced scheme slip past as a "valid" URL. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "raw", + [ + "javascript:alert(1)", # already has a scheme — must pass through + "data:text/html,