diff --git a/.gitea/workflows/build.yml b/.gitea/workflows/build.yml
index 56b8a5d..fcbe1c1 100644
--- a/.gitea/workflows/build.yml
+++ b/.gitea/workflows/build.yml
@@ -1,13 +1,75 @@
-name: Build Docker Image
+name: Build and Test
on:
+ push:
+ branches: [master, main]
+ pull_request:
+ branches: [master, main]
workflow_dispatch:
jobs:
- build:
+ test-backend:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- - name: Build Docker image
- run: docker build -t notify-bridge:dev .
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+
+ - name: Install core + server + dev deps
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install -e ./packages/core
+ python -m pip install -e "./packages/server[dev]"
+
+ - name: Run pytest (server)
+ run: |
+ cd packages/server
+ pytest -q --maxfail=1
+
+ test-frontend:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Set up Node
+ uses: actions/setup-node@v4
+ with:
+ node-version: "22"
+ cache: "npm"
+ cache-dependency-path: frontend/package-lock.json
+
+ - name: Install deps
+ run: |
+ cd frontend
+ npm ci
+
+ - name: Svelte check
+ run: |
+ cd frontend
+ npm run check || echo "::warning::svelte-check reported warnings"
+
+ - name: Build
+ run: |
+ cd frontend
+ npm run build
+
+ build-image:
+ needs: [test-backend, test-frontend]
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Build Docker image (no push)
+ uses: docker/build-push-action@v5
+ with:
+ context: .
+ push: false
+ tags: notify-bridge:ci-${{ gitea.sha }}
+ cache-from: type=gha
+ cache-to: type=gha,mode=max
diff --git a/.gitea/workflows/release.yml b/.gitea/workflows/release.yml
index a6891fa..a907977 100644
--- a/.gitea/workflows/release.yml
+++ b/.gitea/workflows/release.yml
@@ -10,7 +10,22 @@ env:
IMAGE_NAME: alexei.dolgolyov/notify-bridge
jobs:
+ test:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+ - name: Install + test
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install -e ./packages/core
+ python -m pip install -e "./packages/server[dev]"
+ cd packages/server && pytest -q --maxfail=1
+
release:
+ needs: test
runs-on: ubuntu-latest
steps:
- name: Checkout repo
@@ -44,16 +59,28 @@ jobs:
- name: Build and push Docker image
uses: docker/build-push-action@v5
+ id: docker_build
with:
context: .
push: true
tags: |
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.tag }}
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.version }}
+ ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${{ gitea.sha }}
${{ steps.version.outputs.is_pre == 'false' && format('{0}/{1}:latest', env.REGISTRY, env.IMAGE_NAME) || '' }}
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache
cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max
+ - name: Vulnerability scan (trivy)
+ uses: aquasecurity/trivy-action@master
+ continue-on-error: true
+ with:
+ image-ref: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.tag }}
+ format: table
+ exit-code: 0
+ severity: HIGH,CRITICAL
+ ignore-unfixed: true
+
- name: Trigger redeploy webhook
if: steps.version.outputs.is_pre == 'false'
continue-on-error: true
diff --git a/docker-compose.yml b/docker-compose.yml
index 4a9d7b3..339e904 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -10,18 +10,38 @@ services:
volumes:
- notify-bridge-data:/data
environment:
+ # REQUIRED — any 32+ byte random string. `openssl rand -hex 32` is one way.
- NOTIFY_BRIDGE_SECRET_KEY=${NOTIFY_BRIDGE_SECRET_KEY:?Set NOTIFY_BRIDGE_SECRET_KEY (min 32 chars)}
- - NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS=${NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS:-*}
- # Homelab target: allow outbound requests to RFC1918 / link-local addresses.
- # The SSRF guard otherwise rejects 10.*/172.16.*/192.168.*/169.254.* hosts,
- # which breaks tracking of Immich / Gitea / etc. running on the same LAN.
- - NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
+ # Comma-separated list of allowed browser origins. Wildcard `*` is
+ # rejected on startup because credentials are enabled.
+ - NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS=${NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS:-http://localhost:8420}
+ # Trusted proxy IPs whose X-Forwarded-For / X-Forwarded-Proto we honor.
+ # Set this to your reverse proxy's IP (e.g. 172.17.0.1 for the default
+ # docker bridge, or `*` only if the container is NOT reachable from the
+ # public internet).
+ - NOTIFY_BRIDGE_FORWARDED_ALLOW_IPS=${NOTIFY_BRIDGE_FORWARDED_ALLOW_IPS:-127.0.0.1}
+ # Opt-in SSRF bypass for private/loopback/link-local hosts (homelab
+ # scenario — tracking an Immich/Gitea instance on the same LAN). DO NOT
+ # enable on a publicly exposed instance.
+ # - NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
healthcheck:
- test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8420/api/health')"]
+ # Use /api/ready (not /api/health) so the container is only reported
+ # healthy after migrations and the scheduler finish booting.
+ test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8420/api/ready', timeout=3)"]
interval: 30s
timeout: 5s
retries: 3
- start_period: 10s
+ start_period: 30s
+ read_only: true
+ tmpfs:
+ - /tmp
+ security_opt:
+ - no-new-privileges:true
+ cap_drop:
+ - ALL
+ mem_limit: 512m
+ cpus: 1.0
+ pids_limit: 256
volumes:
notify-bridge-data:
diff --git a/frontend/src/lib/i18n/en.json b/frontend/src/lib/i18n/en.json
index 0ff5703..ba41a39 100644
--- a/frontend/src/lib/i18n/en.json
+++ b/frontend/src/lib/i18n/en.json
@@ -55,7 +55,8 @@
"passwordTooShort": "Password must be at least 8 characters",
"or": "or",
"loginFailed": "Login failed",
- "setupFailed": "Setup failed"
+ "setupFailed": "Setup failed",
+ "backendUnreachable": "Cannot reach the server. Check that it's running and try again."
},
"dashboard": {
"title": "Dashboard",
diff --git a/frontend/src/lib/i18n/ru.json b/frontend/src/lib/i18n/ru.json
index d1dde07..ee439a0 100644
--- a/frontend/src/lib/i18n/ru.json
+++ b/frontend/src/lib/i18n/ru.json
@@ -55,7 +55,8 @@
"passwordTooShort": "Пароль должен быть не менее 8 символов",
"or": "или",
"loginFailed": "Ошибка входа",
- "setupFailed": "Ошибка настройки"
+ "setupFailed": "Ошибка настройки",
+ "backendUnreachable": "Не удалось подключиться к серверу. Убедитесь, что он запущен, и повторите попытку."
},
"dashboard": {
"title": "Главная",
diff --git a/frontend/src/routes/login/+page.svelte b/frontend/src/routes/login/+page.svelte
index 1528fe9..bff7ec2 100644
--- a/frontend/src/routes/login/+page.svelte
+++ b/frontend/src/routes/login/+page.svelte
@@ -15,13 +15,32 @@
let submitting = $state(false);
let mounted = $state(false);
+ let backendDown = $state(false);
+
onMount(async () => {
initTheme();
mounted = true;
+ // If the user is already signed in (valid access token in storage),
+ // there is no reason to show them the login form. loadUser() runs in
+ // the root layout; we just check the resolved state after a short tick.
+ const { isAuthenticated } = await import('$lib/api');
+ if (isAuthenticated()) {
+ try {
+ await api('/auth/me');
+ goto('/');
+ return;
+ } catch {
+ // Token was stale; fall through to the login form.
+ }
+ }
try {
const res = await api<{ needs_setup: boolean }>('/auth/needs-setup');
if (res.needs_setup) goto('/setup');
- } catch { /* ignore */ }
+ } catch {
+ // The backend is unreachable — surface that distinctly so the user
+ // doesn't blame the login form for a network/backend problem.
+ backendDown = true;
+ }
});
async function handleSubmit(e: SubmitEvent) {
@@ -62,7 +81,12 @@
{t('auth.signInTitle')}
- {#if error}
+ {#if backendDown}
+
+
+ {t('auth.backendUnreachable')}
+
+ {:else if error}
{error}
diff --git a/packages/core/src/notify_bridge_core/notifications/discord/client.py b/packages/core/src/notify_bridge_core/notifications/discord/client.py
index 20dbc2e..cd0ea24 100644
--- a/packages/core/src/notify_bridge_core/notifications/discord/client.py
+++ b/packages/core/src/notify_bridge_core/notifications/discord/client.py
@@ -52,22 +52,46 @@ class DiscordClient:
return {"success": True}
+ _MAX_RETRIES = 3
+ _MAX_RETRY_AFTER = 60.0
+
async def _post(self, url: str, payload: dict) -> dict[str, Any]:
- try:
- async with self._session.post(
- url, json=payload, headers={"Content-Type": "application/json"}
- ) as resp:
- if resp.status == 429:
- retry_after = float(resp.headers.get("Retry-After", "2"))
- _LOGGER.warning("Discord rate limited, retrying after %.1fs", retry_after)
- await asyncio.sleep(retry_after)
- return await self._post(url, payload)
- if 200 <= resp.status < 300:
- return {"success": True}
- body = await resp.text()
- return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
- except aiohttp.ClientError as e:
- return {"success": False, "error": str(e)}
+ """POST with bounded 429 retry.
+
+ We cap retries at _MAX_RETRIES and the ``Retry-After`` header at
+ _MAX_RETRY_AFTER seconds so a hostile or misbehaving upstream cannot
+ pin the dispatch task indefinitely.
+ """
+ for attempt in range(self._MAX_RETRIES + 1):
+ try:
+ async with self._session.post(
+ url,
+ json=payload,
+ headers={"Content-Type": "application/json"},
+ allow_redirects=False,
+ ) as resp:
+ if resp.status == 429 and attempt < self._MAX_RETRIES:
+ try:
+ retry_after = float(resp.headers.get("Retry-After", "2"))
+ except (TypeError, ValueError):
+ retry_after = 2.0
+ retry_after = max(0.0, min(retry_after, self._MAX_RETRY_AFTER))
+ _LOGGER.warning(
+ "Discord rate limited, retrying after %.1fs (attempt %d/%d)",
+ retry_after, attempt + 1, self._MAX_RETRIES,
+ )
+ await asyncio.sleep(retry_after)
+ continue
+ if 200 <= resp.status < 300:
+ return {"success": True}
+ body = await resp.text()
+ return {
+ "success": False,
+ "error": f"HTTP {resp.status}: {body[:200]}",
+ }
+ except aiohttp.ClientError as e:
+ return {"success": False, "error": str(e)}
+ return {"success": False, "error": "Rate limited (retries exhausted)"}
def _split_message(text: str, limit: int) -> list[str]:
diff --git a/packages/core/src/notify_bridge_core/notifications/dispatcher.py b/packages/core/src/notify_bridge_core/notifications/dispatcher.py
index c7aecc0..71c0091 100644
--- a/packages/core/src/notify_bridge_core/notifications/dispatcher.py
+++ b/packages/core/src/notify_bridge_core/notifications/dispatcher.py
@@ -3,10 +3,11 @@
from __future__ import annotations
import asyncio
+import contextlib
import logging
import uuid
from dataclasses import dataclass, field
-from typing import Any
+from typing import Any, AsyncIterator
import aiohttp
@@ -14,7 +15,7 @@ from notify_bridge_core.log_context import bind_log_context, dispatch_id_var
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.templates.context import build_template_context
from notify_bridge_core.templates.renderer import render_template
-from .ssrf import UnsafeURLError, validate_outbound_url
+from .ssrf import UnsafeURLError, avalidate_outbound_url
_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
@@ -84,9 +85,28 @@ class NotificationDispatcher:
*,
url_cache: TelegramFileCache | None = None,
asset_cache: TelegramFileCache | None = None,
+ session: aiohttp.ClientSession | None = None,
) -> None:
self._url_cache = url_cache
self._asset_cache = asset_cache
+ # Optional shared session owned by the caller; when supplied we reuse
+ # its connection pool instead of opening a fresh per-dispatch session
+ # (saves a TLS handshake per outbound call).
+ self._shared_session = session
+
+ @contextlib.asynccontextmanager
+ async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]:
+ """Yield an aiohttp session, reusing the shared one if provided.
+
+ When a shared session was passed in ``__init__`` we yield it without
+ closing (the caller owns its lifetime). Otherwise we open a
+ short-lived session with our default timeout and close it on exit.
+ """
+ if self._shared_session is not None and not self._shared_session.closed:
+ yield self._shared_session
+ return
+ async with self._session_ctx() as session:
+ yield session
async def dispatch(
self,
@@ -308,7 +328,7 @@ class NotificationDispatcher:
media_assets.append(asset)
results: list[dict[str, Any]] = []
- async with _new_session() as session:
+ async with self._session_ctx() as session:
# Preload all asset bytes once so (a) TelegramClient can skip its
# own download and (b) we know exact upload sizes in time for the
# oversize warning in the rendered text.
@@ -378,13 +398,13 @@ class NotificationDispatcher:
return {"success": False, "error": "No receivers configured"}
results: list[dict[str, Any]] = []
- async with _new_session() as session:
+ async with self._session_ctx() as session:
for receiver in target.receivers:
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
results.append({"success": False, "error": "Invalid webhook receiver"})
continue
try:
- validate_outbound_url(receiver.url)
+ await avalidate_outbound_url(receiver.url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
@@ -452,14 +472,14 @@ class NotificationDispatcher:
username = target.config.get("username")
results: list[dict[str, Any]] = []
- async with _new_session() as session:
+ async with self._session_ctx() as session:
client = DiscordClient(session)
for receiver in target.receivers:
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
results.append({"success": False, "error": "Invalid discord receiver"})
continue
try:
- validate_outbound_url(receiver.webhook_url)
+ await avalidate_outbound_url(receiver.webhook_url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
@@ -478,14 +498,14 @@ class NotificationDispatcher:
username = target.config.get("username")
results: list[dict[str, Any]] = []
- async with _new_session() as session:
+ async with self._session_ctx() as session:
client = SlackClient(session)
for receiver in target.receivers:
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
results.append({"success": False, "error": "Invalid slack receiver"})
continue
try:
- validate_outbound_url(receiver.webhook_url)
+ await avalidate_outbound_url(receiver.webhook_url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
@@ -504,14 +524,14 @@ class NotificationDispatcher:
if not target.receivers:
return {"success": False, "error": "No receivers configured"}
try:
- validate_outbound_url(server_url)
+ await avalidate_outbound_url(server_url)
except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe ntfy server_url: {err}"}
title = f"{event.event_type.value}: {event.collection_name}"
results: list[dict[str, Any]] = []
- async with _new_session() as session:
+ async with self._session_ctx() as session:
client = NtfyClient(session)
for receiver in target.receivers:
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
@@ -535,7 +555,7 @@ class NotificationDispatcher:
if not homeserver or not access_token:
return {"success": False, "error": "Missing Matrix homeserver_url or access_token"}
try:
- validate_outbound_url(homeserver)
+ await avalidate_outbound_url(homeserver)
except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"}
@@ -543,7 +563,7 @@ class NotificationDispatcher:
return {"success": False, "error": "No receivers configured"}
results: list[dict[str, Any]] = []
- async with _new_session() as session:
+ async with self._session_ctx() as session:
client = MatrixClient(session, homeserver, access_token)
for receiver in target.receivers:
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
diff --git a/packages/core/src/notify_bridge_core/notifications/matrix/client.py b/packages/core/src/notify_bridge_core/notifications/matrix/client.py
index f7c06f3..224239f 100644
--- a/packages/core/src/notify_bridge_core/notifications/matrix/client.py
+++ b/packages/core/src/notify_bridge_core/notifications/matrix/client.py
@@ -68,7 +68,9 @@ class MatrixClient:
}
try:
- async with self._session.put(url, json=body, headers=headers) as resp:
+ async with self._session.put(
+ url, json=body, headers=headers, allow_redirects=False,
+ ) as resp:
if 200 <= resp.status < 300:
return {"success": True}
resp_body = await resp.text()
diff --git a/packages/core/src/notify_bridge_core/notifications/ntfy/client.py b/packages/core/src/notify_bridge_core/notifications/ntfy/client.py
index 41717bf..d5be0f9 100644
--- a/packages/core/src/notify_bridge_core/notifications/ntfy/client.py
+++ b/packages/core/src/notify_bridge_core/notifications/ntfy/client.py
@@ -51,7 +51,9 @@ class NtfyClient:
headers["Authorization"] = f"Bearer {auth_token}"
try:
- async with self._session.post(url, json=payload, headers=headers) as resp:
+ async with self._session.post(
+ url, json=payload, headers=headers, allow_redirects=False,
+ ) as resp:
if 200 <= resp.status < 300:
return {"success": True}
body = await resp.text()
diff --git a/packages/core/src/notify_bridge_core/notifications/slack/client.py b/packages/core/src/notify_bridge_core/notifications/slack/client.py
index ea985ca..681b286 100644
--- a/packages/core/src/notify_bridge_core/notifications/slack/client.py
+++ b/packages/core/src/notify_bridge_core/notifications/slack/client.py
@@ -38,6 +38,7 @@ class SlackClient:
webhook_url,
json=payload,
headers={"Content-Type": "application/json"},
+ allow_redirects=False,
) as resp:
if resp.status == 429:
_LOGGER.warning("Slack rate limited")
diff --git a/packages/core/src/notify_bridge_core/notifications/ssrf.py b/packages/core/src/notify_bridge_core/notifications/ssrf.py
index e0902bb..66aea5a 100644
--- a/packages/core/src/notify_bridge_core/notifications/ssrf.py
+++ b/packages/core/src/notify_bridge_core/notifications/ssrf.py
@@ -12,14 +12,25 @@ development against localhost services.
from __future__ import annotations
+import asyncio
import ipaddress
+import logging
import os
import socket
from urllib.parse import urlparse
+_LOGGER = logging.getLogger(__name__)
+
_ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1"
_ALLOWED_SCHEMES = {"http", "https"}
+if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner
+ _LOGGER.warning(
+ "SSRF guard: private-URL bypass ENABLED "
+ "(NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1). Requests to RFC1918 / "
+ "loopback / link-local hosts will be permitted."
+ )
+
class UnsafeURLError(ValueError):
"""Raised when a URL targets a disallowed network destination."""
@@ -36,13 +47,7 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
)
-def validate_outbound_url(url: str) -> str:
- """Validate ``url`` is safe to fetch; returns the URL on success.
-
- Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
- is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
- private addresses are permitted but the scheme check still applies.
- """
+def _check_scheme_host(url: str) -> tuple[str, str]:
if not isinstance(url, str) or not url:
raise UnsafeURLError("URL is empty")
parsed = urlparse(url)
@@ -51,6 +56,31 @@ def validate_outbound_url(url: str) -> str:
host = parsed.hostname
if not host:
raise UnsafeURLError("URL has no host")
+ return parsed.scheme, host
+
+
+def _check_resolved_addresses(host: str, infos: list[tuple]) -> None:
+ for info in infos:
+ sockaddr = info[4]
+ try:
+ ip = ipaddress.ip_address(sockaddr[0])
+ except ValueError:
+ continue
+ if _is_blocked_ip(ip):
+ raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
+
+
+def validate_outbound_url(url: str) -> str:
+ """Validate ``url`` is safe to fetch; returns the URL on success.
+
+ Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
+ is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
+ private addresses are permitted but the scheme check still applies.
+
+ Synchronous; uses blocking ``socket.getaddrinfo``. Prefer
+ :func:`avalidate_outbound_url` from async code paths.
+ """
+ _, host = _check_scheme_host(url)
if _ALLOW_PRIVATE:
return url
@@ -64,17 +94,37 @@ def validate_outbound_url(url: str) -> str:
except ValueError:
pass
- # Hostname — resolve and reject if any resolution is in a blocked range.
try:
infos = socket.getaddrinfo(host, None)
except socket.gaierror as exc:
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
- for info in infos:
- sockaddr = info[4]
- try:
- ip = ipaddress.ip_address(sockaddr[0])
- except ValueError:
- continue
- if _is_blocked_ip(ip):
- raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
+ _check_resolved_addresses(host, infos)
+ return url
+
+
+async def avalidate_outbound_url(url: str) -> str:
+ """Async variant that resolves DNS via the running loop's resolver.
+
+ Use this from ``async def`` code paths to avoid blocking the event
+ loop on DNS lookups.
+ """
+ _, host = _check_scheme_host(url)
+
+ if _ALLOW_PRIVATE:
+ return url
+
+ try:
+ ip = ipaddress.ip_address(host)
+ if _is_blocked_ip(ip):
+ raise UnsafeURLError(f"Host {host} is in a blocked range")
+ return url
+ except ValueError:
+ pass
+
+ loop = asyncio.get_running_loop()
+ try:
+ infos = await loop.getaddrinfo(host, None)
+ except socket.gaierror as exc:
+ raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
+ _check_resolved_addresses(host, infos)
return url
diff --git a/packages/core/src/notify_bridge_core/notifications/webhook/client.py b/packages/core/src/notify_bridge_core/notifications/webhook/client.py
index c5ef57f..0f3cafd 100644
--- a/packages/core/src/notify_bridge_core/notifications/webhook/client.py
+++ b/packages/core/src/notify_bridge_core/notifications/webhook/client.py
@@ -7,7 +7,7 @@ from typing import Any
import aiohttp
-from ..ssrf import UnsafeURLError, validate_outbound_url
+from ..ssrf import UnsafeURLError, avalidate_outbound_url
_LOGGER = logging.getLogger(__name__)
@@ -24,7 +24,7 @@ class WebhookClient:
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
try:
- validate_outbound_url(self._url)
+ await avalidate_outbound_url(self._url)
except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe URL: {err}"}
try:
@@ -33,6 +33,7 @@ class WebhookClient:
json=payload,
headers={"Content-Type": "application/json", **self._headers},
timeout=_DEFAULT_TIMEOUT,
+ allow_redirects=False,
) as response:
if 200 <= response.status < 300:
return {"success": True, "status_code": response.status}
diff --git a/packages/core/src/notify_bridge_core/providers/nut/client.py b/packages/core/src/notify_bridge_core/providers/nut/client.py
index c1a8dbe..3d0d9d4 100644
--- a/packages/core/src/notify_bridge_core/providers/nut/client.py
+++ b/packages/core/src/notify_bridge_core/providers/nut/client.py
@@ -12,6 +12,7 @@ _LOGGER = logging.getLogger(__name__)
_DEFAULT_PORT = 3493
_READ_TIMEOUT = 10.0
+_WRITE_TIMEOUT = 10.0
_CONNECT_TIMEOUT = 5.0
# Allowed characters for NUT protocol identifiers (UPS names, variable names).
@@ -84,14 +85,26 @@ class NutClient:
await self._command(f"PASSWORD {self._password}")
async def disconnect(self) -> None:
- """Send LOGOUT and close the TCP connection."""
+ """Send LOGOUT and close the TCP connection.
+
+ ``drain`` is bounded by ``_WRITE_TIMEOUT`` so a half-closed peer
+ cannot hold the disconnect indefinitely — a tracker tick would
+ otherwise be pinned by a stuck NUT server and block the scheduler
+ slot (``max_instances=1``).
+ """
if self._writer is not None:
try:
self._writer.write(b"LOGOUT\n")
- await self._writer.drain()
- except OSError:
+ await asyncio.wait_for(self._writer.drain(), timeout=_WRITE_TIMEOUT)
+ except (OSError, asyncio.TimeoutError):
pass
self._writer.close()
+ try:
+ await asyncio.wait_for(
+ self._writer.wait_closed(), timeout=_WRITE_TIMEOUT,
+ )
+ except (OSError, asyncio.TimeoutError):
+ pass
self._reader = None
self._writer = None
@@ -135,7 +148,10 @@ class NutClient:
if self._writer is None:
raise NutClientError("Not connected")
self._writer.write(f"{cmd}\n".encode())
- await self._writer.drain()
+ try:
+ await asyncio.wait_for(self._writer.drain(), timeout=_WRITE_TIMEOUT)
+ except asyncio.TimeoutError as exc:
+ raise NutClientError("Write timeout") from exc
async def _readline(self) -> str:
"""Read one line from upsd, stripping trailing newline."""
diff --git a/packages/core/src/notify_bridge_core/storage.py b/packages/core/src/notify_bridge_core/storage.py
index d92161e..16140ee 100644
--- a/packages/core/src/notify_bridge_core/storage.py
+++ b/packages/core/src/notify_bridge_core/storage.py
@@ -2,8 +2,10 @@
from __future__ import annotations
+import asyncio
import json
import logging
+import os
from pathlib import Path
from typing import Any, Protocol, runtime_checkable
@@ -19,34 +21,58 @@ class StorageBackend(Protocol):
async def remove(self) -> None: ...
+def _read_file(path: Path) -> str | None:
+ if not path.exists():
+ return None
+ return path.read_text(encoding="utf-8")
+
+
+def _atomic_write(path: Path, payload: str) -> None:
+ """Write atomically: tmp file + rename. Prevents half-written files on crash."""
+ path.parent.mkdir(parents=True, exist_ok=True)
+ tmp = path.with_suffix(path.suffix + ".tmp")
+ tmp.write_text(payload, encoding="utf-8")
+ os.replace(tmp, path)
+
+
+def _remove_file(path: Path) -> None:
+ if path.exists():
+ path.unlink()
+
+
class JsonFileBackend:
- """Simple JSON file storage backend."""
+ """Simple JSON file storage backend.
+
+ All blocking I/O is wrapped in ``asyncio.to_thread`` so callers can
+ ``await load() / save() / remove()`` without stalling the event loop.
+ """
def __init__(self, path: Path) -> None:
self._path = path
async def load(self) -> dict[str, Any] | None:
- if not self._path.exists():
+ try:
+ text = await asyncio.to_thread(_read_file, self._path)
+ except OSError as err:
+ _LOGGER.warning("Failed to load %s: %s", self._path, err)
+ return None
+ if text is None:
return None
try:
- text = self._path.read_text(encoding="utf-8")
return json.loads(text)
- except (json.JSONDecodeError, OSError) as err:
- _LOGGER.warning("Failed to load %s: %s", self._path, err)
+ except json.JSONDecodeError as err:
+ _LOGGER.warning("Failed to parse %s: %s", self._path, err)
return None
async def save(self, data: dict[str, Any]) -> None:
+ payload = json.dumps(data, default=str)
try:
- self._path.parent.mkdir(parents=True, exist_ok=True)
- self._path.write_text(
- json.dumps(data, default=str), encoding="utf-8"
- )
+ await asyncio.to_thread(_atomic_write, self._path, payload)
except OSError as err:
_LOGGER.error("Failed to save %s: %s", self._path, err)
async def remove(self) -> None:
try:
- if self._path.exists():
- self._path.unlink()
+ await asyncio.to_thread(_remove_file, self._path)
except OSError as err:
_LOGGER.error("Failed to remove %s: %s", self._path, err)
diff --git a/packages/server/pyproject.toml b/packages/server/pyproject.toml
index e77d724..7e3c567 100644
--- a/packages/server/pyproject.toml
+++ b/packages/server/pyproject.toml
@@ -28,6 +28,7 @@ dev = [
"pytest>=8.0",
"pytest-asyncio>=0.23",
"httpx>=0.27",
+ "aioresponses>=0.7",
]
[project.scripts]
@@ -35,3 +36,14 @@ notify-bridge = "notify_bridge_server.main:run"
[tool.hatch.build.targets.wheel]
packages = ["src/notify_bridge_server"]
+
+[tool.pytest.ini_options]
+asyncio_mode = "auto"
+testpaths = ["tests"]
+# The default filter doesn't let SQLAlchemy warnings fail the suite, which
+# matters because our migrations emit a handful of deprecation warnings we
+# don't want to suppress at source.
+filterwarnings = [
+ "ignore::DeprecationWarning:passlib",
+ "ignore::DeprecationWarning:bcrypt",
+]
diff --git a/packages/server/src/notify_bridge_server/api/email_bots.py b/packages/server/src/notify_bridge_server/api/email_bots.py
index 479aeb4..e6a33c3 100644
--- a/packages/server/src/notify_bridge_server/api/email_bots.py
+++ b/packages/server/src/notify_bridge_server/api/email_bots.py
@@ -93,7 +93,14 @@ async def update_email_bot(
session: AsyncSession = Depends(get_session),
):
bot = await _get_user_bot(session, bot_id, user.id)
- for field, value in body.model_dump(exclude_unset=True).items():
+ updates = body.model_dump(exclude_unset=True)
+ # Reject the masked value the GET response returns so the stored password
+ # is preserved if the user saves without retyping it.
+ if "smtp_password" in updates:
+ pw = updates["smtp_password"]
+ if isinstance(pw, str) and pw.startswith("***"):
+ updates.pop("smtp_password")
+ for field, value in updates.items():
setattr(bot, field, value)
session.add(bot)
await session.commit()
diff --git a/packages/server/src/notify_bridge_server/api/matrix_bots.py b/packages/server/src/notify_bridge_server/api/matrix_bots.py
index 644a6a7..f5861e7 100644
--- a/packages/server/src/notify_bridge_server/api/matrix_bots.py
+++ b/packages/server/src/notify_bridge_server/api/matrix_bots.py
@@ -7,6 +7,11 @@ from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
+from notify_bridge_core.notifications.ssrf import (
+ UnsafeURLError,
+ avalidate_outbound_url,
+)
+
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import MatrixBot, User
@@ -33,6 +38,21 @@ class MatrixBotUpdate(BaseModel):
display_name: str | None = None
+def _is_masked_secret(value: str | None) -> bool:
+ """True when a field still carries our masked placeholder."""
+ return bool(value) and (value.startswith("***") or "..." in value)
+
+
+async def _validate_homeserver_url(url: str) -> None:
+ """Reject homeserver URLs that point to blocked networks."""
+ try:
+ await avalidate_outbound_url(url)
+ except UnsafeURLError as err:
+ raise HTTPException(
+ status_code=400, detail=f"Invalid homeserver_url: {err}"
+ ) from err
+
+
@router.get("")
async def list_matrix_bots(
user: User = Depends(get_current_user),
@@ -50,6 +70,7 @@ async def create_matrix_bot(
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
+ await _validate_homeserver_url(body.homeserver_url)
bot = MatrixBot(user_id=user.id, **body.model_dump())
session.add(bot)
await session.commit()
@@ -74,7 +95,19 @@ async def update_matrix_bot(
session: AsyncSession = Depends(get_session),
):
bot = await _get_user_bot(session, bot_id, user.id)
- for field, value in body.model_dump(exclude_unset=True).items():
+ updates = body.model_dump(exclude_unset=True)
+
+ # Re-validate homeserver_url whenever the client supplies a new one so
+ # no private/loopback target can ever be saved, even via update.
+ if "homeserver_url" in updates and updates["homeserver_url"]:
+ await _validate_homeserver_url(updates["homeserver_url"])
+
+ # Never accept the masked placeholder the GET response returns. If the
+ # client echoes it back, keep the stored secret.
+ if "access_token" in updates and _is_masked_secret(updates["access_token"]):
+ updates.pop("access_token")
+
+ for field, value in updates.items():
setattr(bot, field, value)
session.add(bot)
await session.commit()
@@ -108,15 +141,17 @@ async def test_matrix_bot(
If room_id is not provided, just verifies the access token by calling /whoami.
"""
bot = await _get_user_bot(session, bot_id, user.id)
+ # Defense-in-depth: even though create/update validate the URL, a bot row
+ # written before this guard was added could still point at a blocked host.
+ await _validate_homeserver_url(bot.homeserver_url)
import aiohttp
from ..services.http_session import get_http_session
http = await get_http_session()
- # Verify token with /whoami
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
headers = {"Authorization": f"Bearer {bot.access_token}"}
try:
- async with http.get(whoami_url, headers=headers) as resp:
+ async with http.get(whoami_url, headers=headers, allow_redirects=False) as resp:
if resp.status != 200:
body = await resp.text()
return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"}
@@ -126,7 +161,6 @@ async def test_matrix_bot(
result = {"success": True, "user_id": whoami.get("user_id", "")}
- # Optionally send a test message
if room_id:
from ..services.notifier import _get_test_message
from notify_bridge_core.notifications.matrix.client import MatrixClient
@@ -148,7 +182,7 @@ def _response(bot: MatrixBot) -> dict:
"name": bot.name,
"icon": bot.icon,
"homeserver_url": bot.homeserver_url,
- "access_token": f"{bot.access_token[:8]}...{bot.access_token[-4:]}" if len(bot.access_token) > 12 else "***",
+ "access_token": f"***{bot.access_token[-4:]}" if len(bot.access_token) > 4 else "***",
"display_name": bot.display_name,
"created_at": bot.created_at.isoformat(),
}
diff --git a/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py b/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py
index 5967d6f..43a4afe 100644
--- a/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py
+++ b/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py
@@ -22,7 +22,7 @@ from ..database.models import (
User,
)
from ..services.notifier import send_test_notification
-from ..services.test_dispatch import dispatch_test_notification
+from ..services.manual_dispatch import dispatch_test_notification
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__)
diff --git a/packages/server/src/notify_bridge_server/api/notification_trackers.py b/packages/server/src/notify_bridge_server/api/notification_trackers.py
index bca3181..eed6e07 100644
--- a/packages/server/src/notify_bridge_server/api/notification_trackers.py
+++ b/packages/server/src/notify_bridge_server/api/notification_trackers.py
@@ -11,6 +11,7 @@ from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import (
EventLog,
+ NotificationTarget,
NotificationTracker,
NotificationTrackerState,
NotificationTrackerTarget,
@@ -54,11 +55,79 @@ async def list_notification_trackers(
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
+ # Batched loader: pull trackers, then all their tracker-target links in
+ # a single query, then the referenced targets in a single query. Avoids
+ # the old 1 + N + N*M pattern that ran ~60 round-trips for 10 trackers.
result = await session.exec(
select(NotificationTracker).where(NotificationTracker.user_id == user.id)
)
- trackers = result.all()
- return [await _tracker_response(session, t) for t in trackers]
+ trackers = list(result.all())
+ if not trackers:
+ return []
+
+ tracker_ids = [t.id for t in trackers]
+ tt_result = await session.exec(
+ select(NotificationTrackerTarget).where(
+ NotificationTrackerTarget.tracker_id.in_(tracker_ids)
+ )
+ )
+ tt_rows = list(tt_result.all())
+
+ target_ids = {tt.target_id for tt in tt_rows}
+ targets_by_id: dict[int, NotificationTarget] = {}
+ if target_ids:
+ tgt_result = await session.exec(
+ select(NotificationTarget).where(NotificationTarget.id.in_(target_ids))
+ )
+ targets_by_id = {t.id: t for t in tgt_result.all()}
+
+ tts_by_tracker: dict[int, list[NotificationTrackerTarget]] = {}
+ for tt in tt_rows:
+ tts_by_tracker.setdefault(tt.tracker_id, []).append(tt)
+
+ return [
+ _build_tracker_response(t, tts_by_tracker.get(t.id, []), targets_by_id)
+ for t in trackers
+ ]
+
+
+def _build_tracker_response(
+ t: NotificationTracker,
+ tts: list[NotificationTrackerTarget],
+ targets_by_id: dict[int, NotificationTarget],
+) -> dict:
+ """In-memory assembler for a tracker + its pre-loaded links/targets."""
+ tracker_targets = []
+ for tt in tts:
+ target = targets_by_id.get(tt.target_id)
+ tracker_targets.append({
+ "id": tt.id,
+ "tracker_id": tt.tracker_id,
+ "target_id": tt.target_id,
+ "target_name": target.name if target else None,
+ "target_type": target.type if target else None,
+ "target_icon": target.icon if target else None,
+ "tracking_config_id": tt.tracking_config_id,
+ "template_config_id": tt.template_config_id,
+ "enabled": tt.enabled,
+ "quiet_hours_start": tt.quiet_hours_start,
+ "quiet_hours_end": tt.quiet_hours_end,
+ "created_at": tt.created_at.isoformat(),
+ })
+ return {
+ "id": t.id,
+ "name": t.name,
+ "icon": t.icon,
+ "provider_id": t.provider_id,
+ "collection_ids": t.collection_ids,
+ "scan_interval": t.scan_interval,
+ "batch_duration": t.batch_duration,
+ "default_tracking_config_id": t.default_tracking_config_id,
+ "default_template_config_id": t.default_template_config_id,
+ "enabled": t.enabled,
+ "tracker_targets": tracker_targets,
+ "created_at": t.created_at.isoformat(),
+ }
@router.post("", status_code=status.HTTP_201_CREATED)
diff --git a/packages/server/src/notify_bridge_server/api/providers.py b/packages/server/src/notify_bridge_server/api/providers.py
index 400b8a5..71b06bc 100644
--- a/packages/server/src/notify_bridge_server/api/providers.py
+++ b/packages/server/src/notify_bridge_server/api/providers.py
@@ -306,16 +306,31 @@ async def update_provider(
if body.icon is not None:
provider.icon = body.icon
- config_changed = body.config is not None and body.config != provider.config
if body.config is not None:
- _validate_provider_config(provider.type, body.config)
- provider.config = body.config
+ # Merge rather than replace so the masked secrets the frontend
+ # receives on GET cannot silently nuke the stored values when the
+ # user saves without re-entering them. Any field that still carries
+ # our mask placeholder ("***…") is dropped from the incoming body.
+ incoming = dict(body.config)
+ for secret_field in (
+ "api_key", "api_token", "webhook_secret", "password",
+ "client_secret", "refresh_token",
+ ):
+ value = incoming.get(secret_field)
+ if isinstance(value, str) and value.startswith("***"):
+ incoming.pop(secret_field, None)
+ new_config = {**provider.config, **incoming}
+ _validate_provider_config(provider.type, new_config)
+ config_changed = new_config != provider.config
+ provider.config = new_config
- # Re-validate connection when config changes for known provider types
- if config_changed:
- test_result = await _validate_provider_connection(provider)
- if test_result.get("external_domain"):
- provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
+ if config_changed:
+ test_result = await _validate_provider_connection(provider)
+ if test_result.get("external_domain"):
+ provider.config = {
+ **provider.config,
+ "external_domain": test_result["external_domain"],
+ }
session.add(provider)
await session.commit()
diff --git a/packages/server/src/notify_bridge_server/api/users.py b/packages/server/src/notify_bridge_server/api/users.py
index d9648bc..ec6972e 100644
--- a/packages/server/src/notify_bridge_server/api/users.py
+++ b/packages/server/src/notify_bridge_server/api/users.py
@@ -1,5 +1,6 @@
"""User management API routes (admin only)."""
+import asyncio
import logging
from fastapi import APIRouter, Depends, HTTPException, status
@@ -14,6 +15,15 @@ from ..auth.dependencies import require_admin
from ..database.engine import get_session
from ..database.models import User
+
+async def _hash_password(password: str) -> str:
+ """Run bcrypt off the event loop. Matches the helper in auth/routes.py."""
+
+ def _work() -> str:
+ return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
+
+ return await asyncio.to_thread(_work)
+
_LOGGER = logging.getLogger(__name__)
router = APIRouter(prefix="/api/users", tags=["users"])
@@ -36,8 +46,12 @@ async def list_users(
admin: User = Depends(require_admin),
session: AsyncSession = Depends(get_session),
):
- """List all users (admin only)."""
- result = await session.exec(select(User))
+ """List all users (admin only).
+
+ Excludes the internal ``__system__`` placeholder (id=0) used as the
+ owner of default templates/configs — it is never a real account.
+ """
+ result = await session.exec(select(User).where(User.id != 0))
return [
{"id": u.id, "username": u.username, "role": u.role, "created_at": u.created_at.isoformat()}
for u in result.all()
@@ -61,7 +75,7 @@ async def create_user(
user = User(
username=body.username,
- hashed_password=bcrypt.hashpw(body.password.encode(), bcrypt.gensalt()).decode(),
+ hashed_password=await _hash_password(body.password),
role=body.role if body.role in ("admin", "user") else "user",
)
session.add(user)
@@ -162,7 +176,7 @@ async def reset_user_password(
raise HTTPException(status_code=404, detail="User not found")
if len(body.new_password) < 8:
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
- user.hashed_password = bcrypt.hashpw(body.new_password.encode(), bcrypt.gensalt()).decode()
+ user.hashed_password = await _hash_password(body.new_password)
# Invalidate all prior JWTs issued for this user — matches the self-serve
# password-change path in auth/routes.py.
user.token_version = (user.token_version or 1) + 1
diff --git a/packages/server/src/notify_bridge_server/api/webhooks.py b/packages/server/src/notify_bridge_server/api/webhooks.py
index 1bdab14..973691f 100644
--- a/packages/server/src/notify_bridge_server/api/webhooks.py
+++ b/packages/server/src/notify_bridge_server/api/webhooks.py
@@ -37,6 +37,42 @@ _LOGGER = logging.getLogger(__name__)
router = APIRouter(prefix="/api/webhooks", tags=["webhooks"])
+# Hard cap on inbound webhook body size (1 MiB is far larger than anything
+# legitimate providers send and keeps the worst-case memory footprint bounded
+# when a malicious peer lies about Content-Length or streams slowly).
+_MAX_WEBHOOK_BODY_BYTES = 1_000_000
+
+
+async def _read_bounded_body(request: Request, limit: int = _MAX_WEBHOOK_BODY_BYTES) -> bytes:
+ """Reject oversized inbound bodies before they exhaust memory.
+
+ First checks ``Content-Length`` (fast-path for honest peers), then
+ streams the body in chunks enforcing the same cap on actual bytes
+ received so a peer that lies about Content-Length cannot slip through.
+ """
+ declared = request.headers.get("content-length")
+ if declared:
+ try:
+ if int(declared) > limit:
+ raise HTTPException(
+ status_code=413,
+ detail=f"Payload too large (max {limit} bytes)",
+ )
+ except ValueError:
+ raise HTTPException(status_code=400, detail="Invalid Content-Length")
+
+ chunks: list[bytes] = []
+ size = 0
+ async for chunk in request.stream():
+ size += len(chunk)
+ if size > limit:
+ raise HTTPException(
+ status_code=413,
+ detail=f"Payload too large (max {limit} bytes)",
+ )
+ chunks.append(chunk)
+ return b"".join(chunks)
+
async def _get_provider_by_token(
session: AsyncSession, token: str, expected_type: str,
@@ -169,7 +205,8 @@ async def _dispatch_webhook_event(
))
# Dispatch to targets
- dispatcher = NotificationDispatcher()
+ from ..services.http_session import get_http_session
+ dispatcher = NotificationDispatcher(session=await get_http_session())
target_configs = _build_target_configs(event, link_data, provider_config, app_tz)
if target_configs:
results = await dispatcher.dispatch(event, target_configs)
@@ -203,7 +240,7 @@ async def gitea_webhook(token: str, request: Request):
webhook_secret = (provider.config or {}).get("webhook_secret", "")
# Read raw body for HMAC check
- raw_body = await request.body()
+ raw_body = await _read_bounded_body(request)
if not webhook_secret:
raise HTTPException(
@@ -221,8 +258,8 @@ async def gitea_webhook(token: str, request: Request):
return {"ok": True, "skipped": "no event header"}
try:
- payload = await request.json()
- except (json.JSONDecodeError, ValueError):
+ payload = json.loads(raw_body.decode("utf-8"))
+ except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
raise HTTPException(status_code=400, detail="Invalid JSON")
event = parse_gitea_webhook(event_header, payload, provider.name)
@@ -280,10 +317,10 @@ async def planka_webhook(token: str, request: Request):
if not _verify_planka_token(webhook_secret, request):
raise HTTPException(status_code=403, detail="Invalid token")
- # Parse payload
+ # Parse payload from the bounded raw_body we already read.
try:
- payload = await request.json()
- except (json.JSONDecodeError, ValueError):
+ payload = json.loads(raw_body.decode("utf-8"))
+ except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
raise HTTPException(status_code=400, detail="Invalid JSON")
event_type = payload.get("type", "")
@@ -446,23 +483,22 @@ async def generic_webhook(token: str, request: Request):
store_payloads = provider_config.get("store_payloads", True)
max_stored = min(max(int(provider_config.get("max_stored_payloads", 20)), 1), 100)
- raw_body = await request.body()
+ raw_body = await _read_bounded_body(request)
- # Enforce payload size limit BEFORE parsing JSON
- if len(raw_body) > 1_000_000:
- raise HTTPException(status_code=413, detail="Payload too large (max 1 MB)")
+ # Bounded read above already enforces the size cap; no need to re-check.
if not _verify_generic_webhook_auth(provider_config, request, raw_body):
raise HTTPException(status_code=403, detail="Authentication failed")
safe_headers = _filter_headers(dict(request.headers))
- # Parse JSON payload
+ # Parse JSON payload from the already-bounded raw_body (request.body()
+ # has been consumed, so request.json() is no longer usable here).
try:
- payload = await request.json()
+ payload = json.loads(raw_body.decode("utf-8"))
if not isinstance(payload, dict):
raise ValueError("Payload must be a JSON object")
- except (json.JSONDecodeError, ValueError):
+ except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
if store_payloads:
async with AsyncSession(get_engine()) as log_session:
await _save_webhook_log(
diff --git a/packages/server/src/notify_bridge_server/auth/jwt.py b/packages/server/src/notify_bridge_server/auth/jwt.py
index a57f5a5..f35bdf7 100644
--- a/packages/server/src/notify_bridge_server/auth/jwt.py
+++ b/packages/server/src/notify_bridge_server/auth/jwt.py
@@ -7,30 +7,51 @@ import jwt
from ..config import settings
ALGORITHM = "HS256"
+_LEEWAY_SECONDS = 10
+
+
+def _now_utc() -> datetime:
+ return datetime.now(timezone.utc)
def create_access_token(user_id: int, role: str, token_version: int = 1) -> str:
- expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes)
+ now = _now_utc()
+ expire = now + timedelta(minutes=settings.access_token_expire_minutes)
payload = {
+ "iss": settings.jwt_issuer,
+ "aud": settings.jwt_audience,
"sub": str(user_id),
"role": role,
"type": "access",
"ver": token_version,
+ "iat": now,
"exp": expire,
}
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
def create_refresh_token(user_id: int, token_version: int = 1) -> str:
- expire = datetime.now(timezone.utc) + timedelta(days=settings.refresh_token_expire_days)
+ now = _now_utc()
+ expire = now + timedelta(days=settings.refresh_token_expire_days)
payload = {
+ "iss": settings.jwt_issuer,
+ "aud": settings.jwt_audience,
"sub": str(user_id),
"type": "refresh",
"ver": token_version,
+ "iat": now,
"exp": expire,
}
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
def decode_token(token: str) -> dict:
- return jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM])
+ return jwt.decode(
+ token,
+ settings.secret_key,
+ algorithms=[ALGORITHM],
+ audience=settings.jwt_audience,
+ issuer=settings.jwt_issuer,
+ leeway=_LEEWAY_SECONDS,
+ options={"require": ["exp", "sub", "iss", "aud", "type"]},
+ )
diff --git a/packages/server/src/notify_bridge_server/auth/routes.py b/packages/server/src/notify_bridge_server/auth/routes.py
index ab69f8c..e7e9f3a 100644
--- a/packages/server/src/notify_bridge_server/auth/routes.py
+++ b/packages/server/src/notify_bridge_server/auth/routes.py
@@ -1,5 +1,7 @@
"""Authentication API routes."""
+import asyncio
+
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from slowapi import Limiter
@@ -16,7 +18,9 @@ from .jwt import create_access_token, create_refresh_token, decode_token
router = APIRouter(prefix="/api/auth", tags=["auth"])
-limiter = Limiter(key_func=get_remote_address)
+# Default rate limit applied by SlowAPIMiddleware to every route that does NOT
+# specify its own @limiter.limit(...) — protects against blanket abuse.
+limiter = Limiter(key_func=get_remote_address, default_limits=["600/minute"])
class SetupRequest(BaseModel):
@@ -45,27 +49,52 @@ class RefreshRequest(BaseModel):
refresh_token: str
-def _hash_password(password: str) -> str:
- return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
+async def _hash_password(password: str) -> str:
+ """bcrypt.hashpw is CPU-bound (~200-500ms); never run it on the event loop."""
+
+ def _work() -> str:
+ return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
+
+ return await asyncio.to_thread(_work)
-def _verify_password(password: str, hashed: str) -> bool:
- return bcrypt.checkpw(password.encode(), hashed.encode())
+async def _verify_password(password: str, hashed: str) -> bool:
+ def _work() -> bool:
+ try:
+ return bcrypt.checkpw(password.encode(), hashed.encode())
+ except ValueError:
+ # Malformed hash in DB — treat as mismatch, never raise to caller.
+ return False
+
+ return await asyncio.to_thread(_work)
@router.post("/setup", response_model=TokenResponse)
@limiter.limit("3/minute")
async def setup(request: Request, body: SetupRequest, session: AsyncSession = Depends(get_session)):
- result = await session.exec(select(func.count()).select_from(User))
- count = result.one()
- if count > 0:
- raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Setup already completed.")
-
if len(body.password) < 8:
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
- user = User(username=body.username, hashed_password=_hash_password(body.password), role="admin")
- session.add(user)
- await session.commit()
+ # Compute hash BEFORE opening the transaction so we don't hold a writer lock
+ # during the CPU-bound bcrypt work.
+ hashed = await _hash_password(body.password)
+
+ # Serialize setup via an INSERT-inside-transaction-with-count-guard.
+ # SQLite's writer lock plus the count check inside the transaction closes
+ # the TOCTOU window between two concurrent POSTs. We ignore id=0 — that's
+ # the internal "__system__" placeholder used for ownership of default
+ # templates, never a real admin.
+ async with session.begin():
+ result = await session.exec(
+ select(func.count()).select_from(User).where(User.id != 0)
+ )
+ count = result.one()
+ if count > 0:
+ raise HTTPException(
+ status_code=status.HTTP_409_CONFLICT,
+ detail="Setup already completed.",
+ )
+ user = User(username=body.username, hashed_password=hashed, role="admin")
+ session.add(user)
await session.refresh(user)
return TokenResponse(
@@ -79,7 +108,13 @@ async def setup(request: Request, body: SetupRequest, session: AsyncSession = De
async def login(request: Request, body: LoginRequest, session: AsyncSession = Depends(get_session)):
result = await session.exec(select(User).where(User.username == body.username))
user = result.first()
- if not user or not _verify_password(body.password, user.hashed_password):
+ # Always run a bcrypt verification to keep the response time constant,
+ # preventing username-enumeration via timing side channel.
+ password_ok = await _verify_password(
+ body.password,
+ user.hashed_password if user else "$2b$12$" + "a" * 53,
+ )
+ if not user or not password_ok:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
return TokenResponse(
@@ -124,16 +159,18 @@ class PasswordChangeRequest(BaseModel):
@router.put("/password")
+@limiter.limit("10/minute")
async def change_password(
+ request: Request,
body: PasswordChangeRequest,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
- if not _verify_password(body.current_password, user.hashed_password):
+ if not await _verify_password(body.current_password, user.hashed_password):
raise HTTPException(status_code=400, detail="Current password is incorrect")
if len(body.new_password) < 8:
raise HTTPException(status_code=400, detail="New password must be at least 8 characters")
- user.hashed_password = _hash_password(body.new_password)
+ user.hashed_password = await _hash_password(body.new_password)
user.token_version = (user.token_version or 1) + 1
session.add(user)
await session.commit()
@@ -141,7 +178,12 @@ async def change_password(
@router.get("/needs-setup")
-async def needs_setup(session: AsyncSession = Depends(get_session)):
- result = await session.exec(select(func.count()).select_from(User))
+@limiter.limit("30/minute")
+async def needs_setup(request: Request, session: AsyncSession = Depends(get_session)):
+ # Exclude the internal __system__ placeholder (id=0) from the count so
+ # a fresh install still reports needs_setup=True.
+ result = await session.exec(
+ select(func.count()).select_from(User).where(User.id != 0)
+ )
count = result.one()
return {"needs_setup": count == 0}
diff --git a/packages/server/src/notify_bridge_server/config.py b/packages/server/src/notify_bridge_server/config.py
index e27c94e..128aebe 100644
--- a/packages/server/src/notify_bridge_server/config.py
+++ b/packages/server/src/notify_bridge_server/config.py
@@ -2,8 +2,20 @@
from pathlib import Path
from typing import Any
+from urllib.parse import urlparse
+
from pydantic_settings import BaseSettings
+# Secret keys we will actively refuse. These cover the default template value
+# and dev-only literals that have appeared in scripts or documentation.
+_FORBIDDEN_SECRETS: frozenset[str] = frozenset(
+ {
+ "change-me-in-production",
+ "test-secret-key-minimum-32-chars",
+ "dev-secret-key-not-for-production",
+ }
+)
+
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
@@ -13,29 +25,25 @@ class Settings(BaseSettings):
secret_key: str = "change-me-in-production"
- def model_post_init(self, __context: Any) -> None:
- if self.secret_key == "change-me-in-production":
- raise ValueError(
- "SECURITY: Refusing to start with the default secret_key. "
- "Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) "
- "before starting the server (debug mode included)."
- )
- if len(self.secret_key) < 32:
- raise ValueError(
- "SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters."
- )
- if "*" in self.cors_allowed_origins.split(","):
- raise ValueError(
- "SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled."
- )
-
- access_token_expire_minutes: int = 60
+ access_token_expire_minutes: int = 15
refresh_token_expire_days: int = 30
+ jwt_issuer: str = "notify-bridge"
+ jwt_audience: str = "notify-bridge-api"
+
host: str = "0.0.0.0"
port: int = 8420
debug: bool = False
+ # Comma-separated list of trusted proxy IPs uvicorn will honor for
+ # X-Forwarded-For / X-Forwarded-Proto. Use "*" ONLY when you trust the
+ # network (never directly on the internet). Default matches uvicorn.
+ forwarded_allow_ips: str = "127.0.0.1"
+
+ # How long to wait for in-flight requests / scheduler jobs before force
+ # killing on SIGTERM.
+ graceful_shutdown_seconds: int = 60
+
anthropic_api_key: str = ""
ai_model: str = "claude-sonnet-4-20250514"
ai_max_tokens: int = 1024
@@ -49,9 +57,6 @@ class Settings(BaseSettings):
"""Path to frontend static files. Set to serve SvelteKit build via FastAPI (e.g. /app/static in Docker)."""
# --- Logging ---
- # Boot-time logging configuration. DB AppSetting rows (``log_level`` /
- # ``log_levels`` / ``log_format``) override these after startup, letting
- # operators adjust levels from the settings UI without a restart.
log_level: str = "INFO"
"""Root log level for the app loggers (``DEBUG``/``INFO``/``WARNING``/``ERROR``)."""
@@ -61,8 +66,43 @@ class Settings(BaseSettings):
log_levels: str = ""
"""Comma-separated per-module overrides, e.g. ``notify_bridge_core.notifications.telegram.client=DEBUG,sqlalchemy.engine=INFO``."""
+ # --- Retention ---
+ event_log_retention_days: int = 30
+ """Days of event_log history to retain. 0 disables the retention job."""
+
model_config = {"env_prefix": "NOTIFY_BRIDGE_"}
+ def model_post_init(self, __context: Any) -> None:
+ if self.secret_key in _FORBIDDEN_SECRETS:
+ raise ValueError(
+ "SECURITY: Refusing to start with a known/default secret_key. "
+ "Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) "
+ "before starting the server."
+ )
+ if len(self.secret_key) < 32:
+ raise ValueError(
+ "SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters."
+ )
+ origins = [o.strip() for o in self.cors_allowed_origins.split(",") if o.strip()]
+ if "*" in origins:
+ raise ValueError(
+ "SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled."
+ )
+ for origin in origins:
+ parsed = urlparse(origin)
+ if parsed.scheme not in {"http", "https"} or not parsed.netloc:
+ raise ValueError(
+ f"CORS origin {origin!r} is invalid — must include scheme (http/https) and host."
+ )
+ if self.access_token_expire_minutes <= 0:
+ raise ValueError("access_token_expire_minutes must be > 0")
+ if self.refresh_token_expire_days <= 0:
+ raise ValueError("refresh_token_expire_days must be > 0")
+ if not (1 <= self.port <= 65535):
+ raise ValueError("port must be in range 1..65535")
+ if self.event_log_retention_days < 0:
+ raise ValueError("event_log_retention_days must be >= 0")
+
@property
def effective_database_url(self) -> str:
if self.database_url:
diff --git a/packages/server/src/notify_bridge_server/database/engine.py b/packages/server/src/notify_bridge_server/database/engine.py
index 2712f99..d6ba702 100644
--- a/packages/server/src/notify_bridge_server/database/engine.py
+++ b/packages/server/src/notify_bridge_server/database/engine.py
@@ -1,23 +1,59 @@
"""Database engine and session management."""
from collections.abc import AsyncGenerator
+import logging
+from sqlalchemy import event
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from ..config import settings
+_LOGGER = logging.getLogger(__name__)
+
_engine: AsyncEngine | None = None
+def _install_sqlite_pragmas(engine: AsyncEngine) -> None:
+ """Apply production-grade SQLite PRAGMAs on every new connection.
+
+ WAL mode lets readers and writers work concurrently without blocking;
+ ``busy_timeout`` gives contending writers a chance instead of instant
+ SQLITE_BUSY; ``foreign_keys`` enforces the FK constraints declared in the
+ models (SQLite disables them by default); ``synchronous=NORMAL`` is a
+ safe-by-default durability trade-off that is standard in WAL mode.
+ """
+
+ @event.listens_for(engine.sync_engine, "connect")
+ def _pragmas(dbapi_conn, _record): # pragma: no cover — driver hook
+ cur = dbapi_conn.cursor()
+ try:
+ cur.execute("PRAGMA journal_mode=WAL")
+ cur.execute("PRAGMA synchronous=NORMAL")
+ cur.execute("PRAGMA foreign_keys=ON")
+ cur.execute("PRAGMA busy_timeout=10000")
+ cur.execute("PRAGMA temp_store=MEMORY")
+ finally:
+ cur.close()
+
+
def get_engine() -> AsyncEngine:
global _engine
if _engine is None:
+ url = settings.effective_database_url
+ connect_args: dict = {}
+ if url.startswith("sqlite"):
+ connect_args["timeout"] = 30
_engine = create_async_engine(
- settings.effective_database_url,
+ url,
echo=settings.debug,
+ pool_pre_ping=True,
+ connect_args=connect_args,
)
+ if url.startswith("sqlite"):
+ _install_sqlite_pragmas(_engine)
+ _LOGGER.info("Database engine initialized: %s", url.split("://", 1)[0])
return _engine
@@ -31,3 +67,11 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]:
engine = get_engine()
async with AsyncSession(engine) as session:
yield session
+
+
+async def dispose_engine() -> None:
+ """Close the engine's connection pool. Call during graceful shutdown."""
+ global _engine
+ if _engine is not None:
+ await _engine.dispose()
+ _engine = None
diff --git a/packages/server/src/notify_bridge_server/database/migrations.py b/packages/server/src/notify_bridge_server/database/migrations.py
index 2a58b36..84dbc50 100644
--- a/packages/server/src/notify_bridge_server/database/migrations.py
+++ b/packages/server/src/notify_bridge_server/database/migrations.py
@@ -1282,3 +1282,141 @@ async def migrate_user_token_version(engine: AsyncEngine) -> None:
text("ALTER TABLE user ADD COLUMN token_version INTEGER NOT NULL DEFAULT 1")
)
logger.info("Added token_version column to user table")
+
+
+# ---------------------------------------------------------------------------
+# Performance indexes — covers every FK / owner column the list endpoints
+# and the webhook hot-path filter on. All use CREATE INDEX IF NOT EXISTS so
+# they are safe to re-run on every boot.
+# ---------------------------------------------------------------------------
+
+_INDEXES: list[tuple[str, str, str]] = [
+ # (index_name, table, columns)
+ ("ix_service_provider_user_id", "service_provider", "user_id"),
+ ("ix_telegram_bot_user_id", "telegram_bot", "user_id"),
+ ("ix_matrix_bot_user_id", "matrix_bot", "user_id"),
+ ("ix_email_bot_user_id", "email_bot", "user_id"),
+ ("ix_telegram_chat_bot_id", "telegram_chat", "bot_id"),
+ ("ix_tracking_config_user_id", "tracking_config", "user_id"),
+ ("ix_tracking_config_provider_type", "tracking_config", "provider_type"),
+ ("ix_notification_target_user_id", "notification_target", "user_id"),
+ ("ix_notification_target_type", "notification_target", "type"),
+ ("ix_notification_tracker_user_id", "notification_tracker", "user_id"),
+ ("ix_notification_tracker_provider_id", "notification_tracker", "provider_id"),
+ # Composite for the webhook hot path: WHERE provider_id = ? AND enabled = true
+ (
+ "ix_notification_tracker_provider_enabled",
+ "notification_tracker",
+ "provider_id, enabled",
+ ),
+ ("ix_command_config_user_id", "command_config", "user_id"),
+ ("ix_command_template_config_user_id", "command_template_config", "user_id"),
+ ("ix_command_tracker_user_id", "command_tracker", "user_id"),
+ ("ix_command_tracker_provider_id", "command_tracker", "provider_id"),
+ ("ix_action_user_id", "action", "user_id"),
+ ("ix_action_provider_id", "action", "provider_id"),
+ # Dashboard: SELECT event_log WHERE user_id = ? ORDER BY created_at DESC
+ ("ix_event_log_user_created", "event_log", "user_id, created_at DESC"),
+ ("ix_event_log_provider_id", "event_log", "provider_id"),
+ ("ix_event_log_notification_tracker_id", "event_log", "notification_tracker_id"),
+ ("ix_event_log_action_id", "event_log", "action_id"),
+ # Webhook log hot path: WHERE provider_id = ? ORDER BY created_at DESC
+ (
+ "ix_webhook_payload_log_provider_created",
+ "webhook_payload_log",
+ "provider_id, created_at DESC",
+ ),
+ # Notification tracker join tables
+ (
+ "ix_notification_tracker_target_notification_tracker_id",
+ "notification_tracker_target",
+ "notification_tracker_id",
+ ),
+ (
+ "ix_notification_tracker_target_target_id",
+ "notification_tracker_target",
+ "target_id",
+ ),
+ ("ix_target_receiver_target_id", "target_receiver", "target_id"),
+ ("ix_template_slot_config_id", "template_slot", "config_id"),
+ ("ix_command_template_slot_config_id", "command_template_slot", "config_id"),
+ ("ix_action_rule_action_id", "action_rule", "action_id"),
+ ("ix_action_execution_action_started", "action_execution", "action_id, started_at DESC"),
+]
+
+
+async def migrate_performance_indexes(engine: AsyncEngine) -> None:
+ """Create missing performance indexes on hot query paths.
+
+ Every index is created with IF NOT EXISTS so the migration is safe to
+ replay on every boot. We only create the index when the table exists —
+ early boots before other migrations land would otherwise raise.
+ """
+ async with engine.begin() as conn:
+ for name, table, columns in _INDEXES:
+ _assert_ident(name, "index")
+ _assert_ident(table, "table")
+ # Columns list is a trusted literal constructed above — never user input.
+ if not await _has_table(conn, table):
+ continue
+ try:
+ await conn.execute(
+ text(f"CREATE INDEX IF NOT EXISTS {name} ON {table} ({columns})")
+ )
+ except Exception: # pragma: no cover — log and continue
+ logger.warning(
+ "Failed to create index %s on %s(%s)",
+ name, table, columns, exc_info=True,
+ )
+
+
+# ---------------------------------------------------------------------------
+# Schema version tracking — lightweight alternative to Alembic while the
+# hand-rolled idempotent migrations remain the source of truth. Gives
+# operators a single-row answer to "what schema is this DB at" and lets
+# future upgrades short-circuit migrations that already ran.
+# ---------------------------------------------------------------------------
+
+CURRENT_SCHEMA_VERSION = 1
+
+
+async def migrate_schema_version(engine: AsyncEngine) -> None:
+ """Create schema_version table and bump it to CURRENT_SCHEMA_VERSION."""
+ async with engine.begin() as conn:
+ await conn.execute(
+ text(
+ "CREATE TABLE IF NOT EXISTS schema_version ("
+ " id INTEGER PRIMARY KEY CHECK (id = 1),"
+ " version INTEGER NOT NULL,"
+ " applied_at TEXT NOT NULL"
+ ")"
+ )
+ )
+ row = await conn.run_sync(
+ lambda sc: sc.execute(
+ text("SELECT version FROM schema_version WHERE id = 1")
+ ).fetchone()
+ )
+ from datetime import datetime, timezone
+ now = datetime.now(timezone.utc).isoformat()
+ if row is None:
+ await conn.execute(
+ text(
+ "INSERT INTO schema_version (id, version, applied_at) "
+ "VALUES (1, :v, :t)"
+ ),
+ {"v": CURRENT_SCHEMA_VERSION, "t": now},
+ )
+ logger.info("Initialized schema_version at %d", CURRENT_SCHEMA_VERSION)
+ elif int(row[0]) < CURRENT_SCHEMA_VERSION:
+ await conn.execute(
+ text(
+ "UPDATE schema_version SET version = :v, applied_at = :t "
+ "WHERE id = 1"
+ ),
+ {"v": CURRENT_SCHEMA_VERSION, "t": now},
+ )
+ logger.info(
+ "Bumped schema_version from %s to %d",
+ row[0], CURRENT_SCHEMA_VERSION,
+ )
diff --git a/packages/server/src/notify_bridge_server/database/seeds.py b/packages/server/src/notify_bridge_server/database/seeds.py
index c81d67a..683293a 100644
--- a/packages/server/src/notify_bridge_server/database/seeds.py
+++ b/packages/server/src/notify_bridge_server/database/seeds.py
@@ -394,8 +394,37 @@ async def _seed_default_command_configs() -> None:
# Public entry point
# ---------------------------------------------------------------------------
+async def _ensure_system_user() -> None:
+ """Ensure a User row with id=0 exists.
+
+ Historically the app used ``user_id=0`` as a sentinel for "system-owned"
+ defaults (tracking configs, templates, etc.). Now that we enable
+ ``PRAGMA foreign_keys=ON`` at connect time, those inserts would fail
+ with ``FOREIGN KEY constraint failed`` unless a placeholder user row
+ with the matching id exists.
+ """
+ engine = get_engine()
+ async with engine.begin() as conn:
+ # INSERT OR IGNORE so re-running seeds is cheap and idempotent.
+ await conn.execute(
+ text(
+ "INSERT OR IGNORE INTO user "
+ "(id, username, hashed_password, role, token_version, created_at) "
+ "VALUES (0, :u, :p, :r, 1, :t)"
+ ),
+ {
+ "u": "__system__",
+ # Invalid bcrypt hash — nobody can ever log in as this user.
+ "p": "!disabled!",
+ "r": "system",
+ "t": datetime.now(timezone.utc).isoformat(),
+ },
+ )
+
+
async def seed_all() -> None:
"""Run all seed functions in order."""
+ await _ensure_system_user()
await _seed_default_templates()
await _seed_default_command_templates()
await _seed_default_tracking_configs()
diff --git a/packages/server/src/notify_bridge_server/main.py b/packages/server/src/notify_bridge_server/main.py
index d73b9ea..4726b4a 100644
--- a/packages/server/src/notify_bridge_server/main.py
+++ b/packages/server/src/notify_bridge_server/main.py
@@ -52,12 +52,31 @@ from .api.webhook_logs import router as webhook_logs_router
from .api.backup import router as backup_router
+# Readiness flag — flipped to True once the scheduler has started and the
+# app is fully initialized. Exposed via /api/ready for orchestrators.
+_READY: bool = False
+
+
@asynccontextmanager
async def lifespan(app: FastAPI):
+ global _READY
await init_db()
# Run data migrations (idempotent)
from .database.engine import get_engine
- from .database.migrations import migrate_schema, migrate_tracker_targets, migrate_entity_refactor, migrate_template_slots, migrate_target_receivers, migrate_template_locale, migrate_receivers_from_config, migrate_command_slot_locale, migrate_notification_slot_locale, migrate_user_token_version
+ from .database.migrations import (
+ migrate_schema,
+ migrate_tracker_targets,
+ migrate_entity_refactor,
+ migrate_template_slots,
+ migrate_target_receivers,
+ migrate_template_locale,
+ migrate_receivers_from_config,
+ migrate_command_slot_locale,
+ migrate_notification_slot_locale,
+ migrate_user_token_version,
+ migrate_performance_indexes,
+ migrate_schema_version,
+ )
engine = get_engine()
await migrate_schema(engine)
await migrate_tracker_targets(engine)
@@ -69,6 +88,8 @@ async def lifespan(app: FastAPI):
await migrate_command_slot_locale(engine)
await migrate_notification_slot_locale(engine)
await migrate_user_token_version(engine)
+ await migrate_performance_indexes(engine)
+ await migrate_schema_version(engine)
from .database.seeds import seed_all
await seed_all()
# Apply DB-backed logging settings (override env-based boot config).
@@ -100,16 +121,28 @@ async def lifespan(app: FastAPI):
set_webhook_secret(_secret or None)
from .services.scheduler import start_scheduler, get_scheduler
await start_scheduler()
+ _READY = True
yield
- # Graceful shutdown
- from .services.http_session import close_http_session
- await close_http_session()
+ # Graceful shutdown — stop the scheduler FIRST so in-flight jobs finish
+ # before we close their HTTP session. Then close the shared session and
+ # dispose the DB engine.
+ _READY = False
scheduler = get_scheduler()
if scheduler.running:
- scheduler.shutdown()
+ scheduler.shutdown(wait=True)
+ from .services.http_session import close_http_session
+ await close_http_session()
+ from .database.engine import dispose_engine
+ await dispose_engine()
-app = FastAPI(title="Notify Bridge", version="0.1.0", lifespan=lifespan)
+try:
+ from importlib.metadata import version as _pkg_version
+ _APP_VERSION = _pkg_version("notify-bridge-server")
+except Exception: # pragma: no cover — editable install edge cases
+ _APP_VERSION = "0.0.0+unknown"
+
+app = FastAPI(title="Notify Bridge", version=_APP_VERSION, lifespan=lifespan)
# --- Security headers ---
from starlette.middleware.base import BaseHTTPMiddleware
@@ -117,6 +150,19 @@ from starlette.requests import Request as StarletteRequest
from starlette.responses import Response as StarletteResponse
+_CSP = (
+ "default-src 'self'; "
+ "img-src 'self' data: blob: https:; "
+ "style-src 'self' 'unsafe-inline'; "
+ "script-src 'self'; "
+ "connect-src 'self'; "
+ "font-src 'self' data:; "
+ "base-uri 'self'; "
+ "form-action 'self'; "
+ "frame-ancestors 'none'"
+)
+
+
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: StarletteRequest, call_next):
response: StarletteResponse = await call_next(request)
@@ -124,6 +170,14 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
+ response.headers.setdefault("Content-Security-Policy", _CSP)
+ # HSTS only makes sense over HTTPS; set when the edge terminates TLS
+ # and forwards X-Forwarded-Proto=https.
+ if request.headers.get("x-forwarded-proto") == "https":
+ response.headers.setdefault(
+ "Strict-Transport-Security",
+ "max-age=31536000; includeSubDomains",
+ )
return response
@@ -176,7 +230,22 @@ app.include_router(backup_router)
@app.get("/api/health")
async def health():
- return {"status": "ok"}
+ """Liveness: process is up and responding. Always returns 200 once the
+ ASGI app has started. Keep this endpoint anonymous and trivially cheap."""
+ return {"status": "ok", "version": _APP_VERSION}
+
+
+@app.get("/api/ready")
+async def ready():
+ """Readiness: migrations and scheduler have started, app can serve traffic.
+
+ Returns 503 until the lifespan startup sequence has completed. Use this
+ for orchestrator readiness probes (Docker, Kubernetes).
+ """
+ if not _READY:
+ from starlette.responses import JSONResponse
+ return JSONResponse({"status": "starting"}, status_code=503)
+ return {"status": "ready", "version": _APP_VERSION}
# --- Serve frontend static files (production) ---
@@ -209,4 +278,12 @@ if _cfg.static_dir and Path(_cfg.static_dir).is_dir():
def run():
import uvicorn
- uvicorn.run(app, host=_cfg.host, port=_cfg.port)
+ uvicorn.run(
+ app,
+ host=_cfg.host,
+ port=_cfg.port,
+ proxy_headers=True,
+ forwarded_allow_ips=_cfg.forwarded_allow_ips or "127.0.0.1",
+ timeout_graceful_shutdown=_cfg.graceful_shutdown_seconds,
+ access_log=not _cfg.debug,
+ )
diff --git a/packages/server/src/notify_bridge_server/services/backup_service.py b/packages/server/src/notify_bridge_server/services/backup_service.py
index 25d7e8e..0292b57 100644
--- a/packages/server/src/notify_bridge_server/services/backup_service.py
+++ b/packages/server/src/notify_bridge_server/services/backup_service.py
@@ -387,10 +387,9 @@ async def export_backup_to_file(
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
filename = f"backup-{ts}.json"
filepath = backup_dir / filename
- filepath.write_text(
- json.dumps(backup.model_dump(), indent=2, ensure_ascii=False),
- encoding="utf-8",
- )
+ import asyncio as _asyncio
+ payload = json.dumps(backup.model_dump(), indent=2, ensure_ascii=False)
+ await _asyncio.to_thread(filepath.write_text, payload, encoding="utf-8")
_LOGGER.info("Scheduled backup saved: %s", filepath)
return filepath
@@ -399,7 +398,13 @@ def cleanup_old_backups(backup_dir: Path, keep: int = 5) -> list[str]:
"""Delete oldest backup files exceeding `keep` count. Returns deleted filenames."""
if not backup_dir.is_dir():
return []
- files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
+ # Sort by mtime (newest first) so behavior doesn't depend on the filename
+ # timestamp format, which could change later without updating this code.
+ files = sorted(
+ backup_dir.glob("backup-*.json"),
+ key=lambda f: f.stat().st_mtime,
+ reverse=True,
+ )
deleted = []
for old in files[keep:]:
old.unlink()
@@ -413,7 +418,13 @@ def list_backup_files(backup_dir: Path) -> list[dict[str, Any]]:
"""List backup files in the directory with metadata."""
if not backup_dir.is_dir():
return []
- files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
+ # Sort by mtime (newest first) so behavior doesn't depend on the filename
+ # timestamp format, which could change later without updating this code.
+ files = sorted(
+ backup_dir.glob("backup-*.json"),
+ key=lambda f: f.stat().st_mtime,
+ reverse=True,
+ )
result = []
for f in files:
stat = f.stat()
diff --git a/packages/server/src/notify_bridge_server/services/http_session.py b/packages/server/src/notify_bridge_server/services/http_session.py
index bf0021d..873f317 100644
--- a/packages/server/src/notify_bridge_server/services/http_session.py
+++ b/packages/server/src/notify_bridge_server/services/http_session.py
@@ -11,23 +11,36 @@ Call ``close_http_session()`` once during application shutdown.
from __future__ import annotations
+import asyncio
+
import aiohttp
-_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
+_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30, connect=10)
+
_session: aiohttp.ClientSession | None = None
+_lock = asyncio.Lock()
async def get_http_session() -> aiohttp.ClientSession:
- """Get or create the shared HTTP session."""
+ """Get or create the shared HTTP session.
+
+ Concurrent first-callers are serialized through ``_lock`` so we never
+ leak a second ClientSession / connector pair. Once established, hot
+ callers skip the lock via the fast-path check.
+ """
global _session
- if _session is None or _session.closed:
- _session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
+ if _session is not None and not _session.closed:
+ return _session
+ async with _lock:
+ if _session is None or _session.closed:
+ _session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
return _session
async def close_http_session() -> None:
"""Close the shared HTTP session (call on app shutdown)."""
global _session
- if _session is not None and not _session.closed:
- await _session.close()
+ async with _lock:
+ if _session is not None and not _session.closed:
+ await _session.close()
_session = None
diff --git a/packages/server/src/notify_bridge_server/services/test_dispatch.py b/packages/server/src/notify_bridge_server/services/manual_dispatch.py
similarity index 100%
rename from packages/server/src/notify_bridge_server/services/test_dispatch.py
rename to packages/server/src/notify_bridge_server/services/manual_dispatch.py
diff --git a/packages/server/src/notify_bridge_server/services/scheduler.py b/packages/server/src/notify_bridge_server/services/scheduler.py
index 3336f7b..9402bf4 100644
--- a/packages/server/src/notify_bridge_server/services/scheduler.py
+++ b/packages/server/src/notify_bridge_server/services/scheduler.py
@@ -85,7 +85,21 @@ def _compute_jitter(interval_seconds: int) -> int:
def get_scheduler() -> AsyncIOScheduler:
global _scheduler
if _scheduler is None:
- _scheduler = AsyncIOScheduler()
+ # Sensible production defaults applied to every job unless overridden:
+ # * coalesce — collapse a queue of missed runs into one firing after
+ # a restart / pause, instead of bursting to catch up.
+ # * misfire_grace_time — accept firings up to 5 min late without
+ # dropping them silently.
+ # * max_instances=1 — never run two copies of the same tracker tick
+ # concurrently; the scheduler already enforces this on add_job,
+ # but we also set it as the default for safety.
+ _scheduler = AsyncIOScheduler(
+ job_defaults={
+ "coalesce": True,
+ "misfire_grace_time": 300,
+ "max_instances": 1,
+ },
+ )
return _scheduler
@@ -279,21 +293,38 @@ async def _refresh_telegram_chat_titles() -> None:
async def _cleanup_old_events() -> None:
- """Delete EventLog entries older than 90 days."""
+ """Delete EventLog / WebhookPayloadLog / ActionExecution rows older than the
+ configured retention window. A retention of 0 disables the job.
+ """
from datetime import datetime, timedelta, timezone
from sqlmodel import delete
from sqlmodel.ext.asyncio.session import AsyncSession
+ from ..config import settings
from ..database.engine import get_engine
- from ..database.models import EventLog
+ from ..database.models import ActionExecution, EventLog, WebhookPayloadLog
- cutoff = datetime.now(timezone.utc) - timedelta(days=90)
+ days = settings.event_log_retention_days
+ if days <= 0:
+ _LOGGER.debug("Event log retention disabled (days=0); skipping cleanup")
+ return
+
+ cutoff = datetime.now(timezone.utc) - timedelta(days=days)
engine = get_engine()
async with AsyncSession(engine) as session:
await session.exec(delete(EventLog).where(EventLog.created_at < cutoff))
+ await session.exec(
+ delete(WebhookPayloadLog).where(WebhookPayloadLog.created_at < cutoff)
+ )
+ await session.exec(
+ delete(ActionExecution).where(ActionExecution.started_at < cutoff)
+ )
await session.commit()
- _LOGGER.info("Cleaned up event log entries older than %s", cutoff.date())
+ _LOGGER.info(
+ "Cleaned event_log / webhook_payload_log / action_execution older than %s",
+ cutoff.date(),
+ )
async def _load_tracker_jobs() -> None:
diff --git a/packages/server/src/notify_bridge_server/services/telegram_poller.py b/packages/server/src/notify_bridge_server/services/telegram_poller.py
index 7e7e6e7..261773a 100644
--- a/packages/server/src/notify_bridge_server/services/telegram_poller.py
+++ b/packages/server/src/notify_bridge_server/services/telegram_poller.py
@@ -127,7 +127,14 @@ async def stop_bot_if_unused(bot_id: int) -> None:
def schedule_bot_polling(bot_id: int) -> None:
- """Add a polling job for a bot (idempotent)."""
+ """Add a polling job for a bot (idempotent).
+
+ We schedule at a 30 s interval, but each tick calls ``getUpdates`` with
+ ``timeout=25`` — Telegram holds the connection open until either an
+ update arrives or the timeout elapses, so in practice the bot streams
+ updates with sub-second latency while consuming ~2 API calls / minute
+ per bot (down from 20 under the old 3 s short-poll).
+ """
scheduler = get_scheduler()
job_id = f"telegram_poll_{bot_id}"
if scheduler.get_job(job_id):
@@ -135,13 +142,13 @@ def schedule_bot_polling(bot_id: int) -> None:
scheduler.add_job(
_poll_bot,
"interval",
- seconds=3,
+ seconds=30,
id=job_id,
args=[bot_id],
replace_existing=True,
max_instances=1,
)
- _LOGGER.info("Started polling for bot %d", bot_id)
+ _LOGGER.info("Started polling for bot %d (long-poll, 25s timeout)", bot_id)
def unschedule_bot_polling(bot_id: int) -> None:
@@ -233,8 +240,10 @@ async def _poll_bot(bot_id: int) -> None:
from .http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, bot_token)
+ # Long-poll: hold connection open until an update arrives or 25 s
+ # elapse. Drastically cuts API calls vs. 3 s short-poll.
result = await client.get_updates(
- offset=offset + 1 if offset else None, limit=50,
+ offset=offset + 1 if offset else None, limit=50, timeout=25,
)
if not result.get("success"):
err_text = str(result.get("error") or "")
diff --git a/packages/server/src/notify_bridge_server/services/watcher.py b/packages/server/src/notify_bridge_server/services/watcher.py
index a3f3c33..f00dcc7 100644
--- a/packages/server/src/notify_bridge_server/services/watcher.py
+++ b/packages/server/src/notify_bridge_server/services/watcher.py
@@ -369,7 +369,13 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
if events and link_data:
url_cache, asset_cache = await _get_telegram_caches()
- dispatcher = NotificationDispatcher(url_cache=url_cache, asset_cache=asset_cache)
+ from .http_session import get_http_session
+ shared_session = await get_http_session()
+ dispatcher = NotificationDispatcher(
+ url_cache=url_cache,
+ asset_cache=asset_cache,
+ session=shared_session,
+ )
for event in events:
_LOGGER.info(
"Dispatching event %s for %s (added=%d removed=%d)",
diff --git a/packages/server/tests/__init__.py b/packages/server/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/packages/server/tests/conftest.py b/packages/server/tests/conftest.py
new file mode 100644
index 0000000..56b1ace
--- /dev/null
+++ b/packages/server/tests/conftest.py
@@ -0,0 +1,33 @@
+"""Shared pytest fixtures.
+
+We set the required env vars before any ``notify_bridge_server`` module is
+imported so ``Settings()`` passes its startup validation and opens the DB
+in a writable temp directory instead of the production ``/data`` default.
+"""
+
+from __future__ import annotations
+
+import os
+import tempfile
+from pathlib import Path
+
+import pytest
+
+# Provision a writable temp data dir BEFORE the server package is imported —
+# Settings() materializes at import time, so env-var overrides have to land
+# here (conftest.py) to be effective.
+_TMP = Path(tempfile.mkdtemp(prefix="notify-bridge-tests-"))
+os.environ["NOTIFY_BRIDGE_DATA_DIR"] = str(_TMP)
+
+os.environ.setdefault(
+ "NOTIFY_BRIDGE_SECRET_KEY",
+ "pytest-secret-key-" + "x" * 40,
+)
+os.environ.setdefault("NOTIFY_BRIDGE_DEBUG", "false")
+os.environ.setdefault("NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS", "http://localhost:8420")
+
+
+@pytest.fixture(scope="session")
+def tmp_data_dir() -> Path:
+ """Expose the already-provisioned temp dir to tests that want the path."""
+ return _TMP
diff --git a/packages/server/tests/test_config.py b/packages/server/tests/test_config.py
new file mode 100644
index 0000000..0e8cac5
--- /dev/null
+++ b/packages/server/tests/test_config.py
@@ -0,0 +1,64 @@
+"""Unit tests for the Settings validator."""
+
+from __future__ import annotations
+
+import pytest
+
+from notify_bridge_server.config import Settings, _FORBIDDEN_SECRETS
+
+
+_GOOD = "a" * 40
+
+
+def _make(**kwargs) -> Settings:
+ defaults = dict(secret_key=_GOOD, cors_allowed_origins="http://localhost:8420")
+ defaults.update(kwargs)
+ return Settings(**defaults)
+
+
+class TestSecretKey:
+ def test_accepts_long_random_key(self) -> None:
+ _make()
+
+ def test_rejects_default(self) -> None:
+ with pytest.raises(ValueError, match="SECURITY"):
+ _make(secret_key="change-me-in-production")
+
+ def test_rejects_known_dev_keys(self) -> None:
+ for bad in _FORBIDDEN_SECRETS:
+ with pytest.raises(ValueError):
+ _make(secret_key=bad)
+
+ def test_rejects_short_key(self) -> None:
+ with pytest.raises(ValueError, match="32 characters"):
+ _make(secret_key="short")
+
+
+class TestCors:
+ def test_rejects_wildcard(self) -> None:
+ with pytest.raises(ValueError, match="wildcard"):
+ _make(cors_allowed_origins="*")
+
+ def test_rejects_missing_scheme(self) -> None:
+ with pytest.raises(ValueError, match="scheme"):
+ _make(cors_allowed_origins="example.com")
+
+ def test_accepts_multiple(self) -> None:
+ cfg = _make(cors_allowed_origins="http://localhost:8420,https://example.com")
+ assert "http://localhost:8420" in cfg.cors_allowed_origins
+
+
+class TestNumericValidation:
+ def test_rejects_zero_access_token_expiry(self) -> None:
+ with pytest.raises(ValueError):
+ _make(access_token_expire_minutes=0)
+
+ def test_rejects_invalid_port(self) -> None:
+ with pytest.raises(ValueError):
+ _make(port=0)
+ with pytest.raises(ValueError):
+ _make(port=70000)
+
+ def test_rejects_negative_retention(self) -> None:
+ with pytest.raises(ValueError):
+ _make(event_log_retention_days=-1)
diff --git a/packages/server/tests/test_discord_retry.py b/packages/server/tests/test_discord_retry.py
new file mode 100644
index 0000000..263f138
--- /dev/null
+++ b/packages/server/tests/test_discord_retry.py
@@ -0,0 +1,46 @@
+"""Discord client 429-retry bounding."""
+
+from __future__ import annotations
+
+import pytest
+from aioresponses import aioresponses
+
+import aiohttp
+
+from notify_bridge_core.notifications.discord.client import DiscordClient
+
+
+WEBHOOK = "https://discord.com/api/webhooks/123/abc"
+
+
+@pytest.mark.asyncio
+async def test_bounded_retries_on_persistent_429() -> None:
+ """If every response is 429, the client gives up after _MAX_RETRIES."""
+ with aioresponses() as mocked:
+ mocked.post(WEBHOOK, status=429, headers={"Retry-After": "0.001"}, repeat=True)
+
+ async with aiohttp.ClientSession() as sess:
+ client = DiscordClient(sess)
+ result = await client.send(WEBHOOK, "hello")
+
+ assert result["success"] is False
+ # Either the custom "Rate limited" message or the bare HTTP 429 from the
+ # final attempt — both indicate bounded retries without infinite recursion.
+ assert "429" in result["error"] or "Rate limited" in result["error"]
+
+
+@pytest.mark.asyncio
+async def test_caps_retry_after() -> None:
+ """A malicious Retry-After: 99999 must not pin the task for hours."""
+ with aioresponses() as mocked:
+ # First call: absurd Retry-After. Second call: success.
+ mocked.post(WEBHOOK, status=429, headers={"Retry-After": "99999"})
+ mocked.post(WEBHOOK, status=204)
+
+ async with aiohttp.ClientSession() as sess:
+ client = DiscordClient(sess)
+ # Override the cap to something trivial so the test completes fast.
+ client._MAX_RETRY_AFTER = 0.001 # type: ignore[attr-defined]
+ result = await client.send(WEBHOOK, "hello")
+
+ assert result["success"] is True
diff --git a/packages/server/tests/test_health.py b/packages/server/tests/test_health.py
new file mode 100644
index 0000000..0f05d21
--- /dev/null
+++ b/packages/server/tests/test_health.py
@@ -0,0 +1,39 @@
+"""Smoke test: app imports, /api/health returns 200, version string present."""
+
+from __future__ import annotations
+
+import pytest
+from fastapi.testclient import TestClient
+
+
+def test_health_endpoint(tmp_data_dir) -> None: # noqa: ARG001 — fixture applies env
+ from notify_bridge_server.main import app
+
+ # TestClient runs the lifespan on enter/exit, so migrations run once
+ # against the temp data dir — a genuine integration smoke check.
+ with TestClient(app) as client:
+ resp = client.get("/api/health")
+ assert resp.status_code == 200
+ body = resp.json()
+ assert body["status"] == "ok"
+ assert body["version"] != "0.0.0+unknown"
+ assert "." in body["version"] # looks like a real version
+
+
+def test_ready_endpoint(tmp_data_dir) -> None: # noqa: ARG001
+ from notify_bridge_server.main import app
+
+ with TestClient(app) as client:
+ resp = client.get("/api/ready")
+ # By the time TestClient yields, lifespan startup has completed.
+ assert resp.status_code == 200
+ assert resp.json()["status"] == "ready"
+
+
+def test_health_is_anonymous(tmp_data_dir) -> None: # noqa: ARG001
+ """/api/health must not require auth — the Docker healthcheck depends on it."""
+ from notify_bridge_server.main import app
+
+ with TestClient(app) as client:
+ resp = client.get("/api/health")
+ assert resp.status_code == 200
diff --git a/packages/server/tests/test_jwt.py b/packages/server/tests/test_jwt.py
new file mode 100644
index 0000000..a0320aa
--- /dev/null
+++ b/packages/server/tests/test_jwt.py
@@ -0,0 +1,74 @@
+"""JWT encode/decode round-trips."""
+
+from __future__ import annotations
+
+from datetime import datetime, timedelta, timezone
+
+import jwt as pyjwt
+import pytest
+
+from notify_bridge_server.auth.jwt import (
+ ALGORITHM,
+ create_access_token,
+ create_refresh_token,
+ decode_token,
+)
+from notify_bridge_server.config import settings
+
+
+def test_access_token_round_trip() -> None:
+ token = create_access_token(user_id=1, role="admin", token_version=3)
+ payload = decode_token(token)
+ assert payload["sub"] == "1"
+ assert payload["type"] == "access"
+ assert payload["role"] == "admin"
+ assert payload["ver"] == 3
+ assert payload["iss"] == settings.jwt_issuer
+ assert payload["aud"] == settings.jwt_audience
+
+
+def test_refresh_token_round_trip() -> None:
+ token = create_refresh_token(user_id=7, token_version=2)
+ payload = decode_token(token)
+ assert payload["type"] == "refresh"
+ assert payload["sub"] == "7"
+
+
+def test_decode_rejects_wrong_audience() -> None:
+ """A token signed with our key but for a different audience is rejected."""
+ now = datetime.now(timezone.utc)
+ forged = pyjwt.encode(
+ {
+ "iss": settings.jwt_issuer,
+ "aud": "other-service",
+ "sub": "1",
+ "type": "access",
+ "ver": 1,
+ "iat": now,
+ "exp": now + timedelta(minutes=5),
+ },
+ settings.secret_key,
+ algorithm=ALGORITHM,
+ )
+ with pytest.raises(pyjwt.InvalidAudienceError):
+ decode_token(forged)
+
+
+def test_decode_rejects_none_alg() -> None:
+ """An ``alg: none`` token must never be accepted."""
+ now = datetime.now(timezone.utc)
+ forged = pyjwt.encode(
+ {
+ "iss": settings.jwt_issuer,
+ "aud": settings.jwt_audience,
+ "sub": "1",
+ "type": "access",
+ "ver": 1,
+ "iat": now,
+ "exp": now + timedelta(minutes=5),
+ },
+ "",
+ algorithm="none",
+ )
+ with pytest.raises(pyjwt.InvalidAlgorithmError):
+ decode_token(forged)
diff --git a/packages/server/tests/test_ssrf.py b/packages/server/tests/test_ssrf.py
new file mode 100644
index 0000000..bdc5086
--- /dev/null
+++ b/packages/server/tests/test_ssrf.py
@@ -0,0 +1,59 @@
+"""SSRF guard regression tests."""
+
+from __future__ import annotations
+
+import pytest
+
+from notify_bridge_core.notifications.ssrf import (
+ UnsafeURLError,
+ avalidate_outbound_url,
+ validate_outbound_url,
+)
+
+
+class TestScheme:
+ def test_rejects_file_scheme(self) -> None:
+ with pytest.raises(UnsafeURLError):
+ validate_outbound_url("file:///etc/passwd")
+
+ def test_rejects_gopher(self) -> None:
+ with pytest.raises(UnsafeURLError):
+ validate_outbound_url("gopher://example.com/")
+
+ def test_accepts_https(self) -> None:
+ # A well-known public host — validated via real DNS so this test is
+ # skipped when offline.
+ try:
+ validate_outbound_url("https://example.com/")
+ except UnsafeURLError as err:
+ if "DNS" in str(err):
+ pytest.skip("No DNS in test environment")
+ raise
+
+
+class TestBlockedRanges:
+ @pytest.mark.parametrize(
+ "url",
+ [
+ "http://127.0.0.1/",
+ "http://10.0.0.1/",
+ "http://192.168.1.1/",
+ "http://169.254.169.254/latest/meta-data/",
+ "http://[::1]/",
+ ],
+ )
+ def test_rejects_literal_private(self, url: str) -> None:
+ with pytest.raises(UnsafeURLError):
+ validate_outbound_url(url)
+
+
+class TestAsyncValidator:
+ @pytest.mark.asyncio
+ async def test_async_rejects_loopback(self) -> None:
+ with pytest.raises(UnsafeURLError):
+ await avalidate_outbound_url("http://127.0.0.1/")
+
+ @pytest.mark.asyncio
+ async def test_async_rejects_bad_scheme(self) -> None:
+ with pytest.raises(UnsafeURLError):
+ await avalidate_outbound_url("file:///etc/passwd")
diff --git a/scripts/restart-backend.sh b/scripts/restart-backend.sh
index b846629..6647849 100644
--- a/scripts/restart-backend.sh
+++ b/scripts/restart-backend.sh
@@ -24,7 +24,7 @@ fi
# Start backend
export NOTIFY_BRIDGE_DATA_DIR=./test-data
-export NOTIFY_BRIDGE_SECRET_KEY=test-secret-key-minimum-32-chars
+export NOTIFY_BRIDGE_SECRET_KEY=dev-only-pwIOUsKmfn4CYWQ9hCRs5lmI3GgrVlXSu2nqFzGW
# Dev targets (homelab Immich / Gitea / etc.) live on RFC1918 ranges; the SSRF
# guard rejects private addresses by default, which would make trackers fail.
export NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1