diff --git a/.gitea/workflows/build.yml b/.gitea/workflows/build.yml index 56b8a5d..fcbe1c1 100644 --- a/.gitea/workflows/build.yml +++ b/.gitea/workflows/build.yml @@ -1,13 +1,75 @@ -name: Build Docker Image +name: Build and Test on: + push: + branches: [master, main] + pull_request: + branches: [master, main] workflow_dispatch: jobs: - build: + test-backend: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Build Docker image - run: docker build -t notify-bridge:dev . + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install core + server + dev deps + run: | + python -m pip install --upgrade pip + python -m pip install -e ./packages/core + python -m pip install -e "./packages/server[dev]" + + - name: Run pytest (server) + run: | + cd packages/server + pytest -q --maxfail=1 + + test-frontend: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Node + uses: actions/setup-node@v4 + with: + node-version: "22" + cache: "npm" + cache-dependency-path: frontend/package-lock.json + + - name: Install deps + run: | + cd frontend + npm ci + + - name: Svelte check + run: | + cd frontend + npm run check || echo "::warning::svelte-check reported warnings" + + - name: Build + run: | + cd frontend + npm run build + + build-image: + needs: [test-backend, test-frontend] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build Docker image (no push) + uses: docker/build-push-action@v5 + with: + context: . + push: false + tags: notify-bridge:ci-${{ gitea.sha }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.gitea/workflows/release.yml b/.gitea/workflows/release.yml index a6891fa..a907977 100644 --- a/.gitea/workflows/release.yml +++ b/.gitea/workflows/release.yml @@ -10,7 +10,22 @@ env: IMAGE_NAME: alexei.dolgolyov/notify-bridge jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install + test + run: | + python -m pip install --upgrade pip + python -m pip install -e ./packages/core + python -m pip install -e "./packages/server[dev]" + cd packages/server && pytest -q --maxfail=1 + release: + needs: test runs-on: ubuntu-latest steps: - name: Checkout repo @@ -44,16 +59,28 @@ jobs: - name: Build and push Docker image uses: docker/build-push-action@v5 + id: docker_build with: context: . push: true tags: | ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.tag }} ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.version }} + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${{ gitea.sha }} ${{ steps.version.outputs.is_pre == 'false' && format('{0}/{1}:latest', env.REGISTRY, env.IMAGE_NAME) || '' }} cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max + - name: Vulnerability scan (trivy) + uses: aquasecurity/trivy-action@master + continue-on-error: true + with: + image-ref: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.tag }} + format: table + exit-code: 0 + severity: HIGH,CRITICAL + ignore-unfixed: true + - name: Trigger redeploy webhook if: steps.version.outputs.is_pre == 'false' continue-on-error: true diff --git a/docker-compose.yml b/docker-compose.yml index 4a9d7b3..339e904 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,18 +10,38 @@ services: volumes: - notify-bridge-data:/data environment: + # REQUIRED — any 32+ byte random string. `openssl rand -hex 32` is one way. - NOTIFY_BRIDGE_SECRET_KEY=${NOTIFY_BRIDGE_SECRET_KEY:?Set NOTIFY_BRIDGE_SECRET_KEY (min 32 chars)} - - NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS=${NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS:-*} - # Homelab target: allow outbound requests to RFC1918 / link-local addresses. - # The SSRF guard otherwise rejects 10.*/172.16.*/192.168.*/169.254.* hosts, - # which breaks tracking of Immich / Gitea / etc. running on the same LAN. - - NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1 + # Comma-separated list of allowed browser origins. Wildcard `*` is + # rejected on startup because credentials are enabled. + - NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS=${NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS:-http://localhost:8420} + # Trusted proxy IPs whose X-Forwarded-For / X-Forwarded-Proto we honor. + # Set this to your reverse proxy's IP (e.g. 172.17.0.1 for the default + # docker bridge, or `*` only if the container is NOT reachable from the + # public internet). + - NOTIFY_BRIDGE_FORWARDED_ALLOW_IPS=${NOTIFY_BRIDGE_FORWARDED_ALLOW_IPS:-127.0.0.1} + # Opt-in SSRF bypass for private/loopback/link-local hosts (homelab + # scenario — tracking an Immich/Gitea instance on the same LAN). DO NOT + # enable on a publicly exposed instance. + # - NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1 healthcheck: - test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8420/api/health')"] + # Use /api/ready (not /api/health) so the container is only reported + # healthy after migrations and the scheduler finish booting. + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8420/api/ready', timeout=3)"] interval: 30s timeout: 5s retries: 3 - start_period: 10s + start_period: 30s + read_only: true + tmpfs: + - /tmp + security_opt: + - no-new-privileges:true + cap_drop: + - ALL + mem_limit: 512m + cpus: 1.0 + pids_limit: 256 volumes: notify-bridge-data: diff --git a/frontend/src/lib/i18n/en.json b/frontend/src/lib/i18n/en.json index 0ff5703..ba41a39 100644 --- a/frontend/src/lib/i18n/en.json +++ b/frontend/src/lib/i18n/en.json @@ -55,7 +55,8 @@ "passwordTooShort": "Password must be at least 8 characters", "or": "or", "loginFailed": "Login failed", - "setupFailed": "Setup failed" + "setupFailed": "Setup failed", + "backendUnreachable": "Cannot reach the server. Check that it's running and try again." }, "dashboard": { "title": "Dashboard", diff --git a/frontend/src/lib/i18n/ru.json b/frontend/src/lib/i18n/ru.json index d1dde07..ee439a0 100644 --- a/frontend/src/lib/i18n/ru.json +++ b/frontend/src/lib/i18n/ru.json @@ -55,7 +55,8 @@ "passwordTooShort": "Пароль должен быть не менее 8 символов", "or": "или", "loginFailed": "Ошибка входа", - "setupFailed": "Ошибка настройки" + "setupFailed": "Ошибка настройки", + "backendUnreachable": "Не удалось подключиться к серверу. Убедитесь, что он запущен, и повторите попытку." }, "dashboard": { "title": "Главная", diff --git a/frontend/src/routes/login/+page.svelte b/frontend/src/routes/login/+page.svelte index 1528fe9..bff7ec2 100644 --- a/frontend/src/routes/login/+page.svelte +++ b/frontend/src/routes/login/+page.svelte @@ -15,13 +15,32 @@ let submitting = $state(false); let mounted = $state(false); + let backendDown = $state(false); + onMount(async () => { initTheme(); mounted = true; + // If the user is already signed in (valid access token in storage), + // there is no reason to show them the login form. loadUser() runs in + // the root layout; we just check the resolved state after a short tick. + const { isAuthenticated } = await import('$lib/api'); + if (isAuthenticated()) { + try { + await api('/auth/me'); + goto('/'); + return; + } catch { + // Token was stale; fall through to the login form. + } + } try { const res = await api<{ needs_setup: boolean }>('/auth/needs-setup'); if (res.needs_setup) goto('/setup'); - } catch { /* ignore */ } + } catch { + // The backend is unreachable — surface that distinctly so the user + // doesn't blame the login form for a network/backend problem. + backendDown = true; + } }); async function handleSubmit(e: SubmitEvent) { @@ -62,7 +81,12 @@

{t('auth.signInTitle')}

- {#if error} + {#if backendDown} +
+ + {t('auth.backendUnreachable')} +
+ {:else if error}
{error} diff --git a/packages/core/src/notify_bridge_core/notifications/discord/client.py b/packages/core/src/notify_bridge_core/notifications/discord/client.py index 20dbc2e..cd0ea24 100644 --- a/packages/core/src/notify_bridge_core/notifications/discord/client.py +++ b/packages/core/src/notify_bridge_core/notifications/discord/client.py @@ -52,22 +52,46 @@ class DiscordClient: return {"success": True} + _MAX_RETRIES = 3 + _MAX_RETRY_AFTER = 60.0 + async def _post(self, url: str, payload: dict) -> dict[str, Any]: - try: - async with self._session.post( - url, json=payload, headers={"Content-Type": "application/json"} - ) as resp: - if resp.status == 429: - retry_after = float(resp.headers.get("Retry-After", "2")) - _LOGGER.warning("Discord rate limited, retrying after %.1fs", retry_after) - await asyncio.sleep(retry_after) - return await self._post(url, payload) - if 200 <= resp.status < 300: - return {"success": True} - body = await resp.text() - return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"} - except aiohttp.ClientError as e: - return {"success": False, "error": str(e)} + """POST with bounded 429 retry. + + We cap retries at _MAX_RETRIES and the ``Retry-After`` header at + _MAX_RETRY_AFTER seconds so a hostile or misbehaving upstream cannot + pin the dispatch task indefinitely. + """ + for attempt in range(self._MAX_RETRIES + 1): + try: + async with self._session.post( + url, + json=payload, + headers={"Content-Type": "application/json"}, + allow_redirects=False, + ) as resp: + if resp.status == 429 and attempt < self._MAX_RETRIES: + try: + retry_after = float(resp.headers.get("Retry-After", "2")) + except (TypeError, ValueError): + retry_after = 2.0 + retry_after = max(0.0, min(retry_after, self._MAX_RETRY_AFTER)) + _LOGGER.warning( + "Discord rate limited, retrying after %.1fs (attempt %d/%d)", + retry_after, attempt + 1, self._MAX_RETRIES, + ) + await asyncio.sleep(retry_after) + continue + if 200 <= resp.status < 300: + return {"success": True} + body = await resp.text() + return { + "success": False, + "error": f"HTTP {resp.status}: {body[:200]}", + } + except aiohttp.ClientError as e: + return {"success": False, "error": str(e)} + return {"success": False, "error": "Rate limited (retries exhausted)"} def _split_message(text: str, limit: int) -> list[str]: diff --git a/packages/core/src/notify_bridge_core/notifications/dispatcher.py b/packages/core/src/notify_bridge_core/notifications/dispatcher.py index c7aecc0..71c0091 100644 --- a/packages/core/src/notify_bridge_core/notifications/dispatcher.py +++ b/packages/core/src/notify_bridge_core/notifications/dispatcher.py @@ -3,10 +3,11 @@ from __future__ import annotations import asyncio +import contextlib import logging import uuid from dataclasses import dataclass, field -from typing import Any +from typing import Any, AsyncIterator import aiohttp @@ -14,7 +15,7 @@ from notify_bridge_core.log_context import bind_log_context, dispatch_id_var from notify_bridge_core.models.events import ServiceEvent from notify_bridge_core.templates.context import build_template_context from notify_bridge_core.templates.renderer import render_template -from .ssrf import UnsafeURLError, validate_outbound_url +from .ssrf import UnsafeURLError, avalidate_outbound_url _HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30) @@ -84,9 +85,28 @@ class NotificationDispatcher: *, url_cache: TelegramFileCache | None = None, asset_cache: TelegramFileCache | None = None, + session: aiohttp.ClientSession | None = None, ) -> None: self._url_cache = url_cache self._asset_cache = asset_cache + # Optional shared session owned by the caller; when supplied we reuse + # its connection pool instead of opening a fresh per-dispatch session + # (saves a TLS handshake per outbound call). + self._shared_session = session + + @contextlib.asynccontextmanager + async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]: + """Yield an aiohttp session, reusing the shared one if provided. + + When a shared session was passed in ``__init__`` we yield it without + closing (the caller owns its lifetime). Otherwise we open a + short-lived session with our default timeout and close it on exit. + """ + if self._shared_session is not None and not self._shared_session.closed: + yield self._shared_session + return + async with self._session_ctx() as session: + yield session async def dispatch( self, @@ -308,7 +328,7 @@ class NotificationDispatcher: media_assets.append(asset) results: list[dict[str, Any]] = [] - async with _new_session() as session: + async with self._session_ctx() as session: # Preload all asset bytes once so (a) TelegramClient can skip its # own download and (b) we know exact upload sizes in time for the # oversize warning in the rendered text. @@ -378,13 +398,13 @@ class NotificationDispatcher: return {"success": False, "error": "No receivers configured"} results: list[dict[str, Any]] = [] - async with _new_session() as session: + async with self._session_ctx() as session: for receiver in target.receivers: if not isinstance(receiver, WebhookReceiver) or not receiver.url: results.append({"success": False, "error": "Invalid webhook receiver"}) continue try: - validate_outbound_url(receiver.url) + await avalidate_outbound_url(receiver.url) except UnsafeURLError as err: results.append({"success": False, "error": f"Unsafe URL: {err}"}) continue @@ -452,14 +472,14 @@ class NotificationDispatcher: username = target.config.get("username") results: list[dict[str, Any]] = [] - async with _new_session() as session: + async with self._session_ctx() as session: client = DiscordClient(session) for receiver in target.receivers: if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url: results.append({"success": False, "error": "Invalid discord receiver"}) continue try: - validate_outbound_url(receiver.webhook_url) + await avalidate_outbound_url(receiver.webhook_url) except UnsafeURLError as err: results.append({"success": False, "error": f"Unsafe URL: {err}"}) continue @@ -478,14 +498,14 @@ class NotificationDispatcher: username = target.config.get("username") results: list[dict[str, Any]] = [] - async with _new_session() as session: + async with self._session_ctx() as session: client = SlackClient(session) for receiver in target.receivers: if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url: results.append({"success": False, "error": "Invalid slack receiver"}) continue try: - validate_outbound_url(receiver.webhook_url) + await avalidate_outbound_url(receiver.webhook_url) except UnsafeURLError as err: results.append({"success": False, "error": f"Unsafe URL: {err}"}) continue @@ -504,14 +524,14 @@ class NotificationDispatcher: if not target.receivers: return {"success": False, "error": "No receivers configured"} try: - validate_outbound_url(server_url) + await avalidate_outbound_url(server_url) except UnsafeURLError as err: return {"success": False, "error": f"Unsafe ntfy server_url: {err}"} title = f"{event.event_type.value}: {event.collection_name}" results: list[dict[str, Any]] = [] - async with _new_session() as session: + async with self._session_ctx() as session: client = NtfyClient(session) for receiver in target.receivers: if not isinstance(receiver, NtfyReceiver) or not receiver.topic: @@ -535,7 +555,7 @@ class NotificationDispatcher: if not homeserver or not access_token: return {"success": False, "error": "Missing Matrix homeserver_url or access_token"} try: - validate_outbound_url(homeserver) + await avalidate_outbound_url(homeserver) except UnsafeURLError as err: return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"} @@ -543,7 +563,7 @@ class NotificationDispatcher: return {"success": False, "error": "No receivers configured"} results: list[dict[str, Any]] = [] - async with _new_session() as session: + async with self._session_ctx() as session: client = MatrixClient(session, homeserver, access_token) for receiver in target.receivers: if not isinstance(receiver, MatrixReceiver) or not receiver.room_id: diff --git a/packages/core/src/notify_bridge_core/notifications/matrix/client.py b/packages/core/src/notify_bridge_core/notifications/matrix/client.py index f7c06f3..224239f 100644 --- a/packages/core/src/notify_bridge_core/notifications/matrix/client.py +++ b/packages/core/src/notify_bridge_core/notifications/matrix/client.py @@ -68,7 +68,9 @@ class MatrixClient: } try: - async with self._session.put(url, json=body, headers=headers) as resp: + async with self._session.put( + url, json=body, headers=headers, allow_redirects=False, + ) as resp: if 200 <= resp.status < 300: return {"success": True} resp_body = await resp.text() diff --git a/packages/core/src/notify_bridge_core/notifications/ntfy/client.py b/packages/core/src/notify_bridge_core/notifications/ntfy/client.py index 41717bf..d5be0f9 100644 --- a/packages/core/src/notify_bridge_core/notifications/ntfy/client.py +++ b/packages/core/src/notify_bridge_core/notifications/ntfy/client.py @@ -51,7 +51,9 @@ class NtfyClient: headers["Authorization"] = f"Bearer {auth_token}" try: - async with self._session.post(url, json=payload, headers=headers) as resp: + async with self._session.post( + url, json=payload, headers=headers, allow_redirects=False, + ) as resp: if 200 <= resp.status < 300: return {"success": True} body = await resp.text() diff --git a/packages/core/src/notify_bridge_core/notifications/slack/client.py b/packages/core/src/notify_bridge_core/notifications/slack/client.py index ea985ca..681b286 100644 --- a/packages/core/src/notify_bridge_core/notifications/slack/client.py +++ b/packages/core/src/notify_bridge_core/notifications/slack/client.py @@ -38,6 +38,7 @@ class SlackClient: webhook_url, json=payload, headers={"Content-Type": "application/json"}, + allow_redirects=False, ) as resp: if resp.status == 429: _LOGGER.warning("Slack rate limited") diff --git a/packages/core/src/notify_bridge_core/notifications/ssrf.py b/packages/core/src/notify_bridge_core/notifications/ssrf.py index e0902bb..66aea5a 100644 --- a/packages/core/src/notify_bridge_core/notifications/ssrf.py +++ b/packages/core/src/notify_bridge_core/notifications/ssrf.py @@ -12,14 +12,25 @@ development against localhost services. from __future__ import annotations +import asyncio import ipaddress +import logging import os import socket from urllib.parse import urlparse +_LOGGER = logging.getLogger(__name__) + _ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1" _ALLOWED_SCHEMES = {"http", "https"} +if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner + _LOGGER.warning( + "SSRF guard: private-URL bypass ENABLED " + "(NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1). Requests to RFC1918 / " + "loopback / link-local hosts will be permitted." + ) + class UnsafeURLError(ValueError): """Raised when a URL targets a disallowed network destination.""" @@ -36,13 +47,7 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: ) -def validate_outbound_url(url: str) -> str: - """Validate ``url`` is safe to fetch; returns the URL on success. - - Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP - is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``) - private addresses are permitted but the scheme check still applies. - """ +def _check_scheme_host(url: str) -> tuple[str, str]: if not isinstance(url, str) or not url: raise UnsafeURLError("URL is empty") parsed = urlparse(url) @@ -51,6 +56,31 @@ def validate_outbound_url(url: str) -> str: host = parsed.hostname if not host: raise UnsafeURLError("URL has no host") + return parsed.scheme, host + + +def _check_resolved_addresses(host: str, infos: list[tuple]) -> None: + for info in infos: + sockaddr = info[4] + try: + ip = ipaddress.ip_address(sockaddr[0]) + except ValueError: + continue + if _is_blocked_ip(ip): + raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}") + + +def validate_outbound_url(url: str) -> str: + """Validate ``url`` is safe to fetch; returns the URL on success. + + Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP + is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``) + private addresses are permitted but the scheme check still applies. + + Synchronous; uses blocking ``socket.getaddrinfo``. Prefer + :func:`avalidate_outbound_url` from async code paths. + """ + _, host = _check_scheme_host(url) if _ALLOW_PRIVATE: return url @@ -64,17 +94,37 @@ def validate_outbound_url(url: str) -> str: except ValueError: pass - # Hostname — resolve and reject if any resolution is in a blocked range. try: infos = socket.getaddrinfo(host, None) except socket.gaierror as exc: raise UnsafeURLError(f"DNS resolution failed for {host}") from exc - for info in infos: - sockaddr = info[4] - try: - ip = ipaddress.ip_address(sockaddr[0]) - except ValueError: - continue - if _is_blocked_ip(ip): - raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}") + _check_resolved_addresses(host, infos) + return url + + +async def avalidate_outbound_url(url: str) -> str: + """Async variant that resolves DNS via the running loop's resolver. + + Use this from ``async def`` code paths to avoid blocking the event + loop on DNS lookups. + """ + _, host = _check_scheme_host(url) + + if _ALLOW_PRIVATE: + return url + + try: + ip = ipaddress.ip_address(host) + if _is_blocked_ip(ip): + raise UnsafeURLError(f"Host {host} is in a blocked range") + return url + except ValueError: + pass + + loop = asyncio.get_running_loop() + try: + infos = await loop.getaddrinfo(host, None) + except socket.gaierror as exc: + raise UnsafeURLError(f"DNS resolution failed for {host}") from exc + _check_resolved_addresses(host, infos) return url diff --git a/packages/core/src/notify_bridge_core/notifications/webhook/client.py b/packages/core/src/notify_bridge_core/notifications/webhook/client.py index c5ef57f..0f3cafd 100644 --- a/packages/core/src/notify_bridge_core/notifications/webhook/client.py +++ b/packages/core/src/notify_bridge_core/notifications/webhook/client.py @@ -7,7 +7,7 @@ from typing import Any import aiohttp -from ..ssrf import UnsafeURLError, validate_outbound_url +from ..ssrf import UnsafeURLError, avalidate_outbound_url _LOGGER = logging.getLogger(__name__) @@ -24,7 +24,7 @@ class WebhookClient: async def send(self, payload: dict[str, Any]) -> dict[str, Any]: try: - validate_outbound_url(self._url) + await avalidate_outbound_url(self._url) except UnsafeURLError as err: return {"success": False, "error": f"Unsafe URL: {err}"} try: @@ -33,6 +33,7 @@ class WebhookClient: json=payload, headers={"Content-Type": "application/json", **self._headers}, timeout=_DEFAULT_TIMEOUT, + allow_redirects=False, ) as response: if 200 <= response.status < 300: return {"success": True, "status_code": response.status} diff --git a/packages/core/src/notify_bridge_core/providers/nut/client.py b/packages/core/src/notify_bridge_core/providers/nut/client.py index c1a8dbe..3d0d9d4 100644 --- a/packages/core/src/notify_bridge_core/providers/nut/client.py +++ b/packages/core/src/notify_bridge_core/providers/nut/client.py @@ -12,6 +12,7 @@ _LOGGER = logging.getLogger(__name__) _DEFAULT_PORT = 3493 _READ_TIMEOUT = 10.0 +_WRITE_TIMEOUT = 10.0 _CONNECT_TIMEOUT = 5.0 # Allowed characters for NUT protocol identifiers (UPS names, variable names). @@ -84,14 +85,26 @@ class NutClient: await self._command(f"PASSWORD {self._password}") async def disconnect(self) -> None: - """Send LOGOUT and close the TCP connection.""" + """Send LOGOUT and close the TCP connection. + + ``drain`` is bounded by ``_WRITE_TIMEOUT`` so a half-closed peer + cannot hold the disconnect indefinitely — a tracker tick would + otherwise be pinned by a stuck NUT server and block the scheduler + slot (``max_instances=1``). + """ if self._writer is not None: try: self._writer.write(b"LOGOUT\n") - await self._writer.drain() - except OSError: + await asyncio.wait_for(self._writer.drain(), timeout=_WRITE_TIMEOUT) + except (OSError, asyncio.TimeoutError): pass self._writer.close() + try: + await asyncio.wait_for( + self._writer.wait_closed(), timeout=_WRITE_TIMEOUT, + ) + except (OSError, asyncio.TimeoutError): + pass self._reader = None self._writer = None @@ -135,7 +148,10 @@ class NutClient: if self._writer is None: raise NutClientError("Not connected") self._writer.write(f"{cmd}\n".encode()) - await self._writer.drain() + try: + await asyncio.wait_for(self._writer.drain(), timeout=_WRITE_TIMEOUT) + except asyncio.TimeoutError as exc: + raise NutClientError("Write timeout") from exc async def _readline(self) -> str: """Read one line from upsd, stripping trailing newline.""" diff --git a/packages/core/src/notify_bridge_core/storage.py b/packages/core/src/notify_bridge_core/storage.py index d92161e..16140ee 100644 --- a/packages/core/src/notify_bridge_core/storage.py +++ b/packages/core/src/notify_bridge_core/storage.py @@ -2,8 +2,10 @@ from __future__ import annotations +import asyncio import json import logging +import os from pathlib import Path from typing import Any, Protocol, runtime_checkable @@ -19,34 +21,58 @@ class StorageBackend(Protocol): async def remove(self) -> None: ... +def _read_file(path: Path) -> str | None: + if not path.exists(): + return None + return path.read_text(encoding="utf-8") + + +def _atomic_write(path: Path, payload: str) -> None: + """Write atomically: tmp file + rename. Prevents half-written files on crash.""" + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + tmp.write_text(payload, encoding="utf-8") + os.replace(tmp, path) + + +def _remove_file(path: Path) -> None: + if path.exists(): + path.unlink() + + class JsonFileBackend: - """Simple JSON file storage backend.""" + """Simple JSON file storage backend. + + All blocking I/O is wrapped in ``asyncio.to_thread`` so callers can + ``await load() / save() / remove()`` without stalling the event loop. + """ def __init__(self, path: Path) -> None: self._path = path async def load(self) -> dict[str, Any] | None: - if not self._path.exists(): + try: + text = await asyncio.to_thread(_read_file, self._path) + except OSError as err: + _LOGGER.warning("Failed to load %s: %s", self._path, err) + return None + if text is None: return None try: - text = self._path.read_text(encoding="utf-8") return json.loads(text) - except (json.JSONDecodeError, OSError) as err: - _LOGGER.warning("Failed to load %s: %s", self._path, err) + except json.JSONDecodeError as err: + _LOGGER.warning("Failed to parse %s: %s", self._path, err) return None async def save(self, data: dict[str, Any]) -> None: + payload = json.dumps(data, default=str) try: - self._path.parent.mkdir(parents=True, exist_ok=True) - self._path.write_text( - json.dumps(data, default=str), encoding="utf-8" - ) + await asyncio.to_thread(_atomic_write, self._path, payload) except OSError as err: _LOGGER.error("Failed to save %s: %s", self._path, err) async def remove(self) -> None: try: - if self._path.exists(): - self._path.unlink() + await asyncio.to_thread(_remove_file, self._path) except OSError as err: _LOGGER.error("Failed to remove %s: %s", self._path, err) diff --git a/packages/server/pyproject.toml b/packages/server/pyproject.toml index e77d724..7e3c567 100644 --- a/packages/server/pyproject.toml +++ b/packages/server/pyproject.toml @@ -28,6 +28,7 @@ dev = [ "pytest>=8.0", "pytest-asyncio>=0.23", "httpx>=0.27", + "aioresponses>=0.7", ] [project.scripts] @@ -35,3 +36,14 @@ notify-bridge = "notify_bridge_server.main:run" [tool.hatch.build.targets.wheel] packages = ["src/notify_bridge_server"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +# The default filter doesn't let SQLAlchemy warnings fail the suite, which +# matters because our migrations emit a handful of deprecation warnings we +# don't want to suppress at source. +filterwarnings = [ + "ignore::DeprecationWarning:passlib", + "ignore::DeprecationWarning:bcrypt", +] diff --git a/packages/server/src/notify_bridge_server/api/email_bots.py b/packages/server/src/notify_bridge_server/api/email_bots.py index 479aeb4..e6a33c3 100644 --- a/packages/server/src/notify_bridge_server/api/email_bots.py +++ b/packages/server/src/notify_bridge_server/api/email_bots.py @@ -93,7 +93,14 @@ async def update_email_bot( session: AsyncSession = Depends(get_session), ): bot = await _get_user_bot(session, bot_id, user.id) - for field, value in body.model_dump(exclude_unset=True).items(): + updates = body.model_dump(exclude_unset=True) + # Reject the masked value the GET response returns so the stored password + # is preserved if the user saves without retyping it. + if "smtp_password" in updates: + pw = updates["smtp_password"] + if isinstance(pw, str) and pw.startswith("***"): + updates.pop("smtp_password") + for field, value in updates.items(): setattr(bot, field, value) session.add(bot) await session.commit() diff --git a/packages/server/src/notify_bridge_server/api/matrix_bots.py b/packages/server/src/notify_bridge_server/api/matrix_bots.py index 644a6a7..f5861e7 100644 --- a/packages/server/src/notify_bridge_server/api/matrix_bots.py +++ b/packages/server/src/notify_bridge_server/api/matrix_bots.py @@ -7,6 +7,11 @@ from pydantic import BaseModel from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +from notify_bridge_core.notifications.ssrf import ( + UnsafeURLError, + avalidate_outbound_url, +) + from ..auth.dependencies import get_current_user from ..database.engine import get_session from ..database.models import MatrixBot, User @@ -33,6 +38,21 @@ class MatrixBotUpdate(BaseModel): display_name: str | None = None +def _is_masked_secret(value: str | None) -> bool: + """True when a field still carries our masked placeholder.""" + return bool(value) and (value.startswith("***") or "..." in value) + + +async def _validate_homeserver_url(url: str) -> None: + """Reject homeserver URLs that point to blocked networks.""" + try: + await avalidate_outbound_url(url) + except UnsafeURLError as err: + raise HTTPException( + status_code=400, detail=f"Invalid homeserver_url: {err}" + ) from err + + @router.get("") async def list_matrix_bots( user: User = Depends(get_current_user), @@ -50,6 +70,7 @@ async def create_matrix_bot( user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session), ): + await _validate_homeserver_url(body.homeserver_url) bot = MatrixBot(user_id=user.id, **body.model_dump()) session.add(bot) await session.commit() @@ -74,7 +95,19 @@ async def update_matrix_bot( session: AsyncSession = Depends(get_session), ): bot = await _get_user_bot(session, bot_id, user.id) - for field, value in body.model_dump(exclude_unset=True).items(): + updates = body.model_dump(exclude_unset=True) + + # Re-validate homeserver_url whenever the client supplies a new one so + # no private/loopback target can ever be saved, even via update. + if "homeserver_url" in updates and updates["homeserver_url"]: + await _validate_homeserver_url(updates["homeserver_url"]) + + # Never accept the masked placeholder the GET response returns. If the + # client echoes it back, keep the stored secret. + if "access_token" in updates and _is_masked_secret(updates["access_token"]): + updates.pop("access_token") + + for field, value in updates.items(): setattr(bot, field, value) session.add(bot) await session.commit() @@ -108,15 +141,17 @@ async def test_matrix_bot( If room_id is not provided, just verifies the access token by calling /whoami. """ bot = await _get_user_bot(session, bot_id, user.id) + # Defense-in-depth: even though create/update validate the URL, a bot row + # written before this guard was added could still point at a blocked host. + await _validate_homeserver_url(bot.homeserver_url) import aiohttp from ..services.http_session import get_http_session http = await get_http_session() - # Verify token with /whoami whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami" headers = {"Authorization": f"Bearer {bot.access_token}"} try: - async with http.get(whoami_url, headers=headers) as resp: + async with http.get(whoami_url, headers=headers, allow_redirects=False) as resp: if resp.status != 200: body = await resp.text() return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"} @@ -126,7 +161,6 @@ async def test_matrix_bot( result = {"success": True, "user_id": whoami.get("user_id", "")} - # Optionally send a test message if room_id: from ..services.notifier import _get_test_message from notify_bridge_core.notifications.matrix.client import MatrixClient @@ -148,7 +182,7 @@ def _response(bot: MatrixBot) -> dict: "name": bot.name, "icon": bot.icon, "homeserver_url": bot.homeserver_url, - "access_token": f"{bot.access_token[:8]}...{bot.access_token[-4:]}" if len(bot.access_token) > 12 else "***", + "access_token": f"***{bot.access_token[-4:]}" if len(bot.access_token) > 4 else "***", "display_name": bot.display_name, "created_at": bot.created_at.isoformat(), } diff --git a/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py b/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py index 5967d6f..43a4afe 100644 --- a/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py +++ b/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py @@ -22,7 +22,7 @@ from ..database.models import ( User, ) from ..services.notifier import send_test_notification -from ..services.test_dispatch import dispatch_test_notification +from ..services.manual_dispatch import dispatch_test_notification from .helpers import get_owned_entity _LOGGER = logging.getLogger(__name__) diff --git a/packages/server/src/notify_bridge_server/api/notification_trackers.py b/packages/server/src/notify_bridge_server/api/notification_trackers.py index bca3181..eed6e07 100644 --- a/packages/server/src/notify_bridge_server/api/notification_trackers.py +++ b/packages/server/src/notify_bridge_server/api/notification_trackers.py @@ -11,6 +11,7 @@ from ..auth.dependencies import get_current_user from ..database.engine import get_session from ..database.models import ( EventLog, + NotificationTarget, NotificationTracker, NotificationTrackerState, NotificationTrackerTarget, @@ -54,11 +55,79 @@ async def list_notification_trackers( user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session), ): + # Batched loader: pull trackers, then all their tracker-target links in + # a single query, then the referenced targets in a single query. Avoids + # the old 1 + N + N*M pattern that ran ~60 round-trips for 10 trackers. result = await session.exec( select(NotificationTracker).where(NotificationTracker.user_id == user.id) ) - trackers = result.all() - return [await _tracker_response(session, t) for t in trackers] + trackers = list(result.all()) + if not trackers: + return [] + + tracker_ids = [t.id for t in trackers] + tt_result = await session.exec( + select(NotificationTrackerTarget).where( + NotificationTrackerTarget.tracker_id.in_(tracker_ids) + ) + ) + tt_rows = list(tt_result.all()) + + target_ids = {tt.target_id for tt in tt_rows} + targets_by_id: dict[int, NotificationTarget] = {} + if target_ids: + tgt_result = await session.exec( + select(NotificationTarget).where(NotificationTarget.id.in_(target_ids)) + ) + targets_by_id = {t.id: t for t in tgt_result.all()} + + tts_by_tracker: dict[int, list[NotificationTrackerTarget]] = {} + for tt in tt_rows: + tts_by_tracker.setdefault(tt.tracker_id, []).append(tt) + + return [ + _build_tracker_response(t, tts_by_tracker.get(t.id, []), targets_by_id) + for t in trackers + ] + + +def _build_tracker_response( + t: NotificationTracker, + tts: list[NotificationTrackerTarget], + targets_by_id: dict[int, NotificationTarget], +) -> dict: + """In-memory assembler for a tracker + its pre-loaded links/targets.""" + tracker_targets = [] + for tt in tts: + target = targets_by_id.get(tt.target_id) + tracker_targets.append({ + "id": tt.id, + "tracker_id": tt.tracker_id, + "target_id": tt.target_id, + "target_name": target.name if target else None, + "target_type": target.type if target else None, + "target_icon": target.icon if target else None, + "tracking_config_id": tt.tracking_config_id, + "template_config_id": tt.template_config_id, + "enabled": tt.enabled, + "quiet_hours_start": tt.quiet_hours_start, + "quiet_hours_end": tt.quiet_hours_end, + "created_at": tt.created_at.isoformat(), + }) + return { + "id": t.id, + "name": t.name, + "icon": t.icon, + "provider_id": t.provider_id, + "collection_ids": t.collection_ids, + "scan_interval": t.scan_interval, + "batch_duration": t.batch_duration, + "default_tracking_config_id": t.default_tracking_config_id, + "default_template_config_id": t.default_template_config_id, + "enabled": t.enabled, + "tracker_targets": tracker_targets, + "created_at": t.created_at.isoformat(), + } @router.post("", status_code=status.HTTP_201_CREATED) diff --git a/packages/server/src/notify_bridge_server/api/providers.py b/packages/server/src/notify_bridge_server/api/providers.py index 400b8a5..71b06bc 100644 --- a/packages/server/src/notify_bridge_server/api/providers.py +++ b/packages/server/src/notify_bridge_server/api/providers.py @@ -306,16 +306,31 @@ async def update_provider( if body.icon is not None: provider.icon = body.icon - config_changed = body.config is not None and body.config != provider.config if body.config is not None: - _validate_provider_config(provider.type, body.config) - provider.config = body.config + # Merge rather than replace so the masked secrets the frontend + # receives on GET cannot silently nuke the stored values when the + # user saves without re-entering them. Any field that still carries + # our mask placeholder ("***…") is dropped from the incoming body. + incoming = dict(body.config) + for secret_field in ( + "api_key", "api_token", "webhook_secret", "password", + "client_secret", "refresh_token", + ): + value = incoming.get(secret_field) + if isinstance(value, str) and value.startswith("***"): + incoming.pop(secret_field, None) + new_config = {**provider.config, **incoming} + _validate_provider_config(provider.type, new_config) + config_changed = new_config != provider.config + provider.config = new_config - # Re-validate connection when config changes for known provider types - if config_changed: - test_result = await _validate_provider_connection(provider) - if test_result.get("external_domain"): - provider.config = {**provider.config, "external_domain": test_result["external_domain"]} + if config_changed: + test_result = await _validate_provider_connection(provider) + if test_result.get("external_domain"): + provider.config = { + **provider.config, + "external_domain": test_result["external_domain"], + } session.add(provider) await session.commit() diff --git a/packages/server/src/notify_bridge_server/api/users.py b/packages/server/src/notify_bridge_server/api/users.py index d9648bc..ec6972e 100644 --- a/packages/server/src/notify_bridge_server/api/users.py +++ b/packages/server/src/notify_bridge_server/api/users.py @@ -1,5 +1,6 @@ """User management API routes (admin only).""" +import asyncio import logging from fastapi import APIRouter, Depends, HTTPException, status @@ -14,6 +15,15 @@ from ..auth.dependencies import require_admin from ..database.engine import get_session from ..database.models import User + +async def _hash_password(password: str) -> str: + """Run bcrypt off the event loop. Matches the helper in auth/routes.py.""" + + def _work() -> str: + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + + return await asyncio.to_thread(_work) + _LOGGER = logging.getLogger(__name__) router = APIRouter(prefix="/api/users", tags=["users"]) @@ -36,8 +46,12 @@ async def list_users( admin: User = Depends(require_admin), session: AsyncSession = Depends(get_session), ): - """List all users (admin only).""" - result = await session.exec(select(User)) + """List all users (admin only). + + Excludes the internal ``__system__`` placeholder (id=0) used as the + owner of default templates/configs — it is never a real account. + """ + result = await session.exec(select(User).where(User.id != 0)) return [ {"id": u.id, "username": u.username, "role": u.role, "created_at": u.created_at.isoformat()} for u in result.all() @@ -61,7 +75,7 @@ async def create_user( user = User( username=body.username, - hashed_password=bcrypt.hashpw(body.password.encode(), bcrypt.gensalt()).decode(), + hashed_password=await _hash_password(body.password), role=body.role if body.role in ("admin", "user") else "user", ) session.add(user) @@ -162,7 +176,7 @@ async def reset_user_password( raise HTTPException(status_code=404, detail="User not found") if len(body.new_password) < 8: raise HTTPException(status_code=400, detail="Password must be at least 8 characters") - user.hashed_password = bcrypt.hashpw(body.new_password.encode(), bcrypt.gensalt()).decode() + user.hashed_password = await _hash_password(body.new_password) # Invalidate all prior JWTs issued for this user — matches the self-serve # password-change path in auth/routes.py. user.token_version = (user.token_version or 1) + 1 diff --git a/packages/server/src/notify_bridge_server/api/webhooks.py b/packages/server/src/notify_bridge_server/api/webhooks.py index 1bdab14..973691f 100644 --- a/packages/server/src/notify_bridge_server/api/webhooks.py +++ b/packages/server/src/notify_bridge_server/api/webhooks.py @@ -37,6 +37,42 @@ _LOGGER = logging.getLogger(__name__) router = APIRouter(prefix="/api/webhooks", tags=["webhooks"]) +# Hard cap on inbound webhook body size (1 MiB is far larger than anything +# legitimate providers send and keeps the worst-case memory footprint bounded +# when a malicious peer lies about Content-Length or streams slowly). +_MAX_WEBHOOK_BODY_BYTES = 1_000_000 + + +async def _read_bounded_body(request: Request, limit: int = _MAX_WEBHOOK_BODY_BYTES) -> bytes: + """Reject oversized inbound bodies before they exhaust memory. + + First checks ``Content-Length`` (fast-path for honest peers), then + streams the body in chunks enforcing the same cap on actual bytes + received so a peer that lies about Content-Length cannot slip through. + """ + declared = request.headers.get("content-length") + if declared: + try: + if int(declared) > limit: + raise HTTPException( + status_code=413, + detail=f"Payload too large (max {limit} bytes)", + ) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid Content-Length") + + chunks: list[bytes] = [] + size = 0 + async for chunk in request.stream(): + size += len(chunk) + if size > limit: + raise HTTPException( + status_code=413, + detail=f"Payload too large (max {limit} bytes)", + ) + chunks.append(chunk) + return b"".join(chunks) + async def _get_provider_by_token( session: AsyncSession, token: str, expected_type: str, @@ -169,7 +205,8 @@ async def _dispatch_webhook_event( )) # Dispatch to targets - dispatcher = NotificationDispatcher() + from ..services.http_session import get_http_session + dispatcher = NotificationDispatcher(session=await get_http_session()) target_configs = _build_target_configs(event, link_data, provider_config, app_tz) if target_configs: results = await dispatcher.dispatch(event, target_configs) @@ -203,7 +240,7 @@ async def gitea_webhook(token: str, request: Request): webhook_secret = (provider.config or {}).get("webhook_secret", "") # Read raw body for HMAC check - raw_body = await request.body() + raw_body = await _read_bounded_body(request) if not webhook_secret: raise HTTPException( @@ -221,8 +258,8 @@ async def gitea_webhook(token: str, request: Request): return {"ok": True, "skipped": "no event header"} try: - payload = await request.json() - except (json.JSONDecodeError, ValueError): + payload = json.loads(raw_body.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError, ValueError): raise HTTPException(status_code=400, detail="Invalid JSON") event = parse_gitea_webhook(event_header, payload, provider.name) @@ -280,10 +317,10 @@ async def planka_webhook(token: str, request: Request): if not _verify_planka_token(webhook_secret, request): raise HTTPException(status_code=403, detail="Invalid token") - # Parse payload + # Parse payload from the bounded raw_body we already read. try: - payload = await request.json() - except (json.JSONDecodeError, ValueError): + payload = json.loads(raw_body.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError, ValueError): raise HTTPException(status_code=400, detail="Invalid JSON") event_type = payload.get("type", "") @@ -446,23 +483,22 @@ async def generic_webhook(token: str, request: Request): store_payloads = provider_config.get("store_payloads", True) max_stored = min(max(int(provider_config.get("max_stored_payloads", 20)), 1), 100) - raw_body = await request.body() + raw_body = await _read_bounded_body(request) - # Enforce payload size limit BEFORE parsing JSON - if len(raw_body) > 1_000_000: - raise HTTPException(status_code=413, detail="Payload too large (max 1 MB)") + # Bounded read above already enforces the size cap; no need to re-check. if not _verify_generic_webhook_auth(provider_config, request, raw_body): raise HTTPException(status_code=403, detail="Authentication failed") safe_headers = _filter_headers(dict(request.headers)) - # Parse JSON payload + # Parse JSON payload from the already-bounded raw_body (request.body() + # has been consumed, so request.json() is no longer usable here). try: - payload = await request.json() + payload = json.loads(raw_body.decode("utf-8")) if not isinstance(payload, dict): raise ValueError("Payload must be a JSON object") - except (json.JSONDecodeError, ValueError): + except (UnicodeDecodeError, json.JSONDecodeError, ValueError): if store_payloads: async with AsyncSession(get_engine()) as log_session: await _save_webhook_log( diff --git a/packages/server/src/notify_bridge_server/auth/jwt.py b/packages/server/src/notify_bridge_server/auth/jwt.py index a57f5a5..f35bdf7 100644 --- a/packages/server/src/notify_bridge_server/auth/jwt.py +++ b/packages/server/src/notify_bridge_server/auth/jwt.py @@ -7,30 +7,51 @@ import jwt from ..config import settings ALGORITHM = "HS256" +_LEEWAY_SECONDS = 10 + + +def _now_utc() -> datetime: + return datetime.now(timezone.utc) def create_access_token(user_id: int, role: str, token_version: int = 1) -> str: - expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes) + now = _now_utc() + expire = now + timedelta(minutes=settings.access_token_expire_minutes) payload = { + "iss": settings.jwt_issuer, + "aud": settings.jwt_audience, "sub": str(user_id), "role": role, "type": "access", "ver": token_version, + "iat": now, "exp": expire, } return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM) def create_refresh_token(user_id: int, token_version: int = 1) -> str: - expire = datetime.now(timezone.utc) + timedelta(days=settings.refresh_token_expire_days) + now = _now_utc() + expire = now + timedelta(days=settings.refresh_token_expire_days) payload = { + "iss": settings.jwt_issuer, + "aud": settings.jwt_audience, "sub": str(user_id), "type": "refresh", "ver": token_version, + "iat": now, "exp": expire, } return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM) def decode_token(token: str) -> dict: - return jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM]) + return jwt.decode( + token, + settings.secret_key, + algorithms=[ALGORITHM], + audience=settings.jwt_audience, + issuer=settings.jwt_issuer, + leeway=_LEEWAY_SECONDS, + options={"require": ["exp", "sub", "iss", "aud", "type"]}, + ) diff --git a/packages/server/src/notify_bridge_server/auth/routes.py b/packages/server/src/notify_bridge_server/auth/routes.py index ab69f8c..e7e9f3a 100644 --- a/packages/server/src/notify_bridge_server/auth/routes.py +++ b/packages/server/src/notify_bridge_server/auth/routes.py @@ -1,5 +1,7 @@ """Authentication API routes.""" +import asyncio + from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel from slowapi import Limiter @@ -16,7 +18,9 @@ from .jwt import create_access_token, create_refresh_token, decode_token router = APIRouter(prefix="/api/auth", tags=["auth"]) -limiter = Limiter(key_func=get_remote_address) +# Default rate limit applied by SlowAPIMiddleware to every route that does NOT +# specify its own @limiter.limit(...) — protects against blanket abuse. +limiter = Limiter(key_func=get_remote_address, default_limits=["600/minute"]) class SetupRequest(BaseModel): @@ -45,27 +49,52 @@ class RefreshRequest(BaseModel): refresh_token: str -def _hash_password(password: str) -> str: - return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() +async def _hash_password(password: str) -> str: + """bcrypt.hashpw is CPU-bound (~200-500ms); never run it on the event loop.""" + + def _work() -> str: + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + + return await asyncio.to_thread(_work) -def _verify_password(password: str, hashed: str) -> bool: - return bcrypt.checkpw(password.encode(), hashed.encode()) +async def _verify_password(password: str, hashed: str) -> bool: + def _work() -> bool: + try: + return bcrypt.checkpw(password.encode(), hashed.encode()) + except ValueError: + # Malformed hash in DB — treat as mismatch, never raise to caller. + return False + + return await asyncio.to_thread(_work) @router.post("/setup", response_model=TokenResponse) @limiter.limit("3/minute") async def setup(request: Request, body: SetupRequest, session: AsyncSession = Depends(get_session)): - result = await session.exec(select(func.count()).select_from(User)) - count = result.one() - if count > 0: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Setup already completed.") - if len(body.password) < 8: raise HTTPException(status_code=400, detail="Password must be at least 8 characters") - user = User(username=body.username, hashed_password=_hash_password(body.password), role="admin") - session.add(user) - await session.commit() + # Compute hash BEFORE opening the transaction so we don't hold a writer lock + # during the CPU-bound bcrypt work. + hashed = await _hash_password(body.password) + + # Serialize setup via an INSERT-inside-transaction-with-count-guard. + # SQLite's writer lock plus the count check inside the transaction closes + # the TOCTOU window between two concurrent POSTs. We ignore id=0 — that's + # the internal "__system__" placeholder used for ownership of default + # templates, never a real admin. + async with session.begin(): + result = await session.exec( + select(func.count()).select_from(User).where(User.id != 0) + ) + count = result.one() + if count > 0: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Setup already completed.", + ) + user = User(username=body.username, hashed_password=hashed, role="admin") + session.add(user) await session.refresh(user) return TokenResponse( @@ -79,7 +108,13 @@ async def setup(request: Request, body: SetupRequest, session: AsyncSession = De async def login(request: Request, body: LoginRequest, session: AsyncSession = Depends(get_session)): result = await session.exec(select(User).where(User.username == body.username)) user = result.first() - if not user or not _verify_password(body.password, user.hashed_password): + # Always run a bcrypt verification to keep the response time constant, + # preventing username-enumeration via timing side channel. + password_ok = await _verify_password( + body.password, + user.hashed_password if user else "$2b$12$" + "a" * 53, + ) + if not user or not password_ok: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password") return TokenResponse( @@ -124,16 +159,18 @@ class PasswordChangeRequest(BaseModel): @router.put("/password") +@limiter.limit("10/minute") async def change_password( + request: Request, body: PasswordChangeRequest, user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session), ): - if not _verify_password(body.current_password, user.hashed_password): + if not await _verify_password(body.current_password, user.hashed_password): raise HTTPException(status_code=400, detail="Current password is incorrect") if len(body.new_password) < 8: raise HTTPException(status_code=400, detail="New password must be at least 8 characters") - user.hashed_password = _hash_password(body.new_password) + user.hashed_password = await _hash_password(body.new_password) user.token_version = (user.token_version or 1) + 1 session.add(user) await session.commit() @@ -141,7 +178,12 @@ async def change_password( @router.get("/needs-setup") -async def needs_setup(session: AsyncSession = Depends(get_session)): - result = await session.exec(select(func.count()).select_from(User)) +@limiter.limit("30/minute") +async def needs_setup(request: Request, session: AsyncSession = Depends(get_session)): + # Exclude the internal __system__ placeholder (id=0) from the count so + # a fresh install still reports needs_setup=True. + result = await session.exec( + select(func.count()).select_from(User).where(User.id != 0) + ) count = result.one() return {"needs_setup": count == 0} diff --git a/packages/server/src/notify_bridge_server/config.py b/packages/server/src/notify_bridge_server/config.py index e27c94e..128aebe 100644 --- a/packages/server/src/notify_bridge_server/config.py +++ b/packages/server/src/notify_bridge_server/config.py @@ -2,8 +2,20 @@ from pathlib import Path from typing import Any +from urllib.parse import urlparse + from pydantic_settings import BaseSettings +# Secret keys we will actively refuse. These cover the default template value +# and dev-only literals that have appeared in scripts or documentation. +_FORBIDDEN_SECRETS: frozenset[str] = frozenset( + { + "change-me-in-production", + "test-secret-key-minimum-32-chars", + "dev-secret-key-not-for-production", + } +) + class Settings(BaseSettings): """Application settings loaded from environment variables.""" @@ -13,29 +25,25 @@ class Settings(BaseSettings): secret_key: str = "change-me-in-production" - def model_post_init(self, __context: Any) -> None: - if self.secret_key == "change-me-in-production": - raise ValueError( - "SECURITY: Refusing to start with the default secret_key. " - "Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) " - "before starting the server (debug mode included)." - ) - if len(self.secret_key) < 32: - raise ValueError( - "SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters." - ) - if "*" in self.cors_allowed_origins.split(","): - raise ValueError( - "SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled." - ) - - access_token_expire_minutes: int = 60 + access_token_expire_minutes: int = 15 refresh_token_expire_days: int = 30 + jwt_issuer: str = "notify-bridge" + jwt_audience: str = "notify-bridge-api" + host: str = "0.0.0.0" port: int = 8420 debug: bool = False + # Comma-separated list of trusted proxy IPs uvicorn will honor for + # X-Forwarded-For / X-Forwarded-Proto. Use "*" ONLY when you trust the + # network (never directly on the internet). Default matches uvicorn. + forwarded_allow_ips: str = "127.0.0.1" + + # How long to wait for in-flight requests / scheduler jobs before force + # killing on SIGTERM. + graceful_shutdown_seconds: int = 60 + anthropic_api_key: str = "" ai_model: str = "claude-sonnet-4-20250514" ai_max_tokens: int = 1024 @@ -49,9 +57,6 @@ class Settings(BaseSettings): """Path to frontend static files. Set to serve SvelteKit build via FastAPI (e.g. /app/static in Docker).""" # --- Logging --- - # Boot-time logging configuration. DB AppSetting rows (``log_level`` / - # ``log_levels`` / ``log_format``) override these after startup, letting - # operators adjust levels from the settings UI without a restart. log_level: str = "INFO" """Root log level for the app loggers (``DEBUG``/``INFO``/``WARNING``/``ERROR``).""" @@ -61,8 +66,43 @@ class Settings(BaseSettings): log_levels: str = "" """Comma-separated per-module overrides, e.g. ``notify_bridge_core.notifications.telegram.client=DEBUG,sqlalchemy.engine=INFO``.""" + # --- Retention --- + event_log_retention_days: int = 30 + """Days of event_log history to retain. 0 disables the retention job.""" + model_config = {"env_prefix": "NOTIFY_BRIDGE_"} + def model_post_init(self, __context: Any) -> None: + if self.secret_key in _FORBIDDEN_SECRETS: + raise ValueError( + "SECURITY: Refusing to start with a known/default secret_key. " + "Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) " + "before starting the server." + ) + if len(self.secret_key) < 32: + raise ValueError( + "SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters." + ) + origins = [o.strip() for o in self.cors_allowed_origins.split(",") if o.strip()] + if "*" in origins: + raise ValueError( + "SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled." + ) + for origin in origins: + parsed = urlparse(origin) + if parsed.scheme not in {"http", "https"} or not parsed.netloc: + raise ValueError( + f"CORS origin {origin!r} is invalid — must include scheme (http/https) and host." + ) + if self.access_token_expire_minutes <= 0: + raise ValueError("access_token_expire_minutes must be > 0") + if self.refresh_token_expire_days <= 0: + raise ValueError("refresh_token_expire_days must be > 0") + if not (1 <= self.port <= 65535): + raise ValueError("port must be in range 1..65535") + if self.event_log_retention_days < 0: + raise ValueError("event_log_retention_days must be >= 0") + @property def effective_database_url(self) -> str: if self.database_url: diff --git a/packages/server/src/notify_bridge_server/database/engine.py b/packages/server/src/notify_bridge_server/database/engine.py index 2712f99..d6ba702 100644 --- a/packages/server/src/notify_bridge_server/database/engine.py +++ b/packages/server/src/notify_bridge_server/database/engine.py @@ -1,23 +1,59 @@ """Database engine and session management.""" from collections.abc import AsyncGenerator +import logging +from sqlalchemy import event from sqlmodel import SQLModel from sqlmodel.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from ..config import settings +_LOGGER = logging.getLogger(__name__) + _engine: AsyncEngine | None = None +def _install_sqlite_pragmas(engine: AsyncEngine) -> None: + """Apply production-grade SQLite PRAGMAs on every new connection. + + WAL mode lets readers and writers work concurrently without blocking; + ``busy_timeout`` gives contending writers a chance instead of instant + SQLITE_BUSY; ``foreign_keys`` enforces the FK constraints declared in the + models (SQLite disables them by default); ``synchronous=NORMAL`` is a + safe-by-default durability trade-off that is standard in WAL mode. + """ + + @event.listens_for(engine.sync_engine, "connect") + def _pragmas(dbapi_conn, _record): # pragma: no cover — driver hook + cur = dbapi_conn.cursor() + try: + cur.execute("PRAGMA journal_mode=WAL") + cur.execute("PRAGMA synchronous=NORMAL") + cur.execute("PRAGMA foreign_keys=ON") + cur.execute("PRAGMA busy_timeout=10000") + cur.execute("PRAGMA temp_store=MEMORY") + finally: + cur.close() + + def get_engine() -> AsyncEngine: global _engine if _engine is None: + url = settings.effective_database_url + connect_args: dict = {} + if url.startswith("sqlite"): + connect_args["timeout"] = 30 _engine = create_async_engine( - settings.effective_database_url, + url, echo=settings.debug, + pool_pre_ping=True, + connect_args=connect_args, ) + if url.startswith("sqlite"): + _install_sqlite_pragmas(_engine) + _LOGGER.info("Database engine initialized: %s", url.split("://", 1)[0]) return _engine @@ -31,3 +67,11 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]: engine = get_engine() async with AsyncSession(engine) as session: yield session + + +async def dispose_engine() -> None: + """Close the engine's connection pool. Call during graceful shutdown.""" + global _engine + if _engine is not None: + await _engine.dispose() + _engine = None diff --git a/packages/server/src/notify_bridge_server/database/migrations.py b/packages/server/src/notify_bridge_server/database/migrations.py index 2a58b36..84dbc50 100644 --- a/packages/server/src/notify_bridge_server/database/migrations.py +++ b/packages/server/src/notify_bridge_server/database/migrations.py @@ -1282,3 +1282,141 @@ async def migrate_user_token_version(engine: AsyncEngine) -> None: text("ALTER TABLE user ADD COLUMN token_version INTEGER NOT NULL DEFAULT 1") ) logger.info("Added token_version column to user table") + + +# --------------------------------------------------------------------------- +# Performance indexes — covers every FK / owner column the list endpoints +# and the webhook hot-path filter on. All use CREATE INDEX IF NOT EXISTS so +# they are safe to re-run on every boot. +# --------------------------------------------------------------------------- + +_INDEXES: list[tuple[str, str, str]] = [ + # (index_name, table, columns) + ("ix_service_provider_user_id", "service_provider", "user_id"), + ("ix_telegram_bot_user_id", "telegram_bot", "user_id"), + ("ix_matrix_bot_user_id", "matrix_bot", "user_id"), + ("ix_email_bot_user_id", "email_bot", "user_id"), + ("ix_telegram_chat_bot_id", "telegram_chat", "bot_id"), + ("ix_tracking_config_user_id", "tracking_config", "user_id"), + ("ix_tracking_config_provider_type", "tracking_config", "provider_type"), + ("ix_notification_target_user_id", "notification_target", "user_id"), + ("ix_notification_target_type", "notification_target", "type"), + ("ix_notification_tracker_user_id", "notification_tracker", "user_id"), + ("ix_notification_tracker_provider_id", "notification_tracker", "provider_id"), + # Composite for the webhook hot path: WHERE provider_id = ? AND enabled = true + ( + "ix_notification_tracker_provider_enabled", + "notification_tracker", + "provider_id, enabled", + ), + ("ix_command_config_user_id", "command_config", "user_id"), + ("ix_command_template_config_user_id", "command_template_config", "user_id"), + ("ix_command_tracker_user_id", "command_tracker", "user_id"), + ("ix_command_tracker_provider_id", "command_tracker", "provider_id"), + ("ix_action_user_id", "action", "user_id"), + ("ix_action_provider_id", "action", "provider_id"), + # Dashboard: SELECT event_log WHERE user_id = ? ORDER BY created_at DESC + ("ix_event_log_user_created", "event_log", "user_id, created_at DESC"), + ("ix_event_log_provider_id", "event_log", "provider_id"), + ("ix_event_log_notification_tracker_id", "event_log", "notification_tracker_id"), + ("ix_event_log_action_id", "event_log", "action_id"), + # Webhook log hot path: WHERE provider_id = ? ORDER BY created_at DESC + ( + "ix_webhook_payload_log_provider_created", + "webhook_payload_log", + "provider_id, created_at DESC", + ), + # Notification tracker join tables + ( + "ix_notification_tracker_target_notification_tracker_id", + "notification_tracker_target", + "notification_tracker_id", + ), + ( + "ix_notification_tracker_target_target_id", + "notification_tracker_target", + "target_id", + ), + ("ix_target_receiver_target_id", "target_receiver", "target_id"), + ("ix_template_slot_config_id", "template_slot", "config_id"), + ("ix_command_template_slot_config_id", "command_template_slot", "config_id"), + ("ix_action_rule_action_id", "action_rule", "action_id"), + ("ix_action_execution_action_started", "action_execution", "action_id, started_at DESC"), +] + + +async def migrate_performance_indexes(engine: AsyncEngine) -> None: + """Create missing performance indexes on hot query paths. + + Every index is created with IF NOT EXISTS so the migration is safe to + replay on every boot. We only create the index when the table exists — + early boots before other migrations land would otherwise raise. + """ + async with engine.begin() as conn: + for name, table, columns in _INDEXES: + _assert_ident(name, "index") + _assert_ident(table, "table") + # Columns list is a trusted literal constructed above — never user input. + if not await _has_table(conn, table): + continue + try: + await conn.execute( + text(f"CREATE INDEX IF NOT EXISTS {name} ON {table} ({columns})") + ) + except Exception: # pragma: no cover — log and continue + logger.warning( + "Failed to create index %s on %s(%s)", + name, table, columns, exc_info=True, + ) + + +# --------------------------------------------------------------------------- +# Schema version tracking — lightweight alternative to Alembic while the +# hand-rolled idempotent migrations remain the source of truth. Gives +# operators a single-row answer to "what schema is this DB at" and lets +# future upgrades short-circuit migrations that already ran. +# --------------------------------------------------------------------------- + +CURRENT_SCHEMA_VERSION = 1 + + +async def migrate_schema_version(engine: AsyncEngine) -> None: + """Create schema_version table and bump it to CURRENT_SCHEMA_VERSION.""" + async with engine.begin() as conn: + await conn.execute( + text( + "CREATE TABLE IF NOT EXISTS schema_version (" + " id INTEGER PRIMARY KEY CHECK (id = 1)," + " version INTEGER NOT NULL," + " applied_at TEXT NOT NULL" + ")" + ) + ) + row = await conn.run_sync( + lambda sc: sc.execute( + text("SELECT version FROM schema_version WHERE id = 1") + ).fetchone() + ) + from datetime import datetime, timezone + now = datetime.now(timezone.utc).isoformat() + if row is None: + await conn.execute( + text( + "INSERT INTO schema_version (id, version, applied_at) " + "VALUES (1, :v, :t)" + ), + {"v": CURRENT_SCHEMA_VERSION, "t": now}, + ) + logger.info("Initialized schema_version at %d", CURRENT_SCHEMA_VERSION) + elif int(row[0]) < CURRENT_SCHEMA_VERSION: + await conn.execute( + text( + "UPDATE schema_version SET version = :v, applied_at = :t " + "WHERE id = 1" + ), + {"v": CURRENT_SCHEMA_VERSION, "t": now}, + ) + logger.info( + "Bumped schema_version from %s to %d", + row[0], CURRENT_SCHEMA_VERSION, + ) diff --git a/packages/server/src/notify_bridge_server/database/seeds.py b/packages/server/src/notify_bridge_server/database/seeds.py index c81d67a..683293a 100644 --- a/packages/server/src/notify_bridge_server/database/seeds.py +++ b/packages/server/src/notify_bridge_server/database/seeds.py @@ -394,8 +394,37 @@ async def _seed_default_command_configs() -> None: # Public entry point # --------------------------------------------------------------------------- +async def _ensure_system_user() -> None: + """Ensure a User row with id=0 exists. + + Historically the app used ``user_id=0`` as a sentinel for "system-owned" + defaults (tracking configs, templates, etc.). Now that we enable + ``PRAGMA foreign_keys=ON`` at connect time, those inserts would fail + with ``FOREIGN KEY constraint failed`` unless a placeholder user row + with the matching id exists. + """ + engine = get_engine() + async with engine.begin() as conn: + # INSERT OR IGNORE so re-running seeds is cheap and idempotent. + await conn.execute( + text( + "INSERT OR IGNORE INTO user " + "(id, username, hashed_password, role, token_version, created_at) " + "VALUES (0, :u, :p, :r, 1, :t)" + ), + { + "u": "__system__", + # Invalid bcrypt hash — nobody can ever log in as this user. + "p": "!disabled!", + "r": "system", + "t": datetime.now(timezone.utc).isoformat(), + }, + ) + + async def seed_all() -> None: """Run all seed functions in order.""" + await _ensure_system_user() await _seed_default_templates() await _seed_default_command_templates() await _seed_default_tracking_configs() diff --git a/packages/server/src/notify_bridge_server/main.py b/packages/server/src/notify_bridge_server/main.py index d73b9ea..4726b4a 100644 --- a/packages/server/src/notify_bridge_server/main.py +++ b/packages/server/src/notify_bridge_server/main.py @@ -52,12 +52,31 @@ from .api.webhook_logs import router as webhook_logs_router from .api.backup import router as backup_router +# Readiness flag — flipped to True once the scheduler has started and the +# app is fully initialized. Exposed via /api/ready for orchestrators. +_READY: bool = False + + @asynccontextmanager async def lifespan(app: FastAPI): + global _READY await init_db() # Run data migrations (idempotent) from .database.engine import get_engine - from .database.migrations import migrate_schema, migrate_tracker_targets, migrate_entity_refactor, migrate_template_slots, migrate_target_receivers, migrate_template_locale, migrate_receivers_from_config, migrate_command_slot_locale, migrate_notification_slot_locale, migrate_user_token_version + from .database.migrations import ( + migrate_schema, + migrate_tracker_targets, + migrate_entity_refactor, + migrate_template_slots, + migrate_target_receivers, + migrate_template_locale, + migrate_receivers_from_config, + migrate_command_slot_locale, + migrate_notification_slot_locale, + migrate_user_token_version, + migrate_performance_indexes, + migrate_schema_version, + ) engine = get_engine() await migrate_schema(engine) await migrate_tracker_targets(engine) @@ -69,6 +88,8 @@ async def lifespan(app: FastAPI): await migrate_command_slot_locale(engine) await migrate_notification_slot_locale(engine) await migrate_user_token_version(engine) + await migrate_performance_indexes(engine) + await migrate_schema_version(engine) from .database.seeds import seed_all await seed_all() # Apply DB-backed logging settings (override env-based boot config). @@ -100,16 +121,28 @@ async def lifespan(app: FastAPI): set_webhook_secret(_secret or None) from .services.scheduler import start_scheduler, get_scheduler await start_scheduler() + _READY = True yield - # Graceful shutdown - from .services.http_session import close_http_session - await close_http_session() + # Graceful shutdown — stop the scheduler FIRST so in-flight jobs finish + # before we close their HTTP session. Then close the shared session and + # dispose the DB engine. + _READY = False scheduler = get_scheduler() if scheduler.running: - scheduler.shutdown() + scheduler.shutdown(wait=True) + from .services.http_session import close_http_session + await close_http_session() + from .database.engine import dispose_engine + await dispose_engine() -app = FastAPI(title="Notify Bridge", version="0.1.0", lifespan=lifespan) +try: + from importlib.metadata import version as _pkg_version + _APP_VERSION = _pkg_version("notify-bridge-server") +except Exception: # pragma: no cover — editable install edge cases + _APP_VERSION = "0.0.0+unknown" + +app = FastAPI(title="Notify Bridge", version=_APP_VERSION, lifespan=lifespan) # --- Security headers --- from starlette.middleware.base import BaseHTTPMiddleware @@ -117,6 +150,19 @@ from starlette.requests import Request as StarletteRequest from starlette.responses import Response as StarletteResponse +_CSP = ( + "default-src 'self'; " + "img-src 'self' data: blob: https:; " + "style-src 'self' 'unsafe-inline'; " + "script-src 'self'; " + "connect-src 'self'; " + "font-src 'self' data:; " + "base-uri 'self'; " + "form-action 'self'; " + "frame-ancestors 'none'" +) + + class SecurityHeadersMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: StarletteRequest, call_next): response: StarletteResponse = await call_next(request) @@ -124,6 +170,14 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware): response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + response.headers.setdefault("Content-Security-Policy", _CSP) + # HSTS only makes sense over HTTPS; set when the edge terminates TLS + # and forwards X-Forwarded-Proto=https. + if request.headers.get("x-forwarded-proto") == "https": + response.headers.setdefault( + "Strict-Transport-Security", + "max-age=31536000; includeSubDomains", + ) return response @@ -176,7 +230,22 @@ app.include_router(backup_router) @app.get("/api/health") async def health(): - return {"status": "ok"} + """Liveness: process is up and responding. Always returns 200 once the + ASGI app has started. Keep this endpoint anonymous and trivially cheap.""" + return {"status": "ok", "version": _APP_VERSION} + + +@app.get("/api/ready") +async def ready(): + """Readiness: migrations and scheduler have started, app can serve traffic. + + Returns 503 until the lifespan startup sequence has completed. Use this + for orchestrator readiness probes (Docker, Kubernetes). + """ + if not _READY: + from starlette.responses import JSONResponse + return JSONResponse({"status": "starting"}, status_code=503) + return {"status": "ready", "version": _APP_VERSION} # --- Serve frontend static files (production) --- @@ -209,4 +278,12 @@ if _cfg.static_dir and Path(_cfg.static_dir).is_dir(): def run(): import uvicorn - uvicorn.run(app, host=_cfg.host, port=_cfg.port) + uvicorn.run( + app, + host=_cfg.host, + port=_cfg.port, + proxy_headers=True, + forwarded_allow_ips=_cfg.forwarded_allow_ips or "127.0.0.1", + timeout_graceful_shutdown=_cfg.graceful_shutdown_seconds, + access_log=not _cfg.debug, + ) diff --git a/packages/server/src/notify_bridge_server/services/backup_service.py b/packages/server/src/notify_bridge_server/services/backup_service.py index 25d7e8e..0292b57 100644 --- a/packages/server/src/notify_bridge_server/services/backup_service.py +++ b/packages/server/src/notify_bridge_server/services/backup_service.py @@ -387,10 +387,9 @@ async def export_backup_to_file( ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S") filename = f"backup-{ts}.json" filepath = backup_dir / filename - filepath.write_text( - json.dumps(backup.model_dump(), indent=2, ensure_ascii=False), - encoding="utf-8", - ) + import asyncio as _asyncio + payload = json.dumps(backup.model_dump(), indent=2, ensure_ascii=False) + await _asyncio.to_thread(filepath.write_text, payload, encoding="utf-8") _LOGGER.info("Scheduled backup saved: %s", filepath) return filepath @@ -399,7 +398,13 @@ def cleanup_old_backups(backup_dir: Path, keep: int = 5) -> list[str]: """Delete oldest backup files exceeding `keep` count. Returns deleted filenames.""" if not backup_dir.is_dir(): return [] - files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True) + # Sort by mtime (newest first) so behavior doesn't depend on the filename + # timestamp format, which could change later without updating this code. + files = sorted( + backup_dir.glob("backup-*.json"), + key=lambda f: f.stat().st_mtime, + reverse=True, + ) deleted = [] for old in files[keep:]: old.unlink() @@ -413,7 +418,13 @@ def list_backup_files(backup_dir: Path) -> list[dict[str, Any]]: """List backup files in the directory with metadata.""" if not backup_dir.is_dir(): return [] - files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True) + # Sort by mtime (newest first) so behavior doesn't depend on the filename + # timestamp format, which could change later without updating this code. + files = sorted( + backup_dir.glob("backup-*.json"), + key=lambda f: f.stat().st_mtime, + reverse=True, + ) result = [] for f in files: stat = f.stat() diff --git a/packages/server/src/notify_bridge_server/services/http_session.py b/packages/server/src/notify_bridge_server/services/http_session.py index bf0021d..873f317 100644 --- a/packages/server/src/notify_bridge_server/services/http_session.py +++ b/packages/server/src/notify_bridge_server/services/http_session.py @@ -11,23 +11,36 @@ Call ``close_http_session()`` once during application shutdown. from __future__ import annotations +import asyncio + import aiohttp -_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30) +_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30, connect=10) + _session: aiohttp.ClientSession | None = None +_lock = asyncio.Lock() async def get_http_session() -> aiohttp.ClientSession: - """Get or create the shared HTTP session.""" + """Get or create the shared HTTP session. + + Concurrent first-callers are serialized through ``_lock`` so we never + leak a second ClientSession / connector pair. Once established, hot + callers skip the lock via the fast-path check. + """ global _session - if _session is None or _session.closed: - _session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT) + if _session is not None and not _session.closed: + return _session + async with _lock: + if _session is None or _session.closed: + _session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT) return _session async def close_http_session() -> None: """Close the shared HTTP session (call on app shutdown).""" global _session - if _session is not None and not _session.closed: - await _session.close() + async with _lock: + if _session is not None and not _session.closed: + await _session.close() _session = None diff --git a/packages/server/src/notify_bridge_server/services/test_dispatch.py b/packages/server/src/notify_bridge_server/services/manual_dispatch.py similarity index 100% rename from packages/server/src/notify_bridge_server/services/test_dispatch.py rename to packages/server/src/notify_bridge_server/services/manual_dispatch.py diff --git a/packages/server/src/notify_bridge_server/services/scheduler.py b/packages/server/src/notify_bridge_server/services/scheduler.py index 3336f7b..9402bf4 100644 --- a/packages/server/src/notify_bridge_server/services/scheduler.py +++ b/packages/server/src/notify_bridge_server/services/scheduler.py @@ -85,7 +85,21 @@ def _compute_jitter(interval_seconds: int) -> int: def get_scheduler() -> AsyncIOScheduler: global _scheduler if _scheduler is None: - _scheduler = AsyncIOScheduler() + # Sensible production defaults applied to every job unless overridden: + # * coalesce — collapse a queue of missed runs into one firing after + # a restart / pause, instead of bursting to catch up. + # * misfire_grace_time — accept firings up to 5 min late without + # dropping them silently. + # * max_instances=1 — never run two copies of the same tracker tick + # concurrently; the scheduler already enforces this on add_job, + # but we also set it as the default for safety. + _scheduler = AsyncIOScheduler( + job_defaults={ + "coalesce": True, + "misfire_grace_time": 300, + "max_instances": 1, + }, + ) return _scheduler @@ -279,21 +293,38 @@ async def _refresh_telegram_chat_titles() -> None: async def _cleanup_old_events() -> None: - """Delete EventLog entries older than 90 days.""" + """Delete EventLog / WebhookPayloadLog / ActionExecution rows older than the + configured retention window. A retention of 0 disables the job. + """ from datetime import datetime, timedelta, timezone from sqlmodel import delete from sqlmodel.ext.asyncio.session import AsyncSession + from ..config import settings from ..database.engine import get_engine - from ..database.models import EventLog + from ..database.models import ActionExecution, EventLog, WebhookPayloadLog - cutoff = datetime.now(timezone.utc) - timedelta(days=90) + days = settings.event_log_retention_days + if days <= 0: + _LOGGER.debug("Event log retention disabled (days=0); skipping cleanup") + return + + cutoff = datetime.now(timezone.utc) - timedelta(days=days) engine = get_engine() async with AsyncSession(engine) as session: await session.exec(delete(EventLog).where(EventLog.created_at < cutoff)) + await session.exec( + delete(WebhookPayloadLog).where(WebhookPayloadLog.created_at < cutoff) + ) + await session.exec( + delete(ActionExecution).where(ActionExecution.started_at < cutoff) + ) await session.commit() - _LOGGER.info("Cleaned up event log entries older than %s", cutoff.date()) + _LOGGER.info( + "Cleaned event_log / webhook_payload_log / action_execution older than %s", + cutoff.date(), + ) async def _load_tracker_jobs() -> None: diff --git a/packages/server/src/notify_bridge_server/services/telegram_poller.py b/packages/server/src/notify_bridge_server/services/telegram_poller.py index 7e7e6e7..261773a 100644 --- a/packages/server/src/notify_bridge_server/services/telegram_poller.py +++ b/packages/server/src/notify_bridge_server/services/telegram_poller.py @@ -127,7 +127,14 @@ async def stop_bot_if_unused(bot_id: int) -> None: def schedule_bot_polling(bot_id: int) -> None: - """Add a polling job for a bot (idempotent).""" + """Add a polling job for a bot (idempotent). + + We schedule at a 30 s interval, but each tick calls ``getUpdates`` with + ``timeout=25`` — Telegram holds the connection open until either an + update arrives or the timeout elapses, so in practice the bot streams + updates with sub-second latency while consuming ~2 API calls / minute + per bot (down from 20 under the old 3 s short-poll). + """ scheduler = get_scheduler() job_id = f"telegram_poll_{bot_id}" if scheduler.get_job(job_id): @@ -135,13 +142,13 @@ def schedule_bot_polling(bot_id: int) -> None: scheduler.add_job( _poll_bot, "interval", - seconds=3, + seconds=30, id=job_id, args=[bot_id], replace_existing=True, max_instances=1, ) - _LOGGER.info("Started polling for bot %d", bot_id) + _LOGGER.info("Started polling for bot %d (long-poll, 25s timeout)", bot_id) def unschedule_bot_polling(bot_id: int) -> None: @@ -233,8 +240,10 @@ async def _poll_bot(bot_id: int) -> None: from .http_session import get_http_session http = await get_http_session() client = TelegramClient(http, bot_token) + # Long-poll: hold connection open until an update arrives or 25 s + # elapse. Drastically cuts API calls vs. 3 s short-poll. result = await client.get_updates( - offset=offset + 1 if offset else None, limit=50, + offset=offset + 1 if offset else None, limit=50, timeout=25, ) if not result.get("success"): err_text = str(result.get("error") or "") diff --git a/packages/server/src/notify_bridge_server/services/watcher.py b/packages/server/src/notify_bridge_server/services/watcher.py index a3f3c33..f00dcc7 100644 --- a/packages/server/src/notify_bridge_server/services/watcher.py +++ b/packages/server/src/notify_bridge_server/services/watcher.py @@ -369,7 +369,13 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]: if events and link_data: url_cache, asset_cache = await _get_telegram_caches() - dispatcher = NotificationDispatcher(url_cache=url_cache, asset_cache=asset_cache) + from .http_session import get_http_session + shared_session = await get_http_session() + dispatcher = NotificationDispatcher( + url_cache=url_cache, + asset_cache=asset_cache, + session=shared_session, + ) for event in events: _LOGGER.info( "Dispatching event %s for %s (added=%d removed=%d)", diff --git a/packages/server/tests/__init__.py b/packages/server/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/server/tests/conftest.py b/packages/server/tests/conftest.py new file mode 100644 index 0000000..56b1ace --- /dev/null +++ b/packages/server/tests/conftest.py @@ -0,0 +1,33 @@ +"""Shared pytest fixtures. + +We set the required env vars before any ``notify_bridge_server`` module is +imported so ``Settings()`` passes its startup validation and opens the DB +in a writable temp directory instead of the production ``/data`` default. +""" + +from __future__ import annotations + +import os +import tempfile +from pathlib import Path + +import pytest + +# Provision a writable temp data dir BEFORE the server package is imported — +# Settings() materializes at import time, so env-var overrides have to land +# here (conftest.py) to be effective. +_TMP = Path(tempfile.mkdtemp(prefix="notify-bridge-tests-")) +os.environ["NOTIFY_BRIDGE_DATA_DIR"] = str(_TMP) + +os.environ.setdefault( + "NOTIFY_BRIDGE_SECRET_KEY", + "pytest-secret-key-" + "x" * 40, +) +os.environ.setdefault("NOTIFY_BRIDGE_DEBUG", "false") +os.environ.setdefault("NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS", "http://localhost:8420") + + +@pytest.fixture(scope="session") +def tmp_data_dir() -> Path: + """Expose the already-provisioned temp dir to tests that want the path.""" + return _TMP diff --git a/packages/server/tests/test_config.py b/packages/server/tests/test_config.py new file mode 100644 index 0000000..0e8cac5 --- /dev/null +++ b/packages/server/tests/test_config.py @@ -0,0 +1,64 @@ +"""Unit tests for the Settings validator.""" + +from __future__ import annotations + +import pytest + +from notify_bridge_server.config import Settings, _FORBIDDEN_SECRETS + + +_GOOD = "a" * 40 + + +def _make(**kwargs) -> Settings: + defaults = dict(secret_key=_GOOD, cors_allowed_origins="http://localhost:8420") + defaults.update(kwargs) + return Settings(**defaults) + + +class TestSecretKey: + def test_accepts_long_random_key(self) -> None: + _make() + + def test_rejects_default(self) -> None: + with pytest.raises(ValueError, match="SECURITY"): + _make(secret_key="change-me-in-production") + + def test_rejects_known_dev_keys(self) -> None: + for bad in _FORBIDDEN_SECRETS: + with pytest.raises(ValueError): + _make(secret_key=bad) + + def test_rejects_short_key(self) -> None: + with pytest.raises(ValueError, match="32 characters"): + _make(secret_key="short") + + +class TestCors: + def test_rejects_wildcard(self) -> None: + with pytest.raises(ValueError, match="wildcard"): + _make(cors_allowed_origins="*") + + def test_rejects_missing_scheme(self) -> None: + with pytest.raises(ValueError, match="scheme"): + _make(cors_allowed_origins="example.com") + + def test_accepts_multiple(self) -> None: + cfg = _make(cors_allowed_origins="http://localhost:8420,https://example.com") + assert "http://localhost:8420" in cfg.cors_allowed_origins + + +class TestNumericValidation: + def test_rejects_zero_access_token_expiry(self) -> None: + with pytest.raises(ValueError): + _make(access_token_expire_minutes=0) + + def test_rejects_invalid_port(self) -> None: + with pytest.raises(ValueError): + _make(port=0) + with pytest.raises(ValueError): + _make(port=70000) + + def test_rejects_negative_retention(self) -> None: + with pytest.raises(ValueError): + _make(event_log_retention_days=-1) diff --git a/packages/server/tests/test_discord_retry.py b/packages/server/tests/test_discord_retry.py new file mode 100644 index 0000000..263f138 --- /dev/null +++ b/packages/server/tests/test_discord_retry.py @@ -0,0 +1,46 @@ +"""Discord client 429-retry bounding.""" + +from __future__ import annotations + +import pytest +from aioresponses import aioresponses + +import aiohttp + +from notify_bridge_core.notifications.discord.client import DiscordClient + + +WEBHOOK = "https://discord.com/api/webhooks/123/abc" + + +@pytest.mark.asyncio +async def test_bounded_retries_on_persistent_429() -> None: + """If every response is 429, the client gives up after _MAX_RETRIES.""" + with aioresponses() as mocked: + mocked.post(WEBHOOK, status=429, headers={"Retry-After": "0.001"}, repeat=True) + + async with aiohttp.ClientSession() as sess: + client = DiscordClient(sess) + result = await client.send(WEBHOOK, "hello") + + assert result["success"] is False + # Either the custom "Rate limited" message or the bare HTTP 429 from the + # final attempt — both indicate bounded retries without infinite recursion. + assert "429" in result["error"] or "Rate limited" in result["error"] + + +@pytest.mark.asyncio +async def test_caps_retry_after() -> None: + """A malicious Retry-After: 99999 must not pin the task for hours.""" + with aioresponses() as mocked: + # First call: absurd Retry-After. Second call: success. + mocked.post(WEBHOOK, status=429, headers={"Retry-After": "99999"}) + mocked.post(WEBHOOK, status=204) + + async with aiohttp.ClientSession() as sess: + client = DiscordClient(sess) + # Override the cap to something trivial so the test completes fast. + client._MAX_RETRY_AFTER = 0.001 # type: ignore[attr-defined] + result = await client.send(WEBHOOK, "hello") + + assert result["success"] is True diff --git a/packages/server/tests/test_health.py b/packages/server/tests/test_health.py new file mode 100644 index 0000000..0f05d21 --- /dev/null +++ b/packages/server/tests/test_health.py @@ -0,0 +1,39 @@ +"""Smoke test: app imports, /api/health returns 200, version string present.""" + +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + + +def test_health_endpoint(tmp_data_dir) -> None: # noqa: ARG001 — fixture applies env + from notify_bridge_server.main import app + + # TestClient runs the lifespan on enter/exit, so migrations run once + # against the temp data dir — a genuine integration smoke check. + with TestClient(app) as client: + resp = client.get("/api/health") + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + assert body["version"] != "0.0.0+unknown" + assert "." in body["version"] # looks like a real version + + +def test_ready_endpoint(tmp_data_dir) -> None: # noqa: ARG001 + from notify_bridge_server.main import app + + with TestClient(app) as client: + resp = client.get("/api/ready") + # By the time TestClient yields, lifespan startup has completed. + assert resp.status_code == 200 + assert resp.json()["status"] == "ready" + + +def test_health_is_anonymous(tmp_data_dir) -> None: # noqa: ARG001 + """/api/health must not require auth — the Docker healthcheck depends on it.""" + from notify_bridge_server.main import app + + with TestClient(app) as client: + resp = client.get("/api/health") + assert resp.status_code == 200 diff --git a/packages/server/tests/test_jwt.py b/packages/server/tests/test_jwt.py new file mode 100644 index 0000000..a0320aa --- /dev/null +++ b/packages/server/tests/test_jwt.py @@ -0,0 +1,74 @@ +"""JWT encode/decode round-trips.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +import jwt as pyjwt +import pytest + +from notify_bridge_server.auth.jwt import ( + ALGORITHM, + create_access_token, + create_refresh_token, + decode_token, +) +from notify_bridge_server.config import settings + + +def test_access_token_round_trip() -> None: + token = create_access_token(user_id=1, role="admin", token_version=3) + payload = decode_token(token) + assert payload["sub"] == "1" + assert payload["type"] == "access" + assert payload["role"] == "admin" + assert payload["ver"] == 3 + assert payload["iss"] == settings.jwt_issuer + assert payload["aud"] == settings.jwt_audience + + +def test_refresh_token_round_trip() -> None: + token = create_refresh_token(user_id=7, token_version=2) + payload = decode_token(token) + assert payload["type"] == "refresh" + assert payload["sub"] == "7" + + +def test_decode_rejects_wrong_audience() -> None: + """A token signed with our key but for a different audience is rejected.""" + now = datetime.now(timezone.utc) + forged = pyjwt.encode( + { + "iss": settings.jwt_issuer, + "aud": "other-service", + "sub": "1", + "type": "access", + "ver": 1, + "iat": now, + "exp": now + timedelta(minutes=5), + }, + settings.secret_key, + algorithm=ALGORITHM, + ) + with pytest.raises(pyjwt.InvalidAudienceError): + decode_token(forged) + + +def test_decode_rejects_none_alg() -> None: + """An ``alg: none`` token must never be accepted.""" + now = datetime.now(timezone.utc) + forged = pyjwt.encode( + { + "iss": settings.jwt_issuer, + "aud": settings.jwt_audience, + "sub": "1", + "type": "access", + "ver": 1, + "iat": now, + "exp": now + timedelta(minutes=5), + }, + "", + algorithm="none", + ) + with pytest.raises(pyjwt.InvalidAlgorithmError): + decode_token(forged) diff --git a/packages/server/tests/test_ssrf.py b/packages/server/tests/test_ssrf.py new file mode 100644 index 0000000..bdc5086 --- /dev/null +++ b/packages/server/tests/test_ssrf.py @@ -0,0 +1,59 @@ +"""SSRF guard regression tests.""" + +from __future__ import annotations + +import pytest + +from notify_bridge_core.notifications.ssrf import ( + UnsafeURLError, + avalidate_outbound_url, + validate_outbound_url, +) + + +class TestScheme: + def test_rejects_file_scheme(self) -> None: + with pytest.raises(UnsafeURLError): + validate_outbound_url("file:///etc/passwd") + + def test_rejects_gopher(self) -> None: + with pytest.raises(UnsafeURLError): + validate_outbound_url("gopher://example.com/") + + def test_accepts_https(self) -> None: + # A well-known public host — validated via real DNS so this test is + # skipped when offline. + try: + validate_outbound_url("https://example.com/") + except UnsafeURLError as err: + if "DNS" in str(err): + pytest.skip("No DNS in test environment") + raise + + +class TestBlockedRanges: + @pytest.mark.parametrize( + "url", + [ + "http://127.0.0.1/", + "http://10.0.0.1/", + "http://192.168.1.1/", + "http://169.254.169.254/latest/meta-data/", + "http://[::1]/", + ], + ) + def test_rejects_literal_private(self, url: str) -> None: + with pytest.raises(UnsafeURLError): + validate_outbound_url(url) + + +class TestAsyncValidator: + @pytest.mark.asyncio + async def test_async_rejects_loopback(self) -> None: + with pytest.raises(UnsafeURLError): + await avalidate_outbound_url("http://127.0.0.1/") + + @pytest.mark.asyncio + async def test_async_rejects_bad_scheme(self) -> None: + with pytest.raises(UnsafeURLError): + await avalidate_outbound_url("file:///etc/passwd") diff --git a/scripts/restart-backend.sh b/scripts/restart-backend.sh index b846629..6647849 100644 --- a/scripts/restart-backend.sh +++ b/scripts/restart-backend.sh @@ -24,7 +24,7 @@ fi # Start backend export NOTIFY_BRIDGE_DATA_DIR=./test-data -export NOTIFY_BRIDGE_SECRET_KEY=test-secret-key-minimum-32-chars +export NOTIFY_BRIDGE_SECRET_KEY=dev-only-pwIOUsKmfn4CYWQ9hCRs5lmI3GgrVlXSu2nqFzGW # Dev targets (homelab Immich / Gitea / etc.) live on RFC1918 ranges; the SSRF # guard rejects private addresses by default, which would make trackers fail. export NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1