feat: security hardening — SSRF guard, template sandbox timeout, webhook log prune, auth & backup polish

- Add outbound URL validation (SSRF) for webhook/Discord/Slack/ntfy/Matrix dispatch
- Template renderer: input/output caps and thread-based render timeout
- Webhook log filter: strip Authorization/signature/token-like headers; atomic prune
- Auth/JWT/backup/config tightening; misc frontend UX fixes
This commit is contained in:
2026-04-16 03:21:45 +03:00
parent 734e5c9340
commit f0739ca949
30 changed files with 567 additions and 105 deletions
@@ -12,6 +12,19 @@ import aiohttp
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.templates.context import build_template_context
from notify_bridge_core.templates.renderer import render_template
from .ssrf import UnsafeURLError, validate_outbound_url
_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
def _new_session() -> aiohttp.ClientSession:
"""Per-dispatch aiohttp session with a sane default timeout.
We still open a short-lived session per dispatch (connection reuse across
dispatches lives in the server-side shared session), but we always attach
a total timeout so a hung peer cannot wedge the task forever.
"""
return aiohttp.ClientSession(timeout=_HTTP_TIMEOUT)
from .receiver import (
Receiver,
@@ -176,7 +189,7 @@ class NotificationDispatcher:
assets.append(asset_entry)
results: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as session:
async with _new_session() as session:
client = TelegramClient(
session, bot_token,
url_cache=self._url_cache,
@@ -226,11 +239,16 @@ class NotificationDispatcher:
return {"success": False, "error": "No receivers configured"}
results: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as session:
async with _new_session() as session:
for receiver in target.receivers:
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
results.append({"success": False, "error": "Invalid webhook receiver"})
continue
try:
validate_outbound_url(receiver.url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
message = self._message_for_receiver(receiver, default_message, event, target)
payload = {
"message": message,
@@ -295,12 +313,17 @@ class NotificationDispatcher:
username = target.config.get("username")
results: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as session:
async with _new_session() as session:
client = DiscordClient(session)
for receiver in target.receivers:
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
results.append({"success": False, "error": "Invalid discord receiver"})
continue
try:
validate_outbound_url(receiver.webhook_url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
message = self._message_for_receiver(receiver, default_message, event, target)
results.append(await client.send(receiver.webhook_url, message, username=username))
@@ -316,12 +339,17 @@ class NotificationDispatcher:
username = target.config.get("username")
results: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as session:
async with _new_session() as session:
client = SlackClient(session)
for receiver in target.receivers:
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
results.append({"success": False, "error": "Invalid slack receiver"})
continue
try:
validate_outbound_url(receiver.webhook_url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
message = self._message_for_receiver(receiver, default_message, event, target)
results.append(await client.send(receiver.webhook_url, message, username=username))
@@ -336,11 +364,15 @@ class NotificationDispatcher:
auth_token = target.config.get("auth_token")
if not target.receivers:
return {"success": False, "error": "No receivers configured"}
try:
validate_outbound_url(server_url)
except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe ntfy server_url: {err}"}
title = f"{event.event_type.value}: {event.collection_name}"
results: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as session:
async with _new_session() as session:
client = NtfyClient(session)
for receiver in target.receivers:
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
@@ -363,12 +395,16 @@ class NotificationDispatcher:
access_token = target.config.get("access_token")
if not homeserver or not access_token:
return {"success": False, "error": "Missing Matrix homeserver_url or access_token"}
try:
validate_outbound_url(homeserver)
except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"}
if not target.receivers:
return {"success": False, "error": "No receivers configured"}
results: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as session:
async with _new_session() as session:
client = MatrixClient(session, homeserver, access_token)
for receiver in target.receivers:
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
@@ -0,0 +1,80 @@
"""Outbound URL validation to mitigate SSRF attacks.
User-controlled URLs (provider `url`, webhook target `url`, shared-link
base URLs, image downloads) must be validated before any HTTP request is
issued. This module rejects schemes other than http/https and blocks
destinations that resolve to private, loopback, link-local, or unspecified
address ranges.
Set ``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1`` in the environment for
development against localhost services.
"""
from __future__ import annotations
import ipaddress
import os
import socket
from urllib.parse import urlparse
_ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1"
_ALLOWED_SCHEMES = {"http", "https"}
class UnsafeURLError(ValueError):
"""Raised when a URL targets a disallowed network destination."""
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
)
def validate_outbound_url(url: str) -> str:
"""Validate ``url`` is safe to fetch; returns the URL on success.
Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
private addresses are permitted but the scheme check still applies.
"""
if not isinstance(url, str) or not url:
raise UnsafeURLError("URL is empty")
parsed = urlparse(url)
if parsed.scheme not in _ALLOWED_SCHEMES:
raise UnsafeURLError(f"Scheme '{parsed.scheme}' not allowed")
host = parsed.hostname
if not host:
raise UnsafeURLError("URL has no host")
if _ALLOW_PRIVATE:
return url
# Literal IP host
try:
ip = ipaddress.ip_address(host)
if _is_blocked_ip(ip):
raise UnsafeURLError(f"Host {host} is in a blocked range")
return url
except ValueError:
pass
# Hostname — resolve and reject if any resolution is in a blocked range.
try:
infos = socket.getaddrinfo(host, None)
except socket.gaierror as exc:
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
for info in infos:
sockaddr = info[4]
try:
ip = ipaddress.ip_address(sockaddr[0])
except ValueError:
continue
if _is_blocked_ip(ip):
raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
return url
@@ -7,8 +7,12 @@ from typing import Any
import aiohttp
from ..ssrf import UnsafeURLError, validate_outbound_url
_LOGGER = logging.getLogger(__name__)
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
class WebhookClient:
"""Send JSON payloads to a webhook URL."""
@@ -19,11 +23,16 @@ class WebhookClient:
self._headers = headers or {}
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
try:
validate_outbound_url(self._url)
except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe URL: {err}"}
try:
async with self._session.post(
self._url,
json=payload,
headers={"Content-Type": "application/json", **self._headers},
timeout=_DEFAULT_TIMEOUT,
) as response:
if 200 <= response.status < 300:
return {"success": True, "status_code": response.status}
@@ -29,7 +29,7 @@ def _compile_jsonpath(expression: str) -> Any | None:
return _JSONPATH_CACHE[expression]
try:
compiled = jsonpath_parse(expression)
except (JsonPathParserError, Exception) as exc:
except (JsonPathParserError, ValueError, TypeError, AttributeError) as exc:
_LOGGER.warning("Invalid JSONPath expression '%s': %s", expression, exc)
compiled = None
_JSONPATH_CACHE[expression] = compiled
@@ -69,6 +69,10 @@ def parse_webhook(
Returns:
A ServiceEvent, or None if parsing fails critically.
"""
# Defensive: upstream callers should pass a dict, but tolerate non-dict
# payloads by coercing to an empty mapping rather than raising.
if not isinstance(payload, dict):
payload = {}
# Build a combined data dict so JSONPath can reference headers too
data: dict[str, Any] = {**payload}
if headers:
@@ -1,8 +1,18 @@
"""Template rendering engine using Jinja2 SandboxedEnvironment."""
"""Template rendering engine using Jinja2 SandboxedEnvironment.
Hardening applied:
* SandboxedEnvironment with autoescape for attribute/method isolation.
* Input length cap to short-circuit pathological templates before parsing.
* Output length cap via a custom stream check to prevent memory blow-ups.
* Cooperative time budget via a thread-based watchdog -- a runaway template
(``{% for i in range(10**8) %}``) is interrupted instead of wedging the worker.
"""
from __future__ import annotations
import logging
import threading
from typing import Any
import jinja2
@@ -10,16 +20,75 @@ from jinja2.sandbox import SandboxedEnvironment
_LOGGER = logging.getLogger(__name__)
MAX_TEMPLATE_LEN = 64 * 1024 # 64 KiB source
MAX_OUTPUT_LEN = 256 * 1024 # 256 KiB rendered
RENDER_TIMEOUT_SECONDS = 2.0
_env = SandboxedEnvironment(autoescape=True)
class TemplateRenderTimeout(jinja2.TemplateError):
"""Raised when a template exceeds the configured render budget."""
def _render_with_timeout(template: jinja2.Template, context: dict[str, Any]) -> str:
"""Render `template` in a worker thread with a hard timeout.
Jinja2 has no built-in timeout; we run the render in a daemon thread and
join with a deadline. If the deadline is exceeded we raise and let the
thread die with the process -- accepted trade-off for a bounded-budget
admin-authored template.
"""
result: dict[str, Any] = {}
def _run() -> None:
try:
result["value"] = template.render(**context)
except BaseException as exc: # noqa: BLE001 - forward to caller
result["error"] = exc
worker = threading.Thread(target=_run, daemon=True)
worker.start()
worker.join(RENDER_TIMEOUT_SECONDS)
if worker.is_alive():
raise TemplateRenderTimeout(
f"Template render exceeded {RENDER_TIMEOUT_SECONDS}s budget"
)
if "error" in result:
raise result["error"]
return result.get("value", "")
def render_template(template_str: str, context: dict[str, Any]) -> str:
"""Render a Jinja2 template string with the given context.
Falls back to returning the raw template on error.
Enforces source length, output length, and wall-clock time caps.
Returns a placeholder on any failure so callers never see a partial render.
"""
if not isinstance(template_str, str):
return ""
if len(template_str) > MAX_TEMPLATE_LEN:
_LOGGER.warning(
"Template source exceeds %d chars (%d); refusing to render",
MAX_TEMPLATE_LEN, len(template_str),
)
return "[Template too large]"
try:
return _env.from_string(template_str).render(**context)
compiled = _env.from_string(template_str)
output = _render_with_timeout(compiled, context)
except TemplateRenderTimeout as e:
_LOGGER.error("Template render timeout: %s", e)
return "[Template render timeout]"
except jinja2.TemplateError as e:
_LOGGER.error("Template render error: %s", e)
return "[Template rendering error]"
except Exception as e: # sandbox guarded — log and fall back safely
_LOGGER.error("Unexpected template error: %s", e, exc_info=True)
return "[Template rendering error]"
if len(output) > MAX_OUTPUT_LEN:
_LOGGER.warning(
"Template output truncated from %d to %d bytes",
len(output), MAX_OUTPUT_LEN,
)
return output[:MAX_OUTPUT_LEN] + "\n[truncated]"
return output
@@ -31,6 +31,23 @@ def _backup_dir():
return app_config.data_dir / "backups"
def _resolve_backup_file(filename: str):
"""Validate filename and resolve to a path strictly inside the backup dir."""
if not filename.startswith("backup-") or not filename.endswith(".json"):
raise HTTPException(status_code=404, detail="Backup file not found")
if "/" in filename or "\\" in filename or ".." in filename or "\x00" in filename:
raise HTTPException(status_code=404, detail="Backup file not found")
base = _backup_dir().resolve()
candidate = (base / filename).resolve()
try:
candidate.relative_to(base)
except ValueError:
raise HTTPException(status_code=404, detail="Backup file not found")
if not candidate.is_file():
raise HTTPException(status_code=404, detail="Backup file not found")
return candidate
# ---------------------------------------------------------------------------
# Export
# ---------------------------------------------------------------------------
@@ -194,9 +211,7 @@ async def download_backup_file(
user: User = Depends(require_admin),
):
"""Download a specific backup file."""
filepath = _backup_dir() / filename
if not filepath.is_file() or not filename.startswith("backup-"):
raise HTTPException(status_code=404, detail="Backup file not found")
filepath = _resolve_backup_file(filename)
try:
content = json.loads(filepath.read_text(encoding="utf-8"))
@@ -215,9 +230,6 @@ async def delete_backup_file(
user: User = Depends(require_admin),
):
"""Delete a specific backup file."""
filepath = _backup_dir() / filename
if not filepath.is_file() or not filename.startswith("backup-"):
raise HTTPException(status_code=404, detail="Backup file not found")
filepath = _resolve_backup_file(filename)
filepath.unlink()
return {"deleted": filename}
@@ -350,12 +350,29 @@ def _verify_generic_webhook_auth(
return False
_SENSITIVE_HEADER_SUBSTR = (
"token", "auth", "key", "secret", "signature", "password", "credential",
"cookie", "x-api", "x-hub-signature",
)
def _is_sensitive_header(name: str) -> bool:
n = name.lower()
return any(s in n for s in _SENSITIVE_HEADER_SUBSTR)
def _filter_headers(raw_headers: dict[str, str]) -> dict[str, str]:
"""Keep only safe headers for logging (no Authorization)."""
"""Keep only safe headers for logging (strip Authorization, signatures, tokens).
Allowlist base set of known-safe headers, accept X-* only if they do not
match any sensitive substring (token/auth/key/secret/signature/...).
"""
safe: dict[str, str] = {}
for k, v in raw_headers.items():
kl = k.lower()
if kl in ("content-type", "user-agent") or kl.startswith("x-"):
if _is_sensitive_header(kl):
continue
if kl in ("content-type", "user-agent", "content-length", "accept") or kl.startswith("x-"):
safe[k] = v
return safe
@@ -384,26 +401,26 @@ async def _save_webhook_log(
error_message=error_message,
))
await session.flush()
count_result = await session.exec(
select(func.count(WebhookPayloadLog.id))
# Atomic prune: DELETE anything for this provider outside the newest
# max_count rows. Avoids the COUNT -> SELECT -> DELETE race.
keep_subq = (
select(WebhookPayloadLog.id)
.where(WebhookPayloadLog.provider_id == provider_id)
.order_by(WebhookPayloadLog.created_at.desc(), WebhookPayloadLog.id.desc())
.limit(max_count)
.subquery()
)
await session.execute(
sa_delete(WebhookPayloadLog)
.where(WebhookPayloadLog.provider_id == provider_id)
.where(~WebhookPayloadLog.id.in_(select(keep_subq.c.id)))
)
total = count_result.one()
if total > max_count:
oldest = await session.exec(
select(WebhookPayloadLog.id)
.where(WebhookPayloadLog.provider_id == provider_id)
.order_by(WebhookPayloadLog.created_at.asc())
.limit(total - max_count)
)
ids_to_delete = list(oldest.all())
if ids_to_delete:
await session.execute(
sa_delete(WebhookPayloadLog)
.where(WebhookPayloadLog.id.in_(ids_to_delete))
)
except Exception:
_LOGGER.warning("Failed to save webhook payload log for provider %d", provider_id, exc_info=True)
try:
await session.rollback()
except Exception:
pass
@router.post("/webhook/{token}")
@@ -436,6 +453,8 @@ async def generic_webhook(token: str, request: Request):
# Parse JSON payload
try:
payload = await request.json()
if not isinstance(payload, dict):
raise ValueError("Payload must be a JSON object")
except (json.JSONDecodeError, ValueError):
if store_payloads:
async with AsyncSession(get_engine()) as log_session:
@@ -22,12 +22,15 @@ async def get_current_user(
if payload.get("type") != "access":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
user_id = int(payload["sub"])
token_version = int(payload.get("ver", 1))
except (jwt.PyJWTError, KeyError, ValueError) as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token") from exc
user = await session.get(User, user_id)
if user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
if token_version != user.token_version:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token revoked")
return user
@@ -9,15 +9,26 @@ from ..config import settings
ALGORITHM = "HS256"
def create_access_token(user_id: int, role: str) -> str:
def create_access_token(user_id: int, role: str, token_version: int = 1) -> str:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes)
payload = {"sub": str(user_id), "role": role, "type": "access", "exp": expire}
payload = {
"sub": str(user_id),
"role": role,
"type": "access",
"ver": token_version,
"exp": expire,
}
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
def create_refresh_token(user_id: int) -> str:
def create_refresh_token(user_id: int, token_version: int = 1) -> str:
expire = datetime.now(timezone.utc) + timedelta(days=settings.refresh_token_expire_days)
payload = {"sub": str(user_id), "type": "refresh", "exp": expire}
payload = {
"sub": str(user_id),
"type": "refresh",
"ver": token_version,
"exp": expire,
}
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
@@ -69,8 +69,8 @@ async def setup(request: Request, body: SetupRequest, session: AsyncSession = De
await session.refresh(user)
return TokenResponse(
access_token=create_access_token(user.id, user.role),
refresh_token=create_refresh_token(user.id),
access_token=create_access_token(user.id, user.role, user.token_version),
refresh_token=create_refresh_token(user.id, user.token_version),
)
@@ -83,29 +83,33 @@ async def login(request: Request, body: LoginRequest, session: AsyncSession = De
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
return TokenResponse(
access_token=create_access_token(user.id, user.role),
refresh_token=create_refresh_token(user.id),
access_token=create_access_token(user.id, user.role, user.token_version),
refresh_token=create_refresh_token(user.id, user.token_version),
)
@router.post("/refresh", response_model=TokenResponse)
async def refresh(body: RefreshRequest, session: AsyncSession = Depends(get_session)):
@limiter.limit("10/minute")
async def refresh(request: Request, body: RefreshRequest, session: AsyncSession = Depends(get_session)):
import jwt as pyjwt
try:
payload = decode_token(body.refresh_token)
if payload.get("type") != "refresh":
raise HTTPException(status_code=401, detail="Invalid token type")
user_id = int(payload["sub"])
token_version = int(payload.get("ver", 1))
except (pyjwt.PyJWTError, KeyError, ValueError) as exc:
raise HTTPException(status_code=401, detail="Invalid refresh token") from exc
user = await session.get(User, user_id)
if not user:
raise HTTPException(status_code=401, detail="User not found")
if token_version != user.token_version:
raise HTTPException(status_code=401, detail="Refresh token revoked")
return TokenResponse(
access_token=create_access_token(user.id, user.role),
refresh_token=create_refresh_token(user.id),
access_token=create_access_token(user.id, user.role, user.token_version),
refresh_token=create_refresh_token(user.id, user.token_version),
)
@@ -130,6 +134,7 @@ async def change_password(
if len(body.new_password) < 8:
raise HTTPException(status_code=400, detail="New password must be at least 8 characters")
user.hashed_password = _hash_password(body.new_password)
user.token_version = (user.token_version or 1) + 1
session.add(user)
await session.commit()
return {"success": True}
@@ -14,10 +14,19 @@ class Settings(BaseSettings):
secret_key: str = "change-me-in-production"
def model_post_init(self, __context: Any) -> None:
if self.secret_key == "change-me-in-production" and not self.debug:
if self.secret_key == "change-me-in-production":
raise ValueError(
"SECURITY: Cannot start with default secret_key in production. "
"Set NOTIFY_BRIDGE_SECRET_KEY environment variable."
"SECURITY: Refusing to start with the default secret_key. "
"Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) "
"before starting the server (debug mode included)."
)
if len(self.secret_key) < 32:
raise ValueError(
"SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters."
)
if "*" in self.cors_allowed_origins.split(","):
raise ValueError(
"SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled."
)
access_token_expire_minutes: int = 60
@@ -18,8 +18,23 @@ logger = logging.getLogger(__name__)
# Helpers
# ---------------------------------------------------------------------------
_IDENT_RE = __import__("re").compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def _assert_ident(ident: str, kind: str = "identifier") -> str:
"""Guard against SQL injection in dynamically interpolated identifiers.
All table/column names flow through here before being embedded into f-strings,
so attacker-controlled values cannot break out even if they reach this layer.
"""
if not isinstance(ident, str) or not _IDENT_RE.match(ident):
raise ValueError(f"Unsafe {kind}: {ident!r}")
return ident
async def _has_column(conn, table: str, column: str) -> bool:
"""Check if a column exists in a SQLite table."""
_assert_ident(table, "table")
cols = await conn.run_sync(
lambda sync_conn: [
row[1]
@@ -1187,3 +1202,15 @@ async def migrate_notification_slot_locale(engine: AsyncEngine) -> None:
"Merged system notification template configs for %s (EN=%d, RU=%d) into %d",
provider_type, en_id, ru_id, en_id,
)
async def migrate_user_token_version(engine: AsyncEngine) -> None:
"""Add token_version column to user for JWT revocation on password change."""
async with engine.begin() as conn:
if not await _has_table(conn, "user"):
return
if not await _has_column(conn, "user", "token_version"):
await conn.execute(
text("ALTER TABLE user ADD COLUMN token_version INTEGER NOT NULL DEFAULT 1")
)
logger.info("Added token_version column to user table")
@@ -19,6 +19,7 @@ class User(SQLModel, table=True):
username: str = Field(index=True, unique=True)
hashed_password: str
role: str = Field(default="user")
token_version: int = Field(default=1)
created_at: datetime = Field(default_factory=_utcnow)
@@ -52,7 +52,7 @@ async def lifespan(app: FastAPI):
await init_db()
# Run data migrations (idempotent)
from .database.engine import get_engine
from .database.migrations import migrate_schema, migrate_tracker_targets, migrate_entity_refactor, migrate_template_slots, migrate_target_receivers, migrate_template_locale, migrate_receivers_from_config, migrate_command_slot_locale, migrate_notification_slot_locale
from .database.migrations import migrate_schema, migrate_tracker_targets, migrate_entity_refactor, migrate_template_slots, migrate_target_receivers, migrate_template_locale, migrate_receivers_from_config, migrate_command_slot_locale, migrate_notification_slot_locale, migrate_user_token_version
engine = get_engine()
await migrate_schema(engine)
await migrate_tracker_targets(engine)
@@ -63,6 +63,7 @@ async def lifespan(app: FastAPI):
await migrate_receivers_from_config(engine)
await migrate_command_slot_locale(engine)
await migrate_notification_slot_locale(engine)
await migrate_user_token_version(engine)
from .database.seeds import seed_all
await seed_all()
# Configure webhook secret from DB setting (falls back to env var)
@@ -34,6 +34,44 @@ _LOGGER = logging.getLogger(__name__)
# Fields to skip when serializing TrackingConfig into the generic `fields` dict
_TRACKING_SKIP = frozenset(("id", "user_id", "provider_type", "name", "icon", "created_at"))
# Import-time config hardening limits
_MAX_CONFIG_DEPTH = 6
_MAX_CONFIG_KEYS = 200
_MAX_STRING_LEN = 8192
def _sanitize_config(value: Any, depth: int = 0) -> Any:
"""Clamp imported config values to safe shapes before persistence.
Rejects anything that is not a JSON-compatible primitive/container, truncates
over-long strings, and caps dict/list sizes. Returns a defensively-copied
structure; the caller should never see attacker-controlled references.
"""
if depth > _MAX_CONFIG_DEPTH:
raise ValueError("Config nesting exceeds maximum depth")
if value is None or isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return value
if isinstance(value, str):
return value[:_MAX_STRING_LEN]
if isinstance(value, list):
if len(value) > _MAX_CONFIG_KEYS:
raise ValueError("Config list exceeds maximum length")
return [_sanitize_config(v, depth + 1) for v in value]
if isinstance(value, dict):
if len(value) > _MAX_CONFIG_KEYS:
raise ValueError("Config dict exceeds maximum key count")
cleaned: dict[str, Any] = {}
for k, v in value.items():
if not isinstance(k, str):
raise ValueError("Config keys must be strings")
if len(k) > 128:
raise ValueError(f"Config key too long: {k[:40]}...")
cleaned[k] = _sanitize_config(v, depth + 1)
return cleaned
raise ValueError(f"Unsupported config value type: {type(value).__name__}")
# ---------------------------------------------------------------------------
# Export
@@ -530,9 +568,14 @@ async def import_backup(
)
if name is None:
continue
try:
safe_cfg = _sanitize_config(p.config or {})
except ValueError as exc:
result.warnings.append(f"Skipped provider '{p.name}': {exc}")
continue
new_p = ServiceProvider(
user_id=user_id, type=p.type, name=name,
icon=p.icon, config=p.config,
icon=p.icon, config=safe_cfg,
)
session.add(new_p)
await session.flush()
@@ -635,17 +678,27 @@ async def import_backup(
)
if name is None:
continue
try:
safe_tgt_cfg = _sanitize_config(tgt.config or {})
except ValueError as exc:
result.warnings.append(f"Skipped target '{tgt.name}': {exc}")
continue
new_tgt = NotificationTarget(
user_id=user_id, type=tgt.type, name=name,
icon=tgt.icon, config=tgt.config,
icon=tgt.icon, config=safe_tgt_cfg,
chat_action=tgt.chat_action,
)
session.add(new_tgt)
await session.flush()
id_map["targets"][tgt.id] = new_tgt.id
for r in tgt.receivers:
try:
safe_r_cfg = _sanitize_config(r.config or {})
except ValueError as exc:
result.warnings.append(f"Skipped receiver in '{tgt.name}': {exc}")
continue
session.add(TargetReceiver(
target_id=new_tgt.id, name=r.name, config=r.config,
target_id=new_tgt.id, name=r.name, config=safe_r_cfg,
receiver_key=r.receiver_key, locale=r.locale,
enabled=r.enabled,
))
@@ -249,6 +249,22 @@ async def load_link_data(
event_key = s.slot_name.removeprefix("message_") if s.slot_name.startswith("message_") else s.slot_name
slots_by_config.setdefault(s.config_id, {}).setdefault(event_key, {})[s.locale] = s.template
# Pre-resolve broadcast children in one query to avoid N+1 per-child fetches
broadcast_child_ids: set[int] = set()
for tt in active_links:
target = target_map.get(tt.target_id)
if target and target.type == "broadcast":
disabled_ids = set(target.config.get("disabled_child_ids", []))
for cid in target.config.get("child_target_ids", []):
if cid not in disabled_ids:
broadcast_child_ids.add(cid)
child_target_map: dict[int, NotificationTarget] = {}
if broadcast_child_ids:
child_rows = await session.exec(
select(NotificationTarget).where(NotificationTarget.id.in_(broadcast_child_ids))
)
child_target_map = {t.id: t for t in child_rows.all()}
link_data: list[dict[str, Any]] = []
for tt in active_links:
target = target_map.get(tt.target_id)
@@ -262,14 +278,13 @@ async def load_link_data(
template_config = tmpl_map.get(tmpl_id) if tmpl_id else None
template_slots = slots_by_config.get(template_config.id) if template_config else None
# Broadcast target: expand into child targets
# Broadcast target: expand into child targets (pre-loaded above)
if target.type == "broadcast":
child_ids = target.config.get("child_target_ids", [])
disabled_ids = set(target.config.get("disabled_child_ids", []))
for child_id in child_ids:
for child_id in target.config.get("child_target_ids", []):
if child_id in disabled_ids:
continue
child_target = await session.get(NotificationTarget, child_id)
child_target = child_target_map.get(child_id)
if not child_target or child_target.type == "broadcast":
continue
resolved = await _resolve_target(session, child_target)