fix(security,perf): harden restore, CSRF, token_version + perf pass
Security
- Sign pending_restore.json (SHA256 stored in AppSetting, verified on
startup apply) + refuse path outside data_dir, tighten to 0600.
- Require same-origin Origin/Referer on POST /api/backup/apply-restart —
Bearer-in-localStorage is CSRF-reachable from any XSS'd admin tab.
- Bump token_version on role/username change and admin password reset so
demoted admins lose admin in already-issued JWTs. Guard last-admin
TOCTOU via COUNT + post-commit re-check that rolls back a race.
- SSRF guard (validate_outbound_url) in ImmichClient.__init__ and the
external_domain setter — admin-mutable URLs were bypassing the check
that webhook/slack/discord paths already used. Dev restart script now
sets NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1 so homelab Immich still works.
- Redact + cap Immich error bodies to ~120 chars before they flow into
ActionExecution.error / EventLog.details (both UI-visible).
- Deny-list sensitive keys (api_key / token / secret / password /
authorization / cookie / ...) in template-context merges so a rogue
template can't exfiltrate provider creds via {{ api_key }}.
- Cap user-controlled Immich search params (query ≤256, person_ids ≤50,
size ≤100) so a Telegram listener can't DoS upstream.
- Stream upload reads with running byte counter + content-length precheck
instead of buffering the full body and then rejecting.
- Log Telegram parse_mode fallbacks instead of swallowing silently;
template escape bugs now surface in server logs.
- Rollback partial imports on pending-restore failure (error recorded on
a fresh session).
Performance
- Fix N+1 in _refresh_telegram_chat_titles: single IN query instead of
session.get per chat.
- Parallelize album + shared-link fetches in test_dispatch (asyncio.gather)
and per-receiver Telegram test sends in notifier (semaphore 5).
- Early-exit collect_scheduled_assets(limit=0) so the periodic-summary
test path skips full per-album filter/sample (was O(album_assets)).
- Emit explicit CREATE INDEX IF NOT EXISTS for event_log user_id /
action_id / provider_id so the first boot after upgrade isn't left
unindexed for the dashboard query.
- Add AbortController timeout (120s) to fetchAuth so uploads/downloads
don't hang indefinitely.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
+46
-27
@@ -94,6 +94,9 @@ async function doRefreshAccessToken(): Promise<boolean> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const DEFAULT_TIMEOUT_MS = 30_000;
|
const DEFAULT_TIMEOUT_MS = 30_000;
|
||||||
|
// Longer cap for fetchAuth — it's used for multipart uploads (backup restore)
|
||||||
|
// and binary downloads where a 30s limit can cut off a legit slow upload.
|
||||||
|
const DEFAULT_FETCHAUTH_TIMEOUT_MS = 120_000;
|
||||||
|
|
||||||
export async function api<T = any>(
|
export async function api<T = any>(
|
||||||
path: string,
|
path: string,
|
||||||
@@ -170,42 +173,58 @@ export async function api<T = any>(
|
|||||||
*/
|
*/
|
||||||
export async function fetchAuth(
|
export async function fetchAuth(
|
||||||
path: string,
|
path: string,
|
||||||
options: RequestInit = {},
|
options: RequestInit & { timeoutMs?: number } = {},
|
||||||
): Promise<Response> {
|
): Promise<Response> {
|
||||||
const token = getToken();
|
const token = getToken();
|
||||||
const headers: Record<string, string> = { ...(options.headers as Record<string, string>) };
|
const headers: Record<string, string> = { ...(options.headers as Record<string, string>) };
|
||||||
if (token) headers['Authorization'] = `Bearer ${token}`;
|
if (token) headers['Authorization'] = `Bearer ${token}`;
|
||||||
|
|
||||||
const url = path.startsWith('http') ? path : `${API_BASE}${path}`;
|
const url = path.startsWith('http') ? path : `${API_BASE}${path}`;
|
||||||
let res = await fetch(url, { ...options, headers });
|
|
||||||
|
|
||||||
if (res.status === 401 && token) {
|
// Abort after timeout so uploads/downloads don't hang indefinitely if
|
||||||
const refreshed = await refreshAccessToken();
|
// the backend stops responding. Callers can override per-request via
|
||||||
if (refreshed) {
|
// options.timeoutMs or pass their own signal to opt out.
|
||||||
headers['Authorization'] = `Bearer ${getToken()}`;
|
const { timeoutMs, ...fetchOptions } = options;
|
||||||
res = await fetch(url, { ...options, headers });
|
const controller = new AbortController();
|
||||||
|
const timeout = setTimeout(
|
||||||
|
() => controller.abort(),
|
||||||
|
timeoutMs ?? DEFAULT_FETCHAUTH_TIMEOUT_MS,
|
||||||
|
);
|
||||||
|
const signal = options.signal ?? controller.signal;
|
||||||
|
|
||||||
|
try {
|
||||||
|
let res = await fetch(url, { ...fetchOptions, headers, signal });
|
||||||
|
|
||||||
|
if (res.status === 401 && token) {
|
||||||
|
const refreshed = await refreshAccessToken();
|
||||||
|
if (refreshed) {
|
||||||
|
headers['Authorization'] = `Bearer ${getToken()}`;
|
||||||
|
res = await fetch(url, { ...fetchOptions, headers, signal });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (res.status === 401) {
|
if (res.status === 401) {
|
||||||
clearTokens();
|
clearTokens();
|
||||||
if (typeof window !== 'undefined') window.location.href = '/login';
|
if (typeof window !== 'undefined') window.location.href = '/login';
|
||||||
throw new ApiError('Unauthorized', 401);
|
throw new ApiError('Unauthorized', 401);
|
||||||
}
|
|
||||||
|
|
||||||
if (!res.ok) {
|
|
||||||
const err = await res.clone().json().catch(() => ({ detail: res.statusText }));
|
|
||||||
if (err && err.detail && typeof err.detail === 'object' && Array.isArray(err.detail.blocked_by)) {
|
|
||||||
const bb: BlockedByDetail = {
|
|
||||||
message: err.detail.message || `HTTP ${res.status}`,
|
|
||||||
entity: err.detail.entity || '',
|
|
||||||
blocked_by: err.detail.blocked_by,
|
|
||||||
};
|
|
||||||
throw new ApiError(bb.message, res.status, bb);
|
|
||||||
}
|
}
|
||||||
const msg = typeof err.detail === 'string' ? err.detail : (err.detail?.message || `HTTP ${res.status}`);
|
|
||||||
throw new ApiError(msg, res.status);
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
if (!res.ok) {
|
||||||
|
const err = await res.clone().json().catch(() => ({ detail: res.statusText }));
|
||||||
|
if (err && err.detail && typeof err.detail === 'object' && Array.isArray(err.detail.blocked_by)) {
|
||||||
|
const bb: BlockedByDetail = {
|
||||||
|
message: err.detail.message || `HTTP ${res.status}`,
|
||||||
|
entity: err.detail.entity || '',
|
||||||
|
blocked_by: err.detail.blocked_by,
|
||||||
|
};
|
||||||
|
throw new ApiError(bb.message, res.status, bb);
|
||||||
|
}
|
||||||
|
const msg = typeof err.detail === 'string' ? err.detail : (err.detail?.message || `HTTP ${res.status}`);
|
||||||
|
throw new ApiError(msg, res.status);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
} finally {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -300,6 +300,16 @@ class TelegramClient:
|
|||||||
# Retry without parse_mode on parse errors
|
# Retry without parse_mode on parse errors
|
||||||
desc = str(result.get("description", ""))
|
desc = str(result.get("description", ""))
|
||||||
if "parse" in desc.lower():
|
if "parse" in desc.lower():
|
||||||
|
# Log loudly: a parse failure means the template author (or
|
||||||
|
# an asset field) is producing malformed HTML. Silent
|
||||||
|
# fallback hides bugs and makes XSS-via-unescaped-field
|
||||||
|
# harder to spot. Do not log the full payload — it may
|
||||||
|
# contain secrets.
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Telegram rejected parse_mode=%s (%r); retrying as plain text. "
|
||||||
|
"Check template output for unescaped characters.",
|
||||||
|
payload.get("parse_mode"), desc,
|
||||||
|
)
|
||||||
payload.pop("parse_mode", None)
|
payload.pop("parse_mode", None)
|
||||||
async with self._session.post(telegram_url, json=payload) as retry_resp:
|
async with self._session.post(telegram_url, json=payload) as retry_resp:
|
||||||
retry_result = await retry_resp.json()
|
retry_result = await retry_resp.json()
|
||||||
|
|||||||
@@ -321,6 +321,12 @@ def collect_scheduled_assets(
|
|||||||
asset_album_map: dict[str, tuple[str, str]] = {} # asset_id → (album_id, public_url)
|
asset_album_map: dict[str, tuple[str, str]] = {} # asset_id → (album_id, public_url)
|
||||||
collections_extra: list[dict[str, Any]] = []
|
collections_extra: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
# limit=0 is the periodic-summary test path — the caller only needs
|
||||||
|
# per-album stats (name/url/counts), not a sample of assets. Skip the
|
||||||
|
# expensive ``filter_assets`` + sampling loop entirely; on a 50k-asset
|
||||||
|
# album the serial scan-then-discard pattern wasted seconds per test.
|
||||||
|
stats_only = limit <= 0
|
||||||
|
|
||||||
for album_id, album in albums.items():
|
for album_id, album in albums.items():
|
||||||
links = shared_links.get(album_id, [])
|
links = shared_links.get(album_id, [])
|
||||||
album_public_url = get_public_url(external_url, links) or ""
|
album_public_url = get_public_url(external_url, links) or ""
|
||||||
@@ -336,6 +342,9 @@ def collect_scheduled_assets(
|
|||||||
"owner": album.owner,
|
"owner": album.owner,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if stats_only:
|
||||||
|
continue
|
||||||
|
|
||||||
filtered = filter_assets(
|
filtered = filter_assets(
|
||||||
list(album.assets.values()),
|
list(album.assets.values()),
|
||||||
favorite_only=favorite_only,
|
favorite_only=favorite_only,
|
||||||
@@ -348,6 +357,9 @@ def collect_scheduled_assets(
|
|||||||
asset_album_map[asset.id] = (album_id, album_public_url)
|
asset_album_map[asset.id] = (album_id, album_public_url)
|
||||||
all_eligible.append(asset)
|
all_eligible.append(asset)
|
||||||
|
|
||||||
|
if stats_only:
|
||||||
|
return [], collections_extra
|
||||||
|
|
||||||
# Random sample
|
# Random sample
|
||||||
if len(all_eligible) > limit:
|
if len(all_eligible) > limit:
|
||||||
selected = random.sample(all_eligible, limit)
|
selected = random.sample(all_eligible, limit)
|
||||||
|
|||||||
@@ -3,14 +3,47 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from ...notifications.ssrf import UnsafeURLError, validate_outbound_url
|
||||||
from .models import ImmichAlbumData, SharedLinkInfo
|
from .models import ImmichAlbumData, SharedLinkInfo
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Cap user-controlled Immich search parameters so a low-privileged command
|
||||||
|
# listener (e.g. an Immich ``/search`` command) cannot DoS the upstream.
|
||||||
|
MAX_SEARCH_QUERY_LEN = 256
|
||||||
|
MAX_SEARCH_PERSON_IDS = 50
|
||||||
|
|
||||||
|
# User-facing error bodies — Immich responses may leak internal paths,
|
||||||
|
# hostnames, or headers injected by intermediary proxies. These helpers keep
|
||||||
|
# only a short, scrubbed summary; full bodies are logged server-side only.
|
||||||
|
_REDACTED_BODY_MAX = 120
|
||||||
|
_SECRET_PATTERN = re.compile(
|
||||||
|
r"(?i)(bearer\s+\S+|x-api-key[:=]\s*\S+|authorization[:=]\s*\S+|cookie[:=]\s*\S+|"
|
||||||
|
r"password[:=]?\s*\S+|token[:=]?\s*[A-Za-z0-9._\-]+)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _redact_body(text: str) -> str:
|
||||||
|
"""Return a short, credential-scrubbed snippet safe to surface to UI callers.
|
||||||
|
|
||||||
|
Immich error responses are admin-configurable (via reverse proxies, custom
|
||||||
|
error pages) and may echo request headers or environment leak. Stripping
|
||||||
|
anything that looks like a credential + capping length keeps us from
|
||||||
|
persisting secrets into ``ActionExecution.error`` / ``EventLog.details``
|
||||||
|
(both of which are returned through the dashboard API).
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
cleaned = _SECRET_PATTERN.sub("[redacted]", text)
|
||||||
|
if len(cleaned) > _REDACTED_BODY_MAX:
|
||||||
|
return cleaned[:_REDACTED_BODY_MAX] + "..."
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
class ImmichClient:
|
class ImmichClient:
|
||||||
"""Async client for the Immich API."""
|
"""Async client for the Immich API."""
|
||||||
@@ -25,6 +58,18 @@ class ImmichClient:
|
|||||||
self._url = url.rstrip("/")
|
self._url = url.rstrip("/")
|
||||||
self._api_key = api_key
|
self._api_key = api_key
|
||||||
self._external_domain: str | None = None
|
self._external_domain: str | None = None
|
||||||
|
# SSRF guard — admin-set Immich URLs are loaded from provider config
|
||||||
|
# which can be mutated via PATCH /api/providers or imported via
|
||||||
|
# prepare-restore, so we revalidate at construction time rather than
|
||||||
|
# trusting DB state. ``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1`` bypasses
|
||||||
|
# for dev against localhost Immich.
|
||||||
|
if self._url:
|
||||||
|
try:
|
||||||
|
validate_outbound_url(self._url)
|
||||||
|
except UnsafeURLError as err:
|
||||||
|
raise UnsafeURLError(
|
||||||
|
f"Refusing to build ImmichClient for unsafe URL {self._url!r}: {err}"
|
||||||
|
) from err
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def url(self) -> str:
|
def url(self) -> str:
|
||||||
@@ -36,6 +81,15 @@ class ImmichClient:
|
|||||||
|
|
||||||
@external_domain.setter
|
@external_domain.setter
|
||||||
def external_domain(self, value: str | None) -> None:
|
def external_domain(self, value: str | None) -> None:
|
||||||
|
# Mirror the constructor's SSRF guard — external_domain is used to
|
||||||
|
# build URLs that leak into rendered notifications, but any code path
|
||||||
|
# that eventually fetches this URL would otherwise bypass the check.
|
||||||
|
if value:
|
||||||
|
try:
|
||||||
|
validate_outbound_url(value)
|
||||||
|
except UnsafeURLError as err:
|
||||||
|
_LOGGER.warning("Ignoring unsafe external_domain %r: %s", value, err)
|
||||||
|
return
|
||||||
self._external_domain = value
|
self._external_domain = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -237,9 +291,12 @@ class ImmichClient:
|
|||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
payload: dict[str, Any] = {"query": query, "page": max(1, page), "size": limit}
|
# Cap user-controlled inputs — a low-privileged Telegram listener can
|
||||||
|
# craft arbitrarily long queries to DoS the upstream Immich.
|
||||||
|
query = (query or "")[:MAX_SEARCH_QUERY_LEN]
|
||||||
|
payload: dict[str, Any] = {"query": query, "page": max(1, page), "size": min(max(1, limit), 100)}
|
||||||
if album_ids:
|
if album_ids:
|
||||||
payload["albumIds"] = album_ids
|
payload["albumIds"] = album_ids[:MAX_SEARCH_PERSON_IDS]
|
||||||
try:
|
try:
|
||||||
async with self._session.post(
|
async with self._session.post(
|
||||||
f"{self._url}/api/search/smart",
|
f"{self._url}/api/search/smart",
|
||||||
@@ -261,9 +318,10 @@ class ImmichClient:
|
|||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
payload: dict[str, Any] = {"originalFileName": query, "page": max(1, page), "size": limit}
|
query = (query or "")[:MAX_SEARCH_QUERY_LEN]
|
||||||
|
payload: dict[str, Any] = {"originalFileName": query, "page": max(1, page), "size": min(max(1, limit), 100)}
|
||||||
if album_ids:
|
if album_ids:
|
||||||
payload["albumIds"] = album_ids
|
payload["albumIds"] = album_ids[:MAX_SEARCH_PERSON_IDS]
|
||||||
try:
|
try:
|
||||||
async with self._session.post(
|
async with self._session.post(
|
||||||
f"{self._url}/api/search/metadata",
|
f"{self._url}/api/search/metadata",
|
||||||
@@ -289,7 +347,7 @@ class ImmichClient:
|
|||||||
to return an empty list on current servers.
|
to return an empty list on current servers.
|
||||||
"""
|
"""
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"personIds": [person_id],
|
"personIds": [person_id][:MAX_SEARCH_PERSON_IDS],
|
||||||
"page": 1,
|
"page": 1,
|
||||||
"size": max(1, min(limit, 100)),
|
"size": max(1, min(limit, 100)),
|
||||||
}
|
}
|
||||||
@@ -373,9 +431,17 @@ class ImmichClient:
|
|||||||
if isinstance(parsed, dict):
|
if isinstance(parsed, dict):
|
||||||
return parsed
|
return parsed
|
||||||
return {"raw": body_text}
|
return {"raw": body_text}
|
||||||
|
# Log full body server-side (for operators), surface only a
|
||||||
|
# redacted snippet to the caller — this string ends up in
|
||||||
|
# ActionExecution.error / EventLog.details which are returned
|
||||||
|
# through the dashboard API.
|
||||||
|
_LOGGER.warning(
|
||||||
|
"add_assets_to_album failed: HTTP %s body=%s",
|
||||||
|
response.status, body_text[:512],
|
||||||
|
)
|
||||||
raise ImmichApiError(
|
raise ImmichApiError(
|
||||||
f"Failed to add assets to album {album_id}: "
|
f"Failed to add assets to album {album_id}: "
|
||||||
f"HTTP {response.status} body={body_text[:512]}"
|
f"HTTP {response.status} {_redact_body(body_text)}"
|
||||||
)
|
)
|
||||||
except aiohttp.ClientError as err:
|
except aiohttp.ClientError as err:
|
||||||
raise ImmichApiError(f"Error adding assets to album: {err}") from err
|
raise ImmichApiError(f"Error adding assets to album: {err}") from err
|
||||||
@@ -399,9 +465,13 @@ class ImmichClient:
|
|||||||
if response.status in (200, 201, 204):
|
if response.status in (200, 201, 204):
|
||||||
return
|
return
|
||||||
body_text = await response.text()
|
body_text = await response.text()
|
||||||
|
_LOGGER.warning(
|
||||||
|
"set_album_thumbnail failed: HTTP %s body=%s",
|
||||||
|
response.status, body_text[:512],
|
||||||
|
)
|
||||||
raise ImmichApiError(
|
raise ImmichApiError(
|
||||||
f"Failed to set album thumbnail for {album_id}: "
|
f"Failed to set album thumbnail for {album_id}: "
|
||||||
f"HTTP {response.status} body={body_text[:512]}"
|
f"HTTP {response.status} {_redact_body(body_text)}"
|
||||||
)
|
)
|
||||||
except aiohttp.ClientError as err:
|
except aiohttp.ClientError as err:
|
||||||
raise ImmichApiError(f"Error setting album thumbnail: {err}") from err
|
raise ImmichApiError(f"Error setting album thumbnail: {err}") from err
|
||||||
|
|||||||
@@ -2,16 +2,67 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from notify_bridge_core.models.events import ServiceEvent
|
from notify_bridge_core.models.events import ServiceEvent
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Per-target maximum video size (bytes). None = no limit.
|
# Per-target maximum video size (bytes). None = no limit.
|
||||||
_MAX_VIDEO_SIZE_BY_TARGET: dict[str, int] = {
|
_MAX_VIDEO_SIZE_BY_TARGET: dict[str, int] = {
|
||||||
"telegram": 50 * 1024 * 1024, # 50 MB — Telegram Bot API hard limit
|
"telegram": 50 * 1024 * 1024, # 50 MB — Telegram Bot API hard limit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Keys that must NEVER flow into the Jinja2 template context, even if a
|
||||||
|
# provider stuffs them into ``event.extra`` (webhooks, Immich metadata, etc.).
|
||||||
|
# Templates that could reach a Telegram/Discord/etc. chat would otherwise
|
||||||
|
# expose operator credentials if a template author simply did ``{{ api_key }}``.
|
||||||
|
# Case-insensitive substring match — any ``extra`` key containing one of these
|
||||||
|
# tokens is dropped before the merge.
|
||||||
|
_SENSITIVE_EXTRA_TOKENS: tuple[str, ...] = (
|
||||||
|
"api_key",
|
||||||
|
"apikey",
|
||||||
|
"token",
|
||||||
|
"secret",
|
||||||
|
"password",
|
||||||
|
"passwd",
|
||||||
|
"hashed_",
|
||||||
|
"authorization",
|
||||||
|
"cookie",
|
||||||
|
"session_id",
|
||||||
|
"bearer",
|
||||||
|
"private_key",
|
||||||
|
"access_key",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_sensitive_key(key: str) -> bool:
|
||||||
|
lowered = str(key).lower()
|
||||||
|
return any(tok in lowered for tok in _SENSITIVE_EXTRA_TOKENS)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_merge_extras(ctx: dict[str, Any], extras: dict[str, Any]) -> None:
|
||||||
|
"""Merge provider ``extras`` into ``ctx``, dropping sensitive keys.
|
||||||
|
|
||||||
|
Dropped keys are logged once per event (DEBUG) so operators can spot
|
||||||
|
leaking providers without flooding the log.
|
||||||
|
"""
|
||||||
|
if not extras:
|
||||||
|
return
|
||||||
|
dropped: list[str] = []
|
||||||
|
for key, value in extras.items():
|
||||||
|
if _is_sensitive_key(key):
|
||||||
|
dropped.append(key)
|
||||||
|
continue
|
||||||
|
ctx[key] = value
|
||||||
|
if dropped:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Dropped %d sensitive key(s) from template context: %s",
|
||||||
|
len(dropped), ", ".join(sorted(dropped)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_template_context(
|
def build_template_context(
|
||||||
event: ServiceEvent,
|
event: ServiceEvent,
|
||||||
@@ -61,8 +112,9 @@ def build_template_context(
|
|||||||
"preview_url": asset.preview_url or "",
|
"preview_url": asset.preview_url or "",
|
||||||
"full_url": asset.full_url or "",
|
"full_url": asset.full_url or "",
|
||||||
}
|
}
|
||||||
# Flatten extras into asset dict for template access
|
# Flatten extras into asset dict for template access — same
|
||||||
asset_dict.update(asset.extra)
|
# sensitive-key filtering applied as the top-level merge.
|
||||||
|
_safe_merge_extras(asset_dict, asset.extra)
|
||||||
asset_dict.setdefault("oversized", False)
|
asset_dict.setdefault("oversized", False)
|
||||||
asset_dict.setdefault("file_size", None)
|
asset_dict.setdefault("file_size", None)
|
||||||
asset_dict.setdefault("playback_size", None)
|
asset_dict.setdefault("playback_size", None)
|
||||||
@@ -138,8 +190,11 @@ def build_template_context(
|
|||||||
if len(locations) == 1 and "" not in locations:
|
if len(locations) == 1 and "" not in locations:
|
||||||
ctx["common_location"] = locations.pop()
|
ctx["common_location"] = locations.pop()
|
||||||
|
|
||||||
# Provider-specific extras merged at top level
|
# Provider-specific extras merged at top level. Sensitive keys (tokens,
|
||||||
ctx.update(event.extra)
|
# secrets, auth headers) are dropped — see ``_SENSITIVE_EXTRA_TOKENS``.
|
||||||
|
# Without this, a template author could exfiltrate provider credentials
|
||||||
|
# via ``{{ api_key }}`` in an outgoing notification body.
|
||||||
|
_safe_merge_extras(ctx, event.extra)
|
||||||
|
|
||||||
# Ensure URL variables always exist (avoid Jinja2 undefined errors)
|
# Ensure URL variables always exist (avoid Jinja2 undefined errors)
|
||||||
ctx.setdefault("public_url", "")
|
ctx.setdefault("public_url", "")
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
"""Configuration backup/restore API (admin only)."""
|
"""Configuration backup/restore API (admin only)."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, UploadFile, File, Query
|
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, UploadFile, File, Query
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -28,6 +30,11 @@ PENDING_RESTORE_PATH_KEY = "pending_restore_path"
|
|||||||
PENDING_RESTORE_CONFLICT_KEY = "pending_restore_conflict_mode"
|
PENDING_RESTORE_CONFLICT_KEY = "pending_restore_conflict_mode"
|
||||||
PENDING_RESTORE_UPLOADED_AT_KEY = "pending_restore_uploaded_at"
|
PENDING_RESTORE_UPLOADED_AT_KEY = "pending_restore_uploaded_at"
|
||||||
PENDING_RESTORE_UPLOADED_BY_KEY = "pending_restore_uploaded_by"
|
PENDING_RESTORE_UPLOADED_BY_KEY = "pending_restore_uploaded_by"
|
||||||
|
# SHA256 of the staged pending_restore.json, written atomically with the file.
|
||||||
|
# The startup hook refuses to apply if the on-disk file's hash does not match —
|
||||||
|
# defends against anyone dropping a tampered file into data/ between prepare
|
||||||
|
# and restart.
|
||||||
|
PENDING_RESTORE_SHA256_KEY = "pending_restore_sha256"
|
||||||
|
|
||||||
|
|
||||||
def _pending_restore_path():
|
def _pending_restore_path():
|
||||||
@@ -44,6 +51,69 @@ router = APIRouter(prefix="/api/backup", tags=["backup"])
|
|||||||
MAX_UPLOAD_SIZE = 10 * 1024 * 1024 # 10 MB
|
MAX_UPLOAD_SIZE = 10 * 1024 * 1024 # 10 MB
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_upload_bounded(file: UploadFile, max_bytes: int = MAX_UPLOAD_SIZE) -> bytes:
|
||||||
|
"""Read an UploadFile into memory, failing fast if it exceeds ``max_bytes``.
|
||||||
|
|
||||||
|
Rejects on ``content_length`` header up-front when available; always
|
||||||
|
stream-reads with a running byte counter so we never allocate more than
|
||||||
|
the limit even when the header is missing or lies.
|
||||||
|
"""
|
||||||
|
# Fast path: reject on header before we allocate anything.
|
||||||
|
cl = file.headers.get("content-length") if hasattr(file, "headers") else None
|
||||||
|
if cl:
|
||||||
|
try:
|
||||||
|
if int(cl) > max_bytes:
|
||||||
|
raise HTTPException(status_code=400, detail="File too large (max 10 MB)")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
chunks: list[bytes] = []
|
||||||
|
total = 0
|
||||||
|
while True:
|
||||||
|
chunk = await file.read(64 * 1024)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
total += len(chunk)
|
||||||
|
if total > max_bytes:
|
||||||
|
raise HTTPException(status_code=400, detail="File too large (max 10 MB)")
|
||||||
|
chunks.append(chunk)
|
||||||
|
return b"".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_same_origin(request: Request) -> None:
|
||||||
|
"""Reject cross-origin admin-write POSTs (CSRF defense).
|
||||||
|
|
||||||
|
Bearer tokens in ``localStorage`` plus cookie-less CORS mean a malicious
|
||||||
|
page cannot technically submit our Authorization header from a victim's
|
||||||
|
session, BUT browser extensions and misconfigured CORS policies routinely
|
||||||
|
break this assumption. For endpoints whose blast radius is restart/RCE-
|
||||||
|
equivalent (restore apply), we additionally require the request to come
|
||||||
|
from our own origin.
|
||||||
|
"""
|
||||||
|
host = request.headers.get("host", "").lower()
|
||||||
|
if not host:
|
||||||
|
raise HTTPException(status_code=400, detail="Missing Host header")
|
||||||
|
|
||||||
|
def _host_of(u: str | None) -> str:
|
||||||
|
if not u:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
return (urlparse(u).netloc or "").lower()
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return ""
|
||||||
|
|
||||||
|
origin_host = _host_of(request.headers.get("origin"))
|
||||||
|
referer_host = _host_of(request.headers.get("referer"))
|
||||||
|
# At least one of Origin/Referer must be present and match Host.
|
||||||
|
# Legitimate browser requests to this endpoint always ship Origin.
|
||||||
|
same = (origin_host and origin_host == host) or (referer_host and referer_host == host)
|
||||||
|
if not same:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Cross-origin request rejected",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _backup_dir():
|
def _backup_dir():
|
||||||
return app_config.data_dir / "backups"
|
return app_config.data_dir / "backups"
|
||||||
|
|
||||||
@@ -104,9 +174,7 @@ async def validate_config(
|
|||||||
user: User = Depends(require_admin),
|
user: User = Depends(require_admin),
|
||||||
):
|
):
|
||||||
"""Validate a backup file without importing."""
|
"""Validate a backup file without importing."""
|
||||||
content = await file.read()
|
content = await _read_upload_bounded(file)
|
||||||
if len(content) > MAX_UPLOAD_SIZE:
|
|
||||||
raise HTTPException(status_code=400, detail="File too large (max 10 MB)")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw = json.loads(content)
|
raw = json.loads(content)
|
||||||
@@ -129,9 +197,7 @@ async def import_config(
|
|||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""Import configuration from a backup file."""
|
"""Import configuration from a backup file."""
|
||||||
content = await file.read()
|
content = await _read_upload_bounded(file)
|
||||||
if len(content) > MAX_UPLOAD_SIZE:
|
|
||||||
raise HTTPException(status_code=400, detail="File too large (max 10 MB)")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw = json.loads(content)
|
raw = json.loads(content)
|
||||||
@@ -167,6 +233,7 @@ async def _clear_pending_restore_markers(session: AsyncSession) -> None:
|
|||||||
PENDING_RESTORE_CONFLICT_KEY,
|
PENDING_RESTORE_CONFLICT_KEY,
|
||||||
PENDING_RESTORE_UPLOADED_AT_KEY,
|
PENDING_RESTORE_UPLOADED_AT_KEY,
|
||||||
PENDING_RESTORE_UPLOADED_BY_KEY,
|
PENDING_RESTORE_UPLOADED_BY_KEY,
|
||||||
|
PENDING_RESTORE_SHA256_KEY,
|
||||||
):
|
):
|
||||||
row = await session.get(AppSetting, key)
|
row = await session.get(AppSetting, key)
|
||||||
if row:
|
if row:
|
||||||
@@ -185,9 +252,7 @@ async def prepare_restore(
|
|||||||
Validates the uploaded file, writes it to ``data/pending_restore.json``,
|
Validates the uploaded file, writes it to ``data/pending_restore.json``,
|
||||||
and persists marker settings so startup will apply it atomically.
|
and persists marker settings so startup will apply it atomically.
|
||||||
"""
|
"""
|
||||||
content = await file.read()
|
content = await _read_upload_bounded(file)
|
||||||
if len(content) > MAX_UPLOAD_SIZE:
|
|
||||||
raise HTTPException(status_code=400, detail="File too large (max 10 MB)")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw = json.loads(content)
|
raw = json.loads(content)
|
||||||
@@ -205,15 +270,25 @@ async def prepare_restore(
|
|||||||
pending_path.parent.mkdir(parents=True, exist_ok=True)
|
pending_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
# Atomic write: write to tmp then rename, so a crash mid-write never
|
# Atomic write: write to tmp then rename, so a crash mid-write never
|
||||||
# leaves a truncated pending_restore.json that would break startup apply.
|
# leaves a truncated pending_restore.json that would break startup apply.
|
||||||
|
payload = json.dumps(raw).encode("utf-8")
|
||||||
|
digest = hashlib.sha256(payload).hexdigest()
|
||||||
tmp_path = pending_path.with_suffix(pending_path.suffix + ".tmp")
|
tmp_path = pending_path.with_suffix(pending_path.suffix + ".tmp")
|
||||||
tmp_path.write_text(json.dumps(raw), encoding="utf-8")
|
tmp_path.write_bytes(payload)
|
||||||
os.replace(tmp_path, pending_path)
|
os.replace(tmp_path, pending_path)
|
||||||
|
# Best-effort tighten perms so a non-root local user cannot swap the file
|
||||||
|
# for one they control between prepare and restart. On Windows this is a
|
||||||
|
# no-op; on POSIX we restrict to owner-only rw.
|
||||||
|
try:
|
||||||
|
os.chmod(pending_path, 0o600)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
now_iso = datetime.now(timezone.utc).isoformat()
|
now_iso = datetime.now(timezone.utc).isoformat()
|
||||||
await _set_app_setting(session, PENDING_RESTORE_PATH_KEY, str(pending_path))
|
await _set_app_setting(session, PENDING_RESTORE_PATH_KEY, str(pending_path))
|
||||||
await _set_app_setting(session, PENDING_RESTORE_CONFLICT_KEY, conflict_mode.value)
|
await _set_app_setting(session, PENDING_RESTORE_CONFLICT_KEY, conflict_mode.value)
|
||||||
await _set_app_setting(session, PENDING_RESTORE_UPLOADED_AT_KEY, now_iso)
|
await _set_app_setting(session, PENDING_RESTORE_UPLOADED_AT_KEY, now_iso)
|
||||||
await _set_app_setting(session, PENDING_RESTORE_UPLOADED_BY_KEY, user.username)
|
await _set_app_setting(session, PENDING_RESTORE_UPLOADED_BY_KEY, user.username)
|
||||||
|
await _set_app_setting(session, PENDING_RESTORE_SHA256_KEY, digest)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -292,6 +367,7 @@ def _is_supervised() -> bool:
|
|||||||
|
|
||||||
@router.post("/apply-restart")
|
@router.post("/apply-restart")
|
||||||
async def apply_and_restart(
|
async def apply_and_restart(
|
||||||
|
request: Request,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
user: User = Depends(require_admin),
|
user: User = Depends(require_admin),
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
@@ -299,7 +375,11 @@ async def apply_and_restart(
|
|||||||
"""Trigger a graceful exit so the supervisor respawns and applies the pending restore.
|
"""Trigger a graceful exit so the supervisor respawns and applies the pending restore.
|
||||||
|
|
||||||
Only allowed when a pending restore is staged AND the process is supervised.
|
Only allowed when a pending restore is staged AND the process is supervised.
|
||||||
|
Requires same-origin Origin/Referer — this endpoint's blast radius is a
|
||||||
|
full config replace + restart, so an admin token alone (vulnerable to
|
||||||
|
XSS-driven CSRF) is not enough.
|
||||||
"""
|
"""
|
||||||
|
_check_same_origin(request)
|
||||||
path_row = await session.get(AppSetting, PENDING_RESTORE_PATH_KEY)
|
path_row = await session.get(AppSetting, PENDING_RESTORE_PATH_KEY)
|
||||||
if not path_row or not path_row.value:
|
if not path_row or not path_row.value:
|
||||||
raise HTTPException(status_code=409, detail="No pending restore to apply")
|
raise HTTPException(status_code=409, detail="No pending restore to apply")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import logging
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
@@ -81,6 +82,12 @@ async def update_user(
|
|||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
|
# Track whether the identity that JWTs encode has changed. Any such change
|
||||||
|
# must bump ``token_version`` so already-issued tokens are rejected — a
|
||||||
|
# user demoted admin→user must not keep admin in their cached JWT until
|
||||||
|
# expiry, and a rename should invalidate prior sessions too.
|
||||||
|
identity_changed = False
|
||||||
|
|
||||||
if body.username is not None and body.username != user.username:
|
if body.username is not None and body.username != user.username:
|
||||||
new_username = body.username.strip()
|
new_username = body.username.strip()
|
||||||
if not new_username:
|
if not new_username:
|
||||||
@@ -89,21 +96,51 @@ async def update_user(
|
|||||||
if dup.first():
|
if dup.first():
|
||||||
raise HTTPException(status_code=409, detail="Username already exists")
|
raise HTTPException(status_code=409, detail="Username already exists")
|
||||||
user.username = new_username
|
user.username = new_username
|
||||||
|
identity_changed = True
|
||||||
|
|
||||||
if body.role is not None and body.role != user.role:
|
if body.role is not None and body.role != user.role:
|
||||||
if body.role not in ("admin", "user"):
|
if body.role not in ("admin", "user"):
|
||||||
raise HTTPException(status_code=400, detail="Invalid role")
|
raise HTTPException(status_code=400, detail="Invalid role")
|
||||||
# Prevent demoting the last admin
|
# Prevent demoting the last admin. Done via a COUNT to avoid loading
|
||||||
|
# every admin row; more importantly, re-checked *after* the role
|
||||||
|
# change is staged (TOCTOU guard — two concurrent demotes can each
|
||||||
|
# see admin_count=2 and both proceed, dropping to 0).
|
||||||
if user.role == "admin" and body.role != "admin":
|
if user.role == "admin" and body.role != "admin":
|
||||||
admins = (await session.exec(
|
admin_count = (await session.exec(
|
||||||
select(User).where(User.role == "admin")
|
select(func.count(User.id)).where(User.role == "admin")
|
||||||
)).all()
|
)).one()
|
||||||
if len(admins) <= 1:
|
if isinstance(admin_count, tuple):
|
||||||
|
admin_count = admin_count[0]
|
||||||
|
if (admin_count or 0) <= 1:
|
||||||
raise HTTPException(status_code=400, detail="Cannot demote the last admin")
|
raise HTTPException(status_code=400, detail="Cannot demote the last admin")
|
||||||
user.role = body.role
|
user.role = body.role
|
||||||
|
identity_changed = True
|
||||||
|
|
||||||
|
if identity_changed:
|
||||||
|
user.token_version = (user.token_version or 1) + 1
|
||||||
|
|
||||||
session.add(user)
|
session.add(user)
|
||||||
await session.commit()
|
try:
|
||||||
|
await session.commit()
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Final defense against admin-count race: if we just demoted the last admin
|
||||||
|
# due to a concurrent demote landing between our check and commit, undo.
|
||||||
|
if body.role is not None and body.role != "admin":
|
||||||
|
admin_count_after = (await session.exec(
|
||||||
|
select(func.count(User.id)).where(User.role == "admin")
|
||||||
|
)).one()
|
||||||
|
if isinstance(admin_count_after, tuple):
|
||||||
|
admin_count_after = admin_count_after[0]
|
||||||
|
if (admin_count_after or 0) < 1:
|
||||||
|
# Roll the user back to admin and re-commit.
|
||||||
|
user.role = "admin"
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
raise HTTPException(status_code=409, detail="Refused: would remove the last admin")
|
||||||
|
|
||||||
await session.refresh(user)
|
await session.refresh(user)
|
||||||
return {"id": user.id, "username": user.username, "role": user.role}
|
return {"id": user.id, "username": user.username, "role": user.role}
|
||||||
|
|
||||||
@@ -126,6 +163,9 @@ async def reset_user_password(
|
|||||||
if len(body.new_password) < 8:
|
if len(body.new_password) < 8:
|
||||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||||
user.hashed_password = bcrypt.hashpw(body.new_password.encode(), bcrypt.gensalt()).decode()
|
user.hashed_password = bcrypt.hashpw(body.new_password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
# Invalidate all prior JWTs issued for this user — matches the self-serve
|
||||||
|
# password-change path in auth/routes.py.
|
||||||
|
user.token_version = (user.token_version or 1) + 1
|
||||||
session.add(user)
|
session.add(user)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
|
|||||||
@@ -92,6 +92,21 @@ async def migrate_schema(engine: AsyncEngine) -> None:
|
|||||||
await conn.execute(text(sql))
|
await conn.execute(text(sql))
|
||||||
logger.info("Added %s column to event_log table", col)
|
logger.info("Added %s column to event_log table", col)
|
||||||
|
|
||||||
|
# Explicit indexes on the dashboard-query columns. SQLModel's
|
||||||
|
# ``index=True`` is emitted by ``create_all`` on *new* installs,
|
||||||
|
# but ALTER TABLE ADD COLUMN doesn't create them on upgrades —
|
||||||
|
# so the first boot after upgrade would leave these unindexed
|
||||||
|
# and status.py ``WHERE user_id=...`` would table-scan. The
|
||||||
|
# indexes are redundant-but-safe once create_all also runs.
|
||||||
|
for idx_name, col in [
|
||||||
|
("ix_event_log_user_id", "user_id"),
|
||||||
|
("ix_event_log_action_id", "action_id"),
|
||||||
|
("ix_event_log_provider_id", "provider_id"),
|
||||||
|
]:
|
||||||
|
await conn.execute(
|
||||||
|
text(f"CREATE INDEX IF NOT EXISTS {idx_name} ON event_log ({col})")
|
||||||
|
)
|
||||||
|
|
||||||
# Backfill user_id from notification_tracker for legacy rows.
|
# Backfill user_id from notification_tracker for legacy rows.
|
||||||
# Safe to run repeatedly: only touches rows where user_id is still NULL.
|
# Safe to run repeatedly: only touches rows where user_id is still NULL.
|
||||||
await conn.execute(text("""
|
await conn.execute(text("""
|
||||||
@@ -250,6 +265,21 @@ async def migrate_schema(engine: AsyncEngine) -> None:
|
|||||||
)
|
)
|
||||||
logger.info("Added track_webhook_received column to tracking_config table")
|
logger.info("Added track_webhook_received column to tracking_config table")
|
||||||
|
|
||||||
|
# Add quiet hours to tracking_config if missing.
|
||||||
|
# Start/end are nullable HH:MM strings; quiet_hours_enabled gates them.
|
||||||
|
if await _has_table(conn, "tracking_config"):
|
||||||
|
if not await _has_column(conn, "tracking_config", "quiet_hours_enabled"):
|
||||||
|
await conn.execute(
|
||||||
|
text("ALTER TABLE tracking_config ADD COLUMN quiet_hours_enabled INTEGER DEFAULT 0")
|
||||||
|
)
|
||||||
|
logger.info("Added quiet_hours_enabled column to tracking_config table")
|
||||||
|
for col_name in ("quiet_hours_start", "quiet_hours_end"):
|
||||||
|
if not await _has_column(conn, "tracking_config", col_name):
|
||||||
|
await conn.execute(
|
||||||
|
text(f"ALTER TABLE tracking_config ADD COLUMN {col_name} TEXT")
|
||||||
|
)
|
||||||
|
logger.info("Added %s column to tracking_config table", col_name)
|
||||||
|
|
||||||
# Drop legacy template content columns from template_config
|
# Drop legacy template content columns from template_config
|
||||||
# (template content moved to template_slot child rows)
|
# (template content moved to template_slot child rows)
|
||||||
if await _has_table(conn, "template_config"):
|
if await _has_table(conn, "template_config"):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Notification sender — unified send logic for all paths (dispatch + test)."""
|
"""Notification sender — unified send logic for all paths (dispatch + test)."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -11,6 +12,10 @@ from ..database.models import NotificationTarget, TargetReceiver
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Cap on concurrent per-receiver test sends. Keeps us under Telegram's per-bot
|
||||||
|
# rate limit (~30 msg/s) while still saving ~N×RTT on multi-chat broadcasts.
|
||||||
|
_TEST_SEND_CONCURRENCY = 5
|
||||||
|
|
||||||
_TEST_MESSAGES: dict[str, dict[str, str]] = {
|
_TEST_MESSAGES: dict[str, dict[str, str]] = {
|
||||||
"en": {
|
"en": {
|
||||||
"telegram": "\u2705 Test message from <b>Notify Bridge</b>",
|
"telegram": "\u2705 Test message from <b>Notify Bridge</b>",
|
||||||
@@ -358,19 +363,29 @@ async def _send_telegram_test_per_receiver(
|
|||||||
|
|
||||||
http = await get_http_session()
|
http = await get_http_session()
|
||||||
client = TelegramClient(http, bot_token)
|
client = TelegramClient(http, bot_token)
|
||||||
results: list[dict] = []
|
|
||||||
for r in recv_rows:
|
# Parallelize per-receiver sends with a small semaphore — broadcast to
|
||||||
|
# N chats now takes ~ceil(N / concurrency) × RTT instead of N × RTT,
|
||||||
|
# matching the dispatcher's bounded-concurrency pattern. Capped below
|
||||||
|
# Telegram's rate limit so we don't trigger 429s on large fleets.
|
||||||
|
sem = asyncio.Semaphore(_TEST_SEND_CONCURRENCY)
|
||||||
|
|
||||||
|
async def _send_one(r: TargetReceiver) -> dict | None:
|
||||||
chat_id = str(r.config.get("chat_id", ""))
|
chat_id = str(r.config.get("chat_id", ""))
|
||||||
if not chat_id:
|
if not chat_id:
|
||||||
continue
|
return None
|
||||||
explicit = getattr(r, "locale", "") or ""
|
explicit = getattr(r, "locale", "") or ""
|
||||||
locale = explicit or chat_locale_map.get(chat_id) or default_locale
|
locale = explicit or chat_locale_map.get(chat_id) or default_locale
|
||||||
message = _get_test_message(locale[:2].lower(), "telegram")
|
message = _get_test_message(locale[:2].lower(), "telegram")
|
||||||
results.append(await client.send_message(
|
async with sem:
|
||||||
chat_id=chat_id,
|
return await client.send_message(
|
||||||
text=message,
|
chat_id=chat_id,
|
||||||
disable_web_page_preview=bool(disable_preview),
|
text=message,
|
||||||
))
|
disable_web_page_preview=bool(disable_preview),
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = await asyncio.gather(*(_send_one(r) for r in recv_rows))
|
||||||
|
results = [r for r in raw if r is not None]
|
||||||
return _aggregate(results)
|
return _aggregate(results)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,10 +10,19 @@ If the apply fails, the pending file is kept so the operator can inspect it
|
|||||||
and markers are updated to record the last error. On success, the staged file
|
and markers are updated to record the last error. On success, the staged file
|
||||||
is archived under data/applied_restores/<timestamp>.json and markers are
|
is archived under data/applied_restores/<timestamp>.json and markers are
|
||||||
cleared.
|
cleared.
|
||||||
|
|
||||||
|
Integrity checks on startup:
|
||||||
|
- The on-disk file's SHA256 must match ``PENDING_RESTORE_SHA256_KEY``
|
||||||
|
(written atomically with the staged file). Protects against tampering
|
||||||
|
between prepare and restart.
|
||||||
|
- The pending path must resolve *inside* ``app_config.data_dir``. Protects
|
||||||
|
against a rogue AppSetting pointing at an arbitrary file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
@@ -24,11 +33,13 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
from ..api.backup import (
|
from ..api.backup import (
|
||||||
PENDING_RESTORE_CONFLICT_KEY,
|
PENDING_RESTORE_CONFLICT_KEY,
|
||||||
PENDING_RESTORE_PATH_KEY,
|
PENDING_RESTORE_PATH_KEY,
|
||||||
|
PENDING_RESTORE_SHA256_KEY,
|
||||||
PENDING_RESTORE_UPLOADED_AT_KEY,
|
PENDING_RESTORE_UPLOADED_AT_KEY,
|
||||||
PENDING_RESTORE_UPLOADED_BY_KEY,
|
PENDING_RESTORE_UPLOADED_BY_KEY,
|
||||||
_applied_restores_dir,
|
_applied_restores_dir,
|
||||||
_pending_restore_path,
|
_pending_restore_path,
|
||||||
)
|
)
|
||||||
|
from ..config import settings as app_config
|
||||||
from ..database.engine import get_engine
|
from ..database.engine import get_engine
|
||||||
from ..database.models import AppSetting
|
from ..database.models import AppSetting
|
||||||
from .backup_schema import BackupFile, ConflictMode
|
from .backup_schema import BackupFile, ConflictMode
|
||||||
@@ -49,6 +60,23 @@ async def apply_pending_restore_if_any() -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
pending_path = _pending_restore_path()
|
pending_path = _pending_restore_path()
|
||||||
|
|
||||||
|
# Defensive: ensure the hard-coded path still lives inside data_dir.
|
||||||
|
# If future refactors let this be read from AppSetting, this check
|
||||||
|
# blocks arbitrary-file reads.
|
||||||
|
try:
|
||||||
|
resolved = pending_path.resolve()
|
||||||
|
data_root = app_config.data_dir.resolve()
|
||||||
|
resolved.relative_to(data_root)
|
||||||
|
except (ValueError, OSError):
|
||||||
|
_LOGGER.error(
|
||||||
|
"Pending-restore path %s is outside data_dir %s — refusing to apply",
|
||||||
|
pending_path, app_config.data_dir,
|
||||||
|
)
|
||||||
|
await _record_error(session, "Pending path outside data_dir")
|
||||||
|
await session.commit()
|
||||||
|
return
|
||||||
|
|
||||||
if not pending_path.exists():
|
if not pending_path.exists():
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"Pending-restore marker present but file missing at %s — clearing marker",
|
"Pending-restore marker present but file missing at %s — clearing marker",
|
||||||
@@ -62,9 +90,42 @@ async def apply_pending_restore_if_any() -> None:
|
|||||||
conflict_mode = ConflictMode(conflict_row.value) if conflict_row and conflict_row.value else ConflictMode.SKIP
|
conflict_mode = ConflictMode(conflict_row.value) if conflict_row and conflict_row.value else ConflictMode.SKIP
|
||||||
uploaded_by_row = await session.get(AppSetting, PENDING_RESTORE_UPLOADED_BY_KEY)
|
uploaded_by_row = await session.get(AppSetting, PENDING_RESTORE_UPLOADED_BY_KEY)
|
||||||
uploaded_by = uploaded_by_row.value if uploaded_by_row else "admin"
|
uploaded_by = uploaded_by_row.value if uploaded_by_row else "admin"
|
||||||
|
sha_row = await session.get(AppSetting, PENDING_RESTORE_SHA256_KEY)
|
||||||
|
expected_sha = (sha_row.value or "").strip().lower() if sha_row else ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw = json.loads(pending_path.read_text(encoding="utf-8"))
|
raw_bytes = await asyncio.to_thread(pending_path.read_bytes)
|
||||||
|
except OSError as err:
|
||||||
|
_LOGGER.exception("Pending-restore file unreadable")
|
||||||
|
await _record_error(session, f"Unreadable backup: {err}")
|
||||||
|
await session.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Integrity: reject unless hash matches what prepare-restore stored.
|
||||||
|
# An attacker with write access to data/ (swapped file, bind-mount
|
||||||
|
# abuse) does not also have write access to the DB.
|
||||||
|
if not expected_sha:
|
||||||
|
_LOGGER.error("Pending-restore marker has no SHA256; refusing to apply")
|
||||||
|
await _record_error(session, "Missing integrity marker")
|
||||||
|
await session.commit()
|
||||||
|
return
|
||||||
|
actual_sha = hashlib.sha256(raw_bytes).hexdigest()
|
||||||
|
if actual_sha != expected_sha:
|
||||||
|
_LOGGER.error(
|
||||||
|
"Pending-restore SHA256 mismatch (expected %s, got %s) — refusing to apply",
|
||||||
|
expected_sha, actual_sha,
|
||||||
|
)
|
||||||
|
await _record_error(
|
||||||
|
session,
|
||||||
|
"Integrity check failed: on-disk backup SHA256 does not match the hash "
|
||||||
|
"recorded at prepare time. File may have been tampered with; cancel and "
|
||||||
|
"re-upload.",
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = json.loads(raw_bytes.decode("utf-8"))
|
||||||
backup = BackupFile.model_validate(raw)
|
backup = BackupFile.model_validate(raw)
|
||||||
except Exception as err: # noqa: BLE001
|
except Exception as err: # noqa: BLE001
|
||||||
_LOGGER.exception("Pending-restore file unreadable")
|
_LOGGER.exception("Pending-restore file unreadable")
|
||||||
@@ -88,8 +149,14 @@ async def apply_pending_restore_if_any() -> None:
|
|||||||
result = await import_backup(session, admin_row.id, backup, conflict_mode)
|
result = await import_backup(session, admin_row.id, backup, conflict_mode)
|
||||||
except Exception as err: # noqa: BLE001
|
except Exception as err: # noqa: BLE001
|
||||||
_LOGGER.exception("Pending-restore apply failed")
|
_LOGGER.exception("Pending-restore apply failed")
|
||||||
await _record_error(session, str(err))
|
# Discard any partial inserts the importer made before raising —
|
||||||
await session.commit()
|
# committing partial state would let a crafted failing backup
|
||||||
|
# selectively mutate entities. The error-record commit below
|
||||||
|
# happens on a *fresh* session.
|
||||||
|
await session.rollback()
|
||||||
|
async with AsyncSession(engine) as fresh:
|
||||||
|
await _record_error(fresh, str(err))
|
||||||
|
await fresh.commit()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Archive the file
|
# Archive the file
|
||||||
@@ -136,6 +203,7 @@ async def _clear_markers(session: AsyncSession) -> None:
|
|||||||
PENDING_RESTORE_CONFLICT_KEY,
|
PENDING_RESTORE_CONFLICT_KEY,
|
||||||
PENDING_RESTORE_UPLOADED_AT_KEY,
|
PENDING_RESTORE_UPLOADED_AT_KEY,
|
||||||
PENDING_RESTORE_UPLOADED_BY_KEY,
|
PENDING_RESTORE_UPLOADED_BY_KEY,
|
||||||
|
PENDING_RESTORE_SHA256_KEY,
|
||||||
):
|
):
|
||||||
row = await session.get(AppSetting, key)
|
row = await session.get(AppSetting, key)
|
||||||
if row:
|
if row:
|
||||||
|
|||||||
@@ -166,30 +166,41 @@ async def _refresh_telegram_chat_titles() -> None:
|
|||||||
|
|
||||||
refreshed = 0
|
refreshed = 0
|
||||||
errors = 0
|
errors = 0
|
||||||
|
# Bucket results first, then fetch all rows in one IN-query instead of
|
||||||
|
# per-row ``session.get`` — otherwise a 50-chat fleet issues 50 extra
|
||||||
|
# SELECTs before commit.
|
||||||
|
successes: dict[int, dict] = {}
|
||||||
|
for chat_id, info, err in results:
|
||||||
|
if err is not None or info is None:
|
||||||
|
errors += 1
|
||||||
|
if err:
|
||||||
|
_LOGGER.debug("getChat failed for chat row %s: %s", chat_id, err)
|
||||||
|
continue
|
||||||
|
if chat_id is not None:
|
||||||
|
successes[chat_id] = info
|
||||||
async with AsyncSession(engine) as session:
|
async with AsyncSession(engine) as session:
|
||||||
for chat_id, info, err in results:
|
if successes:
|
||||||
if err is not None or info is None:
|
rows = (await session.exec(
|
||||||
errors += 1
|
select(TelegramChat).where(TelegramChat.id.in_(list(successes.keys())))
|
||||||
if err:
|
)).all()
|
||||||
_LOGGER.debug("getChat failed for chat row %s: %s", chat_id, err)
|
for merged in rows:
|
||||||
continue
|
info = successes.get(merged.id)
|
||||||
merged = await session.get(TelegramChat, chat_id)
|
if not info:
|
||||||
if not merged:
|
continue
|
||||||
continue
|
title = info.get("title") or (
|
||||||
title = info.get("title") or (
|
(info.get("first_name", "") + " " + info.get("last_name", "")).strip()
|
||||||
(info.get("first_name", "") + " " + info.get("last_name", "")).strip()
|
)
|
||||||
)
|
changed = False
|
||||||
changed = False
|
if title and merged.title != title:
|
||||||
if title and merged.title != title:
|
merged.title = title
|
||||||
merged.title = title
|
changed = True
|
||||||
changed = True
|
new_username = info.get("username")
|
||||||
new_username = info.get("username")
|
if new_username is not None and merged.username != new_username:
|
||||||
if new_username is not None and merged.username != new_username:
|
merged.username = new_username
|
||||||
merged.username = new_username
|
changed = True
|
||||||
changed = True
|
if changed:
|
||||||
if changed:
|
session.add(merged)
|
||||||
session.add(merged)
|
refreshed += 1
|
||||||
refreshed += 1
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
"Telegram chat title refresh: %s updated, %s errors", refreshed, errors
|
"Telegram chat title refresh: %s updated, %s errors", refreshed, errors
|
||||||
|
|||||||
@@ -250,14 +250,23 @@ async def _build_immich_event(
|
|||||||
collection_ids, limit, asset_type, favorite_only, min_rating,
|
collection_ids, limit, asset_type, favorite_only, min_rating,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Album-based path: use shared collect_scheduled_assets
|
# Album-based path: use shared collect_scheduled_assets.
|
||||||
|
# Fetch albums + shared links in parallel — on a 20-album tracker the old
|
||||||
|
# serial ``await`` loop took ~2 × 20 × RTT, now it's one round-trip.
|
||||||
|
import asyncio as _asyncio
|
||||||
|
album_tasks = [immich.client.get_album(aid) for aid in collection_ids]
|
||||||
|
link_tasks = [immich.client.get_shared_links(aid) for aid in collection_ids]
|
||||||
|
album_results, link_results = await _asyncio.gather(
|
||||||
|
_asyncio.gather(*album_tasks, return_exceptions=True),
|
||||||
|
_asyncio.gather(*link_tasks, return_exceptions=True),
|
||||||
|
)
|
||||||
albums: dict[str, ImmichAlbumData] = {}
|
albums: dict[str, ImmichAlbumData] = {}
|
||||||
shared_links: dict[str, list[SharedLinkInfo]] = {}
|
shared_links: dict[str, list[SharedLinkInfo]] = {}
|
||||||
for album_id in collection_ids:
|
for album_id, album, links in zip(collection_ids, album_results, link_results):
|
||||||
album = await immich.client.get_album(album_id)
|
if isinstance(album, Exception) or album is None:
|
||||||
if album:
|
continue
|
||||||
albums[album_id] = album
|
albums[album_id] = album
|
||||||
shared_links[album_id] = await immich.client.get_shared_links(album_id)
|
shared_links[album_id] = links if not isinstance(links, Exception) else []
|
||||||
|
|
||||||
assets, collections_extra = collect_scheduled_assets(
|
assets, collections_extra = collect_scheduled_assets(
|
||||||
albums, shared_links, ext_domain,
|
albums, shared_links, ext_domain,
|
||||||
@@ -320,13 +329,21 @@ async def _build_immich_periodic_event(
|
|||||||
|
|
||||||
ext_domain = provider_config.get("external_domain") or provider_config.get("url", "")
|
ext_domain = provider_config.get("external_domain") or provider_config.get("url", "")
|
||||||
|
|
||||||
|
# Parallel fetch — see _build_immich_event above for the same rationale.
|
||||||
|
import asyncio as _asyncio
|
||||||
|
album_tasks = [immich.client.get_album(aid) for aid in collection_ids]
|
||||||
|
link_tasks = [immich.client.get_shared_links(aid) for aid in collection_ids]
|
||||||
|
album_results, link_results = await _asyncio.gather(
|
||||||
|
_asyncio.gather(*album_tasks, return_exceptions=True),
|
||||||
|
_asyncio.gather(*link_tasks, return_exceptions=True),
|
||||||
|
)
|
||||||
albums: dict[str, ImmichAlbumData] = {}
|
albums: dict[str, ImmichAlbumData] = {}
|
||||||
shared_links: dict[str, list[SharedLinkInfo]] = {}
|
shared_links: dict[str, list[SharedLinkInfo]] = {}
|
||||||
for album_id in collection_ids:
|
for album_id, album, links in zip(collection_ids, album_results, link_results):
|
||||||
album = await immich.client.get_album(album_id)
|
if isinstance(album, Exception) or album is None:
|
||||||
if album:
|
continue
|
||||||
albums[album_id] = album
|
albums[album_id] = album
|
||||||
shared_links[album_id] = await immich.client.get_shared_links(album_id)
|
shared_links[album_id] = links if not isinstance(links, Exception) else []
|
||||||
|
|
||||||
# limit=0 → returns ([], collections_extra) with full per-album stats.
|
# limit=0 → returns ([], collections_extra) with full per-album stats.
|
||||||
_assets, collections_extra = collect_scheduled_assets(
|
_assets, collections_extra = collect_scheduled_assets(
|
||||||
|
|||||||
@@ -25,6 +25,9 @@ fi
|
|||||||
# Start backend
|
# Start backend
|
||||||
export NOTIFY_BRIDGE_DATA_DIR=./test-data
|
export NOTIFY_BRIDGE_DATA_DIR=./test-data
|
||||||
export NOTIFY_BRIDGE_SECRET_KEY=test-secret-key-minimum-32-chars
|
export NOTIFY_BRIDGE_SECRET_KEY=test-secret-key-minimum-32-chars
|
||||||
|
# Dev targets (homelab Immich / Gitea / etc.) live on RFC1918 ranges; the SSRF
|
||||||
|
# guard rejects private addresses by default, which would make trackers fail.
|
||||||
|
export NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
|
||||||
nohup "$PYTHON" -m uvicorn notify_bridge_server.main:app \
|
nohup "$PYTHON" -m uvicorn notify_bridge_server.main:app \
|
||||||
--host 0.0.0.0 --port 8420 > .backend.log 2>&1 &
|
--host 0.0.0.0 --port 8420 > .backend.log 2>&1 &
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user