diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 802a711..42b3ab0 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -94,6 +94,9 @@ async function doRefreshAccessToken(): Promise { } const DEFAULT_TIMEOUT_MS = 30_000; +// Longer cap for fetchAuth — it's used for multipart uploads (backup restore) +// and binary downloads where a 30s limit can cut off a legit slow upload. +const DEFAULT_FETCHAUTH_TIMEOUT_MS = 120_000; export async function api( path: string, @@ -170,42 +173,58 @@ export async function api( */ export async function fetchAuth( path: string, - options: RequestInit = {}, + options: RequestInit & { timeoutMs?: number } = {}, ): Promise { const token = getToken(); const headers: Record = { ...(options.headers as Record) }; if (token) headers['Authorization'] = `Bearer ${token}`; const url = path.startsWith('http') ? path : `${API_BASE}${path}`; - let res = await fetch(url, { ...options, headers }); - if (res.status === 401 && token) { - const refreshed = await refreshAccessToken(); - if (refreshed) { - headers['Authorization'] = `Bearer ${getToken()}`; - res = await fetch(url, { ...options, headers }); + // Abort after timeout so uploads/downloads don't hang indefinitely if + // the backend stops responding. Callers can override per-request via + // options.timeoutMs or pass their own signal to opt out. + const { timeoutMs, ...fetchOptions } = options; + const controller = new AbortController(); + const timeout = setTimeout( + () => controller.abort(), + timeoutMs ?? DEFAULT_FETCHAUTH_TIMEOUT_MS, + ); + const signal = options.signal ?? controller.signal; + + try { + let res = await fetch(url, { ...fetchOptions, headers, signal }); + + if (res.status === 401 && token) { + const refreshed = await refreshAccessToken(); + if (refreshed) { + headers['Authorization'] = `Bearer ${getToken()}`; + res = await fetch(url, { ...fetchOptions, headers, signal }); + } } - } - if (res.status === 401) { - clearTokens(); - if (typeof window !== 'undefined') window.location.href = '/login'; - throw new ApiError('Unauthorized', 401); - } - - if (!res.ok) { - const err = await res.clone().json().catch(() => ({ detail: res.statusText })); - if (err && err.detail && typeof err.detail === 'object' && Array.isArray(err.detail.blocked_by)) { - const bb: BlockedByDetail = { - message: err.detail.message || `HTTP ${res.status}`, - entity: err.detail.entity || '', - blocked_by: err.detail.blocked_by, - }; - throw new ApiError(bb.message, res.status, bb); + if (res.status === 401) { + clearTokens(); + if (typeof window !== 'undefined') window.location.href = '/login'; + throw new ApiError('Unauthorized', 401); } - const msg = typeof err.detail === 'string' ? err.detail : (err.detail?.message || `HTTP ${res.status}`); - throw new ApiError(msg, res.status); - } - return res; + if (!res.ok) { + const err = await res.clone().json().catch(() => ({ detail: res.statusText })); + if (err && err.detail && typeof err.detail === 'object' && Array.isArray(err.detail.blocked_by)) { + const bb: BlockedByDetail = { + message: err.detail.message || `HTTP ${res.status}`, + entity: err.detail.entity || '', + blocked_by: err.detail.blocked_by, + }; + throw new ApiError(bb.message, res.status, bb); + } + const msg = typeof err.detail === 'string' ? err.detail : (err.detail?.message || `HTTP ${res.status}`); + throw new ApiError(msg, res.status); + } + + return res; + } finally { + clearTimeout(timeout); + } } diff --git a/packages/core/src/notify_bridge_core/notifications/telegram/client.py b/packages/core/src/notify_bridge_core/notifications/telegram/client.py index a9043da..d1aafa4 100644 --- a/packages/core/src/notify_bridge_core/notifications/telegram/client.py +++ b/packages/core/src/notify_bridge_core/notifications/telegram/client.py @@ -300,6 +300,16 @@ class TelegramClient: # Retry without parse_mode on parse errors desc = str(result.get("description", "")) if "parse" in desc.lower(): + # Log loudly: a parse failure means the template author (or + # an asset field) is producing malformed HTML. Silent + # fallback hides bugs and makes XSS-via-unescaped-field + # harder to spot. Do not log the full payload — it may + # contain secrets. + _LOGGER.warning( + "Telegram rejected parse_mode=%s (%r); retrying as plain text. " + "Check template output for unescaped characters.", + payload.get("parse_mode"), desc, + ) payload.pop("parse_mode", None) async with self._session.post(telegram_url, json=payload) as retry_resp: retry_result = await retry_resp.json() diff --git a/packages/core/src/notify_bridge_core/providers/immich/asset_utils.py b/packages/core/src/notify_bridge_core/providers/immich/asset_utils.py index 079bea5..86bc613 100644 --- a/packages/core/src/notify_bridge_core/providers/immich/asset_utils.py +++ b/packages/core/src/notify_bridge_core/providers/immich/asset_utils.py @@ -321,6 +321,12 @@ def collect_scheduled_assets( asset_album_map: dict[str, tuple[str, str]] = {} # asset_id → (album_id, public_url) collections_extra: list[dict[str, Any]] = [] + # limit=0 is the periodic-summary test path — the caller only needs + # per-album stats (name/url/counts), not a sample of assets. Skip the + # expensive ``filter_assets`` + sampling loop entirely; on a 50k-asset + # album the serial scan-then-discard pattern wasted seconds per test. + stats_only = limit <= 0 + for album_id, album in albums.items(): links = shared_links.get(album_id, []) album_public_url = get_public_url(external_url, links) or "" @@ -336,6 +342,9 @@ def collect_scheduled_assets( "owner": album.owner, }) + if stats_only: + continue + filtered = filter_assets( list(album.assets.values()), favorite_only=favorite_only, @@ -348,6 +357,9 @@ def collect_scheduled_assets( asset_album_map[asset.id] = (album_id, album_public_url) all_eligible.append(asset) + if stats_only: + return [], collections_extra + # Random sample if len(all_eligible) > limit: selected = random.sample(all_eligible, limit) diff --git a/packages/core/src/notify_bridge_core/providers/immich/client.py b/packages/core/src/notify_bridge_core/providers/immich/client.py index b509112..5a31f37 100644 --- a/packages/core/src/notify_bridge_core/providers/immich/client.py +++ b/packages/core/src/notify_bridge_core/providers/immich/client.py @@ -3,14 +3,47 @@ from __future__ import annotations import logging +import re from typing import Any import aiohttp +from ...notifications.ssrf import UnsafeURLError, validate_outbound_url from .models import ImmichAlbumData, SharedLinkInfo _LOGGER = logging.getLogger(__name__) +# Cap user-controlled Immich search parameters so a low-privileged command +# listener (e.g. an Immich ``/search`` command) cannot DoS the upstream. +MAX_SEARCH_QUERY_LEN = 256 +MAX_SEARCH_PERSON_IDS = 50 + +# User-facing error bodies — Immich responses may leak internal paths, +# hostnames, or headers injected by intermediary proxies. These helpers keep +# only a short, scrubbed summary; full bodies are logged server-side only. +_REDACTED_BODY_MAX = 120 +_SECRET_PATTERN = re.compile( + r"(?i)(bearer\s+\S+|x-api-key[:=]\s*\S+|authorization[:=]\s*\S+|cookie[:=]\s*\S+|" + r"password[:=]?\s*\S+|token[:=]?\s*[A-Za-z0-9._\-]+)" +) + + +def _redact_body(text: str) -> str: + """Return a short, credential-scrubbed snippet safe to surface to UI callers. + + Immich error responses are admin-configurable (via reverse proxies, custom + error pages) and may echo request headers or environment leak. Stripping + anything that looks like a credential + capping length keeps us from + persisting secrets into ``ActionExecution.error`` / ``EventLog.details`` + (both of which are returned through the dashboard API). + """ + if not text: + return "" + cleaned = _SECRET_PATTERN.sub("[redacted]", text) + if len(cleaned) > _REDACTED_BODY_MAX: + return cleaned[:_REDACTED_BODY_MAX] + "..." + return cleaned + class ImmichClient: """Async client for the Immich API.""" @@ -25,6 +58,18 @@ class ImmichClient: self._url = url.rstrip("/") self._api_key = api_key self._external_domain: str | None = None + # SSRF guard — admin-set Immich URLs are loaded from provider config + # which can be mutated via PATCH /api/providers or imported via + # prepare-restore, so we revalidate at construction time rather than + # trusting DB state. ``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1`` bypasses + # for dev against localhost Immich. + if self._url: + try: + validate_outbound_url(self._url) + except UnsafeURLError as err: + raise UnsafeURLError( + f"Refusing to build ImmichClient for unsafe URL {self._url!r}: {err}" + ) from err @property def url(self) -> str: @@ -36,6 +81,15 @@ class ImmichClient: @external_domain.setter def external_domain(self, value: str | None) -> None: + # Mirror the constructor's SSRF guard — external_domain is used to + # build URLs that leak into rendered notifications, but any code path + # that eventually fetches this URL would otherwise bypass the check. + if value: + try: + validate_outbound_url(value) + except UnsafeURLError as err: + _LOGGER.warning("Ignoring unsafe external_domain %r: %s", value, err) + return self._external_domain = value @property @@ -237,9 +291,12 @@ class ImmichClient: limit: int = 10, page: int = 1, ) -> list[dict[str, Any]]: - payload: dict[str, Any] = {"query": query, "page": max(1, page), "size": limit} + # Cap user-controlled inputs — a low-privileged Telegram listener can + # craft arbitrarily long queries to DoS the upstream Immich. + query = (query or "")[:MAX_SEARCH_QUERY_LEN] + payload: dict[str, Any] = {"query": query, "page": max(1, page), "size": min(max(1, limit), 100)} if album_ids: - payload["albumIds"] = album_ids + payload["albumIds"] = album_ids[:MAX_SEARCH_PERSON_IDS] try: async with self._session.post( f"{self._url}/api/search/smart", @@ -261,9 +318,10 @@ class ImmichClient: limit: int = 10, page: int = 1, ) -> list[dict[str, Any]]: - payload: dict[str, Any] = {"originalFileName": query, "page": max(1, page), "size": limit} + query = (query or "")[:MAX_SEARCH_QUERY_LEN] + payload: dict[str, Any] = {"originalFileName": query, "page": max(1, page), "size": min(max(1, limit), 100)} if album_ids: - payload["albumIds"] = album_ids + payload["albumIds"] = album_ids[:MAX_SEARCH_PERSON_IDS] try: async with self._session.post( f"{self._url}/api/search/metadata", @@ -289,7 +347,7 @@ class ImmichClient: to return an empty list on current servers. """ payload: dict[str, Any] = { - "personIds": [person_id], + "personIds": [person_id][:MAX_SEARCH_PERSON_IDS], "page": 1, "size": max(1, min(limit, 100)), } @@ -373,9 +431,17 @@ class ImmichClient: if isinstance(parsed, dict): return parsed return {"raw": body_text} + # Log full body server-side (for operators), surface only a + # redacted snippet to the caller — this string ends up in + # ActionExecution.error / EventLog.details which are returned + # through the dashboard API. + _LOGGER.warning( + "add_assets_to_album failed: HTTP %s body=%s", + response.status, body_text[:512], + ) raise ImmichApiError( f"Failed to add assets to album {album_id}: " - f"HTTP {response.status} body={body_text[:512]}" + f"HTTP {response.status} {_redact_body(body_text)}" ) except aiohttp.ClientError as err: raise ImmichApiError(f"Error adding assets to album: {err}") from err @@ -399,9 +465,13 @@ class ImmichClient: if response.status in (200, 201, 204): return body_text = await response.text() + _LOGGER.warning( + "set_album_thumbnail failed: HTTP %s body=%s", + response.status, body_text[:512], + ) raise ImmichApiError( f"Failed to set album thumbnail for {album_id}: " - f"HTTP {response.status} body={body_text[:512]}" + f"HTTP {response.status} {_redact_body(body_text)}" ) except aiohttp.ClientError as err: raise ImmichApiError(f"Error setting album thumbnail: {err}") from err diff --git a/packages/core/src/notify_bridge_core/templates/context.py b/packages/core/src/notify_bridge_core/templates/context.py index 465cc34..560687a 100644 --- a/packages/core/src/notify_bridge_core/templates/context.py +++ b/packages/core/src/notify_bridge_core/templates/context.py @@ -2,16 +2,67 @@ from __future__ import annotations +import logging from datetime import datetime from typing import Any from notify_bridge_core.models.events import ServiceEvent +_LOGGER = logging.getLogger(__name__) + # Per-target maximum video size (bytes). None = no limit. _MAX_VIDEO_SIZE_BY_TARGET: dict[str, int] = { "telegram": 50 * 1024 * 1024, # 50 MB — Telegram Bot API hard limit } +# Keys that must NEVER flow into the Jinja2 template context, even if a +# provider stuffs them into ``event.extra`` (webhooks, Immich metadata, etc.). +# Templates that could reach a Telegram/Discord/etc. chat would otherwise +# expose operator credentials if a template author simply did ``{{ api_key }}``. +# Case-insensitive substring match — any ``extra`` key containing one of these +# tokens is dropped before the merge. +_SENSITIVE_EXTRA_TOKENS: tuple[str, ...] = ( + "api_key", + "apikey", + "token", + "secret", + "password", + "passwd", + "hashed_", + "authorization", + "cookie", + "session_id", + "bearer", + "private_key", + "access_key", +) + + +def _is_sensitive_key(key: str) -> bool: + lowered = str(key).lower() + return any(tok in lowered for tok in _SENSITIVE_EXTRA_TOKENS) + + +def _safe_merge_extras(ctx: dict[str, Any], extras: dict[str, Any]) -> None: + """Merge provider ``extras`` into ``ctx``, dropping sensitive keys. + + Dropped keys are logged once per event (DEBUG) so operators can spot + leaking providers without flooding the log. + """ + if not extras: + return + dropped: list[str] = [] + for key, value in extras.items(): + if _is_sensitive_key(key): + dropped.append(key) + continue + ctx[key] = value + if dropped: + _LOGGER.debug( + "Dropped %d sensitive key(s) from template context: %s", + len(dropped), ", ".join(sorted(dropped)), + ) + def build_template_context( event: ServiceEvent, @@ -61,8 +112,9 @@ def build_template_context( "preview_url": asset.preview_url or "", "full_url": asset.full_url or "", } - # Flatten extras into asset dict for template access - asset_dict.update(asset.extra) + # Flatten extras into asset dict for template access — same + # sensitive-key filtering applied as the top-level merge. + _safe_merge_extras(asset_dict, asset.extra) asset_dict.setdefault("oversized", False) asset_dict.setdefault("file_size", None) asset_dict.setdefault("playback_size", None) @@ -138,8 +190,11 @@ def build_template_context( if len(locations) == 1 and "" not in locations: ctx["common_location"] = locations.pop() - # Provider-specific extras merged at top level - ctx.update(event.extra) + # Provider-specific extras merged at top level. Sensitive keys (tokens, + # secrets, auth headers) are dropped — see ``_SENSITIVE_EXTRA_TOKENS``. + # Without this, a template author could exfiltrate provider credentials + # via ``{{ api_key }}`` in an outgoing notification body. + _safe_merge_extras(ctx, event.extra) # Ensure URL variables always exist (avoid Jinja2 undefined errors) ctx.setdefault("public_url", "") diff --git a/packages/server/src/notify_bridge_server/api/backup.py b/packages/server/src/notify_bridge_server/api/backup.py index fb0c5d2..52a6abd 100644 --- a/packages/server/src/notify_bridge_server/api/backup.py +++ b/packages/server/src/notify_bridge_server/api/backup.py @@ -1,13 +1,15 @@ """Configuration backup/restore API (admin only).""" import asyncio +import hashlib import json import logging import os import signal from datetime import datetime, timezone +from urllib.parse import urlparse -from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, File, Query +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, UploadFile, File, Query from fastapi.responses import JSONResponse from sqlmodel.ext.asyncio.session import AsyncSession @@ -28,6 +30,11 @@ PENDING_RESTORE_PATH_KEY = "pending_restore_path" PENDING_RESTORE_CONFLICT_KEY = "pending_restore_conflict_mode" PENDING_RESTORE_UPLOADED_AT_KEY = "pending_restore_uploaded_at" PENDING_RESTORE_UPLOADED_BY_KEY = "pending_restore_uploaded_by" +# SHA256 of the staged pending_restore.json, written atomically with the file. +# The startup hook refuses to apply if the on-disk file's hash does not match — +# defends against anyone dropping a tampered file into data/ between prepare +# and restart. +PENDING_RESTORE_SHA256_KEY = "pending_restore_sha256" def _pending_restore_path(): @@ -44,6 +51,69 @@ router = APIRouter(prefix="/api/backup", tags=["backup"]) MAX_UPLOAD_SIZE = 10 * 1024 * 1024 # 10 MB +async def _read_upload_bounded(file: UploadFile, max_bytes: int = MAX_UPLOAD_SIZE) -> bytes: + """Read an UploadFile into memory, failing fast if it exceeds ``max_bytes``. + + Rejects on ``content_length`` header up-front when available; always + stream-reads with a running byte counter so we never allocate more than + the limit even when the header is missing or lies. + """ + # Fast path: reject on header before we allocate anything. + cl = file.headers.get("content-length") if hasattr(file, "headers") else None + if cl: + try: + if int(cl) > max_bytes: + raise HTTPException(status_code=400, detail="File too large (max 10 MB)") + except ValueError: + pass + + chunks: list[bytes] = [] + total = 0 + while True: + chunk = await file.read(64 * 1024) + if not chunk: + break + total += len(chunk) + if total > max_bytes: + raise HTTPException(status_code=400, detail="File too large (max 10 MB)") + chunks.append(chunk) + return b"".join(chunks) + + +def _check_same_origin(request: Request) -> None: + """Reject cross-origin admin-write POSTs (CSRF defense). + + Bearer tokens in ``localStorage`` plus cookie-less CORS mean a malicious + page cannot technically submit our Authorization header from a victim's + session, BUT browser extensions and misconfigured CORS policies routinely + break this assumption. For endpoints whose blast radius is restart/RCE- + equivalent (restore apply), we additionally require the request to come + from our own origin. + """ + host = request.headers.get("host", "").lower() + if not host: + raise HTTPException(status_code=400, detail="Missing Host header") + + def _host_of(u: str | None) -> str: + if not u: + return "" + try: + return (urlparse(u).netloc or "").lower() + except Exception: # noqa: BLE001 + return "" + + origin_host = _host_of(request.headers.get("origin")) + referer_host = _host_of(request.headers.get("referer")) + # At least one of Origin/Referer must be present and match Host. + # Legitimate browser requests to this endpoint always ship Origin. + same = (origin_host and origin_host == host) or (referer_host and referer_host == host) + if not same: + raise HTTPException( + status_code=403, + detail="Cross-origin request rejected", + ) + + def _backup_dir(): return app_config.data_dir / "backups" @@ -104,9 +174,7 @@ async def validate_config( user: User = Depends(require_admin), ): """Validate a backup file without importing.""" - content = await file.read() - if len(content) > MAX_UPLOAD_SIZE: - raise HTTPException(status_code=400, detail="File too large (max 10 MB)") + content = await _read_upload_bounded(file) try: raw = json.loads(content) @@ -129,9 +197,7 @@ async def import_config( session: AsyncSession = Depends(get_session), ): """Import configuration from a backup file.""" - content = await file.read() - if len(content) > MAX_UPLOAD_SIZE: - raise HTTPException(status_code=400, detail="File too large (max 10 MB)") + content = await _read_upload_bounded(file) try: raw = json.loads(content) @@ -167,6 +233,7 @@ async def _clear_pending_restore_markers(session: AsyncSession) -> None: PENDING_RESTORE_CONFLICT_KEY, PENDING_RESTORE_UPLOADED_AT_KEY, PENDING_RESTORE_UPLOADED_BY_KEY, + PENDING_RESTORE_SHA256_KEY, ): row = await session.get(AppSetting, key) if row: @@ -185,9 +252,7 @@ async def prepare_restore( Validates the uploaded file, writes it to ``data/pending_restore.json``, and persists marker settings so startup will apply it atomically. """ - content = await file.read() - if len(content) > MAX_UPLOAD_SIZE: - raise HTTPException(status_code=400, detail="File too large (max 10 MB)") + content = await _read_upload_bounded(file) try: raw = json.loads(content) @@ -205,15 +270,25 @@ async def prepare_restore( pending_path.parent.mkdir(parents=True, exist_ok=True) # Atomic write: write to tmp then rename, so a crash mid-write never # leaves a truncated pending_restore.json that would break startup apply. + payload = json.dumps(raw).encode("utf-8") + digest = hashlib.sha256(payload).hexdigest() tmp_path = pending_path.with_suffix(pending_path.suffix + ".tmp") - tmp_path.write_text(json.dumps(raw), encoding="utf-8") + tmp_path.write_bytes(payload) os.replace(tmp_path, pending_path) + # Best-effort tighten perms so a non-root local user cannot swap the file + # for one they control between prepare and restart. On Windows this is a + # no-op; on POSIX we restrict to owner-only rw. + try: + os.chmod(pending_path, 0o600) + except OSError: + pass now_iso = datetime.now(timezone.utc).isoformat() await _set_app_setting(session, PENDING_RESTORE_PATH_KEY, str(pending_path)) await _set_app_setting(session, PENDING_RESTORE_CONFLICT_KEY, conflict_mode.value) await _set_app_setting(session, PENDING_RESTORE_UPLOADED_AT_KEY, now_iso) await _set_app_setting(session, PENDING_RESTORE_UPLOADED_BY_KEY, user.username) + await _set_app_setting(session, PENDING_RESTORE_SHA256_KEY, digest) await session.commit() return { @@ -292,6 +367,7 @@ def _is_supervised() -> bool: @router.post("/apply-restart") async def apply_and_restart( + request: Request, background_tasks: BackgroundTasks, user: User = Depends(require_admin), session: AsyncSession = Depends(get_session), @@ -299,7 +375,11 @@ async def apply_and_restart( """Trigger a graceful exit so the supervisor respawns and applies the pending restore. Only allowed when a pending restore is staged AND the process is supervised. + Requires same-origin Origin/Referer — this endpoint's blast radius is a + full config replace + restart, so an admin token alone (vulnerable to + XSS-driven CSRF) is not enough. """ + _check_same_origin(request) path_row = await session.get(AppSetting, PENDING_RESTORE_PATH_KEY) if not path_row or not path_row.value: raise HTTPException(status_code=409, detail="No pending restore to apply") diff --git a/packages/server/src/notify_bridge_server/api/users.py b/packages/server/src/notify_bridge_server/api/users.py index 756fda6..d9648bc 100644 --- a/packages/server/src/notify_bridge_server/api/users.py +++ b/packages/server/src/notify_bridge_server/api/users.py @@ -4,6 +4,7 @@ import logging from fastapi import APIRouter, Depends, HTTPException, status from pydantic import BaseModel +from sqlalchemy import func from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -81,6 +82,12 @@ async def update_user( if not user: raise HTTPException(status_code=404, detail="User not found") + # Track whether the identity that JWTs encode has changed. Any such change + # must bump ``token_version`` so already-issued tokens are rejected — a + # user demoted admin→user must not keep admin in their cached JWT until + # expiry, and a rename should invalidate prior sessions too. + identity_changed = False + if body.username is not None and body.username != user.username: new_username = body.username.strip() if not new_username: @@ -89,21 +96,51 @@ async def update_user( if dup.first(): raise HTTPException(status_code=409, detail="Username already exists") user.username = new_username + identity_changed = True if body.role is not None and body.role != user.role: if body.role not in ("admin", "user"): raise HTTPException(status_code=400, detail="Invalid role") - # Prevent demoting the last admin + # Prevent demoting the last admin. Done via a COUNT to avoid loading + # every admin row; more importantly, re-checked *after* the role + # change is staged (TOCTOU guard — two concurrent demotes can each + # see admin_count=2 and both proceed, dropping to 0). if user.role == "admin" and body.role != "admin": - admins = (await session.exec( - select(User).where(User.role == "admin") - )).all() - if len(admins) <= 1: + admin_count = (await session.exec( + select(func.count(User.id)).where(User.role == "admin") + )).one() + if isinstance(admin_count, tuple): + admin_count = admin_count[0] + if (admin_count or 0) <= 1: raise HTTPException(status_code=400, detail="Cannot demote the last admin") user.role = body.role + identity_changed = True + + if identity_changed: + user.token_version = (user.token_version or 1) + 1 session.add(user) - await session.commit() + try: + await session.commit() + except Exception: + await session.rollback() + raise + + # Final defense against admin-count race: if we just demoted the last admin + # due to a concurrent demote landing between our check and commit, undo. + if body.role is not None and body.role != "admin": + admin_count_after = (await session.exec( + select(func.count(User.id)).where(User.role == "admin") + )).one() + if isinstance(admin_count_after, tuple): + admin_count_after = admin_count_after[0] + if (admin_count_after or 0) < 1: + # Roll the user back to admin and re-commit. + user.role = "admin" + session.add(user) + await session.commit() + raise HTTPException(status_code=409, detail="Refused: would remove the last admin") + await session.refresh(user) return {"id": user.id, "username": user.username, "role": user.role} @@ -126,6 +163,9 @@ async def reset_user_password( if len(body.new_password) < 8: raise HTTPException(status_code=400, detail="Password must be at least 8 characters") user.hashed_password = bcrypt.hashpw(body.new_password.encode(), bcrypt.gensalt()).decode() + # Invalidate all prior JWTs issued for this user — matches the self-serve + # password-change path in auth/routes.py. + user.token_version = (user.token_version or 1) + 1 session.add(user) await session.commit() return {"success": True} diff --git a/packages/server/src/notify_bridge_server/database/migrations.py b/packages/server/src/notify_bridge_server/database/migrations.py index 9180ec7..e387e51 100644 --- a/packages/server/src/notify_bridge_server/database/migrations.py +++ b/packages/server/src/notify_bridge_server/database/migrations.py @@ -92,6 +92,21 @@ async def migrate_schema(engine: AsyncEngine) -> None: await conn.execute(text(sql)) logger.info("Added %s column to event_log table", col) + # Explicit indexes on the dashboard-query columns. SQLModel's + # ``index=True`` is emitted by ``create_all`` on *new* installs, + # but ALTER TABLE ADD COLUMN doesn't create them on upgrades — + # so the first boot after upgrade would leave these unindexed + # and status.py ``WHERE user_id=...`` would table-scan. The + # indexes are redundant-but-safe once create_all also runs. + for idx_name, col in [ + ("ix_event_log_user_id", "user_id"), + ("ix_event_log_action_id", "action_id"), + ("ix_event_log_provider_id", "provider_id"), + ]: + await conn.execute( + text(f"CREATE INDEX IF NOT EXISTS {idx_name} ON event_log ({col})") + ) + # Backfill user_id from notification_tracker for legacy rows. # Safe to run repeatedly: only touches rows where user_id is still NULL. await conn.execute(text(""" @@ -250,6 +265,21 @@ async def migrate_schema(engine: AsyncEngine) -> None: ) logger.info("Added track_webhook_received column to tracking_config table") + # Add quiet hours to tracking_config if missing. + # Start/end are nullable HH:MM strings; quiet_hours_enabled gates them. + if await _has_table(conn, "tracking_config"): + if not await _has_column(conn, "tracking_config", "quiet_hours_enabled"): + await conn.execute( + text("ALTER TABLE tracking_config ADD COLUMN quiet_hours_enabled INTEGER DEFAULT 0") + ) + logger.info("Added quiet_hours_enabled column to tracking_config table") + for col_name in ("quiet_hours_start", "quiet_hours_end"): + if not await _has_column(conn, "tracking_config", col_name): + await conn.execute( + text(f"ALTER TABLE tracking_config ADD COLUMN {col_name} TEXT") + ) + logger.info("Added %s column to tracking_config table", col_name) + # Drop legacy template content columns from template_config # (template content moved to template_slot child rows) if await _has_table(conn, "template_config"): diff --git a/packages/server/src/notify_bridge_server/services/notifier.py b/packages/server/src/notify_bridge_server/services/notifier.py index 2d1a0a9..d5b669a 100644 --- a/packages/server/src/notify_bridge_server/services/notifier.py +++ b/packages/server/src/notify_bridge_server/services/notifier.py @@ -1,5 +1,6 @@ """Notification sender — unified send logic for all paths (dispatch + test).""" +import asyncio import logging from typing import Any @@ -11,6 +12,10 @@ from ..database.models import NotificationTarget, TargetReceiver _LOGGER = logging.getLogger(__name__) +# Cap on concurrent per-receiver test sends. Keeps us under Telegram's per-bot +# rate limit (~30 msg/s) while still saving ~N×RTT on multi-chat broadcasts. +_TEST_SEND_CONCURRENCY = 5 + _TEST_MESSAGES: dict[str, dict[str, str]] = { "en": { "telegram": "\u2705 Test message from Notify Bridge", @@ -358,19 +363,29 @@ async def _send_telegram_test_per_receiver( http = await get_http_session() client = TelegramClient(http, bot_token) - results: list[dict] = [] - for r in recv_rows: + + # Parallelize per-receiver sends with a small semaphore — broadcast to + # N chats now takes ~ceil(N / concurrency) × RTT instead of N × RTT, + # matching the dispatcher's bounded-concurrency pattern. Capped below + # Telegram's rate limit so we don't trigger 429s on large fleets. + sem = asyncio.Semaphore(_TEST_SEND_CONCURRENCY) + + async def _send_one(r: TargetReceiver) -> dict | None: chat_id = str(r.config.get("chat_id", "")) if not chat_id: - continue + return None explicit = getattr(r, "locale", "") or "" locale = explicit or chat_locale_map.get(chat_id) or default_locale message = _get_test_message(locale[:2].lower(), "telegram") - results.append(await client.send_message( - chat_id=chat_id, - text=message, - disable_web_page_preview=bool(disable_preview), - )) + async with sem: + return await client.send_message( + chat_id=chat_id, + text=message, + disable_web_page_preview=bool(disable_preview), + ) + + raw = await asyncio.gather(*(_send_one(r) for r in recv_rows)) + results = [r for r in raw if r is not None] return _aggregate(results) diff --git a/packages/server/src/notify_bridge_server/services/pending_restore.py b/packages/server/src/notify_bridge_server/services/pending_restore.py index 17f71d0..c1522cf 100644 --- a/packages/server/src/notify_bridge_server/services/pending_restore.py +++ b/packages/server/src/notify_bridge_server/services/pending_restore.py @@ -10,10 +10,19 @@ If the apply fails, the pending file is kept so the operator can inspect it and markers are updated to record the last error. On success, the staged file is archived under data/applied_restores/.json and markers are cleared. + +Integrity checks on startup: +- The on-disk file's SHA256 must match ``PENDING_RESTORE_SHA256_KEY`` + (written atomically with the staged file). Protects against tampering + between prepare and restart. +- The pending path must resolve *inside* ``app_config.data_dir``. Protects + against a rogue AppSetting pointing at an arbitrary file. """ from __future__ import annotations +import asyncio +import hashlib import json import logging import shutil @@ -24,11 +33,13 @@ from sqlmodel.ext.asyncio.session import AsyncSession from ..api.backup import ( PENDING_RESTORE_CONFLICT_KEY, PENDING_RESTORE_PATH_KEY, + PENDING_RESTORE_SHA256_KEY, PENDING_RESTORE_UPLOADED_AT_KEY, PENDING_RESTORE_UPLOADED_BY_KEY, _applied_restores_dir, _pending_restore_path, ) +from ..config import settings as app_config from ..database.engine import get_engine from ..database.models import AppSetting from .backup_schema import BackupFile, ConflictMode @@ -49,6 +60,23 @@ async def apply_pending_restore_if_any() -> None: return pending_path = _pending_restore_path() + + # Defensive: ensure the hard-coded path still lives inside data_dir. + # If future refactors let this be read from AppSetting, this check + # blocks arbitrary-file reads. + try: + resolved = pending_path.resolve() + data_root = app_config.data_dir.resolve() + resolved.relative_to(data_root) + except (ValueError, OSError): + _LOGGER.error( + "Pending-restore path %s is outside data_dir %s — refusing to apply", + pending_path, app_config.data_dir, + ) + await _record_error(session, "Pending path outside data_dir") + await session.commit() + return + if not pending_path.exists(): _LOGGER.warning( "Pending-restore marker present but file missing at %s — clearing marker", @@ -62,9 +90,42 @@ async def apply_pending_restore_if_any() -> None: conflict_mode = ConflictMode(conflict_row.value) if conflict_row and conflict_row.value else ConflictMode.SKIP uploaded_by_row = await session.get(AppSetting, PENDING_RESTORE_UPLOADED_BY_KEY) uploaded_by = uploaded_by_row.value if uploaded_by_row else "admin" + sha_row = await session.get(AppSetting, PENDING_RESTORE_SHA256_KEY) + expected_sha = (sha_row.value or "").strip().lower() if sha_row else "" try: - raw = json.loads(pending_path.read_text(encoding="utf-8")) + raw_bytes = await asyncio.to_thread(pending_path.read_bytes) + except OSError as err: + _LOGGER.exception("Pending-restore file unreadable") + await _record_error(session, f"Unreadable backup: {err}") + await session.commit() + return + + # Integrity: reject unless hash matches what prepare-restore stored. + # An attacker with write access to data/ (swapped file, bind-mount + # abuse) does not also have write access to the DB. + if not expected_sha: + _LOGGER.error("Pending-restore marker has no SHA256; refusing to apply") + await _record_error(session, "Missing integrity marker") + await session.commit() + return + actual_sha = hashlib.sha256(raw_bytes).hexdigest() + if actual_sha != expected_sha: + _LOGGER.error( + "Pending-restore SHA256 mismatch (expected %s, got %s) — refusing to apply", + expected_sha, actual_sha, + ) + await _record_error( + session, + "Integrity check failed: on-disk backup SHA256 does not match the hash " + "recorded at prepare time. File may have been tampered with; cancel and " + "re-upload.", + ) + await session.commit() + return + + try: + raw = json.loads(raw_bytes.decode("utf-8")) backup = BackupFile.model_validate(raw) except Exception as err: # noqa: BLE001 _LOGGER.exception("Pending-restore file unreadable") @@ -88,8 +149,14 @@ async def apply_pending_restore_if_any() -> None: result = await import_backup(session, admin_row.id, backup, conflict_mode) except Exception as err: # noqa: BLE001 _LOGGER.exception("Pending-restore apply failed") - await _record_error(session, str(err)) - await session.commit() + # Discard any partial inserts the importer made before raising — + # committing partial state would let a crafted failing backup + # selectively mutate entities. The error-record commit below + # happens on a *fresh* session. + await session.rollback() + async with AsyncSession(engine) as fresh: + await _record_error(fresh, str(err)) + await fresh.commit() return # Archive the file @@ -136,6 +203,7 @@ async def _clear_markers(session: AsyncSession) -> None: PENDING_RESTORE_CONFLICT_KEY, PENDING_RESTORE_UPLOADED_AT_KEY, PENDING_RESTORE_UPLOADED_BY_KEY, + PENDING_RESTORE_SHA256_KEY, ): row = await session.get(AppSetting, key) if row: diff --git a/packages/server/src/notify_bridge_server/services/scheduler.py b/packages/server/src/notify_bridge_server/services/scheduler.py index 2bc495d..18b89f2 100644 --- a/packages/server/src/notify_bridge_server/services/scheduler.py +++ b/packages/server/src/notify_bridge_server/services/scheduler.py @@ -166,30 +166,41 @@ async def _refresh_telegram_chat_titles() -> None: refreshed = 0 errors = 0 + # Bucket results first, then fetch all rows in one IN-query instead of + # per-row ``session.get`` — otherwise a 50-chat fleet issues 50 extra + # SELECTs before commit. + successes: dict[int, dict] = {} + for chat_id, info, err in results: + if err is not None or info is None: + errors += 1 + if err: + _LOGGER.debug("getChat failed for chat row %s: %s", chat_id, err) + continue + if chat_id is not None: + successes[chat_id] = info async with AsyncSession(engine) as session: - for chat_id, info, err in results: - if err is not None or info is None: - errors += 1 - if err: - _LOGGER.debug("getChat failed for chat row %s: %s", chat_id, err) - continue - merged = await session.get(TelegramChat, chat_id) - if not merged: - continue - title = info.get("title") or ( - (info.get("first_name", "") + " " + info.get("last_name", "")).strip() - ) - changed = False - if title and merged.title != title: - merged.title = title - changed = True - new_username = info.get("username") - if new_username is not None and merged.username != new_username: - merged.username = new_username - changed = True - if changed: - session.add(merged) - refreshed += 1 + if successes: + rows = (await session.exec( + select(TelegramChat).where(TelegramChat.id.in_(list(successes.keys()))) + )).all() + for merged in rows: + info = successes.get(merged.id) + if not info: + continue + title = info.get("title") or ( + (info.get("first_name", "") + " " + info.get("last_name", "")).strip() + ) + changed = False + if title and merged.title != title: + merged.title = title + changed = True + new_username = info.get("username") + if new_username is not None and merged.username != new_username: + merged.username = new_username + changed = True + if changed: + session.add(merged) + refreshed += 1 await session.commit() _LOGGER.info( "Telegram chat title refresh: %s updated, %s errors", refreshed, errors diff --git a/packages/server/src/notify_bridge_server/services/test_dispatch.py b/packages/server/src/notify_bridge_server/services/test_dispatch.py index a8676da..c8d4bf7 100644 --- a/packages/server/src/notify_bridge_server/services/test_dispatch.py +++ b/packages/server/src/notify_bridge_server/services/test_dispatch.py @@ -250,14 +250,23 @@ async def _build_immich_event( collection_ids, limit, asset_type, favorite_only, min_rating, ) - # Album-based path: use shared collect_scheduled_assets + # Album-based path: use shared collect_scheduled_assets. + # Fetch albums + shared links in parallel — on a 20-album tracker the old + # serial ``await`` loop took ~2 × 20 × RTT, now it's one round-trip. + import asyncio as _asyncio + album_tasks = [immich.client.get_album(aid) for aid in collection_ids] + link_tasks = [immich.client.get_shared_links(aid) for aid in collection_ids] + album_results, link_results = await _asyncio.gather( + _asyncio.gather(*album_tasks, return_exceptions=True), + _asyncio.gather(*link_tasks, return_exceptions=True), + ) albums: dict[str, ImmichAlbumData] = {} shared_links: dict[str, list[SharedLinkInfo]] = {} - for album_id in collection_ids: - album = await immich.client.get_album(album_id) - if album: - albums[album_id] = album - shared_links[album_id] = await immich.client.get_shared_links(album_id) + for album_id, album, links in zip(collection_ids, album_results, link_results): + if isinstance(album, Exception) or album is None: + continue + albums[album_id] = album + shared_links[album_id] = links if not isinstance(links, Exception) else [] assets, collections_extra = collect_scheduled_assets( albums, shared_links, ext_domain, @@ -320,13 +329,21 @@ async def _build_immich_periodic_event( ext_domain = provider_config.get("external_domain") or provider_config.get("url", "") + # Parallel fetch — see _build_immich_event above for the same rationale. + import asyncio as _asyncio + album_tasks = [immich.client.get_album(aid) for aid in collection_ids] + link_tasks = [immich.client.get_shared_links(aid) for aid in collection_ids] + album_results, link_results = await _asyncio.gather( + _asyncio.gather(*album_tasks, return_exceptions=True), + _asyncio.gather(*link_tasks, return_exceptions=True), + ) albums: dict[str, ImmichAlbumData] = {} shared_links: dict[str, list[SharedLinkInfo]] = {} - for album_id in collection_ids: - album = await immich.client.get_album(album_id) - if album: - albums[album_id] = album - shared_links[album_id] = await immich.client.get_shared_links(album_id) + for album_id, album, links in zip(collection_ids, album_results, link_results): + if isinstance(album, Exception) or album is None: + continue + albums[album_id] = album + shared_links[album_id] = links if not isinstance(links, Exception) else [] # limit=0 → returns ([], collections_extra) with full per-album stats. _assets, collections_extra = collect_scheduled_assets( diff --git a/scripts/restart-backend.sh b/scripts/restart-backend.sh index a5c6ea9..b846629 100644 --- a/scripts/restart-backend.sh +++ b/scripts/restart-backend.sh @@ -25,6 +25,9 @@ fi # Start backend export NOTIFY_BRIDGE_DATA_DIR=./test-data export NOTIFY_BRIDGE_SECRET_KEY=test-secret-key-minimum-32-chars +# Dev targets (homelab Immich / Gitea / etc.) live on RFC1918 ranges; the SSRF +# guard rejects private addresses by default, which would make trackers fail. +export NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1 nohup "$PYTHON" -m uvicorn notify_bridge_server.main:app \ --host 0.0.0.0 --port 8420 > .backend.log 2>&1 &