From 0eb899afb93201210f7c8cf6d7f9d1616f1cf342 Mon Sep 17 00:00:00 2001 From: "alexei.dolgolyov" Date: Thu, 7 May 2026 13:53:26 +0300 Subject: [PATCH] feat: harden notification stack and switch logging selectors to icon grid Notifications: - Add shared http_base, redact, and SSRF hardening modules - Refactor dispatcher, queue, receiver and per-provider clients (telegram, discord, email, matrix, ntfy, slack, webhook) to use the shared base, with bounded queue and redacted error logs - Tests for ssrf, redact, http_base, queue bounds, dispatcher aggregation, telegram media partition, email and matrix clients Frontend: - Settings: log level / log format selectors now use IconGridSelect with per-option icons and i18n descriptions - Minor providers page and entity-cache store updates Tooling: - Document code-review-graph MCP usage in CLAUDE.md - Ignore .code-review-graph/, register .mcp.json --- .gitignore | 2 + .mcp.json | 12 + CLAUDE.md | 39 + frontend/src/lib/grid-items.ts | 16 + frontend/src/lib/i18n/en.json | 9 +- frontend/src/lib/i18n/ru.json | 9 +- frontend/src/lib/stores/caches.svelte.ts | 28 + frontend/src/routes/providers/+page.svelte | 46 +- frontend/src/routes/settings/+page.svelte | 18 +- .../notifications/discord/client.py | 88 +- .../notifications/dispatcher.py | 456 ++++--- .../notifications/email/client.py | 158 ++- .../notifications/http_base.py | 196 +++ .../notifications/matrix/client.py | 110 +- .../notifications/ntfy/client.py | 66 +- .../notify_bridge_core/notifications/queue.py | 73 +- .../notifications/receiver.py | 109 +- .../notifications/redact.py | 64 + .../notifications/slack/client.py | 23 +- .../notify_bridge_core/notifications/ssrf.py | 221 +++- .../notifications/telegram/cache.py | 184 +-- .../notifications/telegram/client.py | 1099 ++++++++++------- .../notifications/telegram/media.py | 28 +- .../notifications/webhook/client.py | 42 +- .../notify_bridge_server/api/app_settings.py | 13 + .../tests/test_dispatcher_aggregation.py | 46 + packages/server/tests/test_email_client.py | 77 ++ packages/server/tests/test_http_base.py | 53 + packages/server/tests/test_matrix_client.py | 84 ++ packages/server/tests/test_queue_bound.py | 84 ++ packages/server/tests/test_redact.py | 74 ++ packages/server/tests/test_ssrf_hardening.py | 73 ++ .../tests/test_telegram_media_partition.py | 56 + 33 files changed, 2623 insertions(+), 1033 deletions(-) create mode 100644 .mcp.json create mode 100644 packages/core/src/notify_bridge_core/notifications/http_base.py create mode 100644 packages/core/src/notify_bridge_core/notifications/redact.py create mode 100644 packages/server/tests/test_dispatcher_aggregation.py create mode 100644 packages/server/tests/test_email_client.py create mode 100644 packages/server/tests/test_http_base.py create mode 100644 packages/server/tests/test_matrix_client.py create mode 100644 packages/server/tests/test_queue_bound.py create mode 100644 packages/server/tests/test_redact.py create mode 100644 packages/server/tests/test_ssrf_hardening.py create mode 100644 packages/server/tests/test_telegram_media_partition.py diff --git a/.gitignore b/.gitignore index e7d0448..db08a1c 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,5 @@ frontend/.svelte-kit/ # Logs *.log +# Added by code-review-graph +.code-review-graph/ diff --git a/.mcp.json b/.mcp.json new file mode 100644 index 0000000..c942808 --- /dev/null +++ b/.mcp.json @@ -0,0 +1,12 @@ +{ + "mcpServers": { + "code-review-graph": { + "command": "uvx", + "args": [ + "code-review-graph", + "serve" + ], + "type": "stdio" + } + } +} diff --git a/CLAUDE.md b/CLAUDE.md index a89dc33..1766996 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -43,3 +43,42 @@ Detailed context is split into focused documents under `.claude/docs/`. Read the - Notification preview sample: `packages/server/src/notify_bridge_server/services/sample_context.py` (`_SAMPLE_CONTEXT`) - Command preview sample: `packages/server/src/notify_bridge_server/api/command_template_configs.py` (`sample_ctx` in `preview_raw`) - Runtime validator whitelist: `packages/core/src/notify_bridge_core/templates/validator.py` + + +## MCP Tools: code-review-graph + +**IMPORTANT: This project has a knowledge graph. ALWAYS use the +code-review-graph MCP tools BEFORE using Grep/Glob/Read to explore +the codebase.** The graph is faster, cheaper (fewer tokens), and gives +you structural context (callers, dependents, test coverage) that file +scanning cannot. + +### When to use graph tools FIRST + +- **Exploring code**: `semantic_search_nodes` or `query_graph` instead of Grep +- **Understanding impact**: `get_impact_radius` instead of manually tracing imports +- **Code review**: `detect_changes` + `get_review_context` instead of reading entire files +- **Finding relationships**: `query_graph` with callers_of/callees_of/imports_of/tests_for +- **Architecture questions**: `get_architecture_overview` + `list_communities` + +Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need. + +### Key Tools + +| Tool | Use when | +|------|----------| +| `detect_changes` | Reviewing code changes — gives risk-scored analysis | +| `get_review_context` | Need source snippets for review — token-efficient | +| `get_impact_radius` | Understanding blast radius of a change | +| `get_affected_flows` | Finding which execution paths are impacted | +| `query_graph` | Tracing callers, callees, imports, tests, dependencies | +| `semantic_search_nodes` | Finding functions/classes by name or keyword | +| `get_architecture_overview` | Understanding high-level codebase structure | +| `refactor_tool` | Planning renames, finding dead code | + +### Workflow + +1. The graph auto-updates on file changes (via hooks). +2. Use `detect_changes` for code review. +3. Use `get_affected_flows` to understand impact. +4. Use `query_graph` pattern="tests_for" to check coverage. diff --git a/frontend/src/lib/grid-items.ts b/frontend/src/lib/grid-items.ts index 1c61b89..0102a96 100644 --- a/frontend/src/lib/grid-items.ts +++ b/frontend/src/lib/grid-items.ts @@ -73,6 +73,22 @@ export const localeItems = (): GridItem[] => [ { value: 'ru', icon: 'mdiAlphabeticalVariant', label: 'Русский', desc: t('gridDesc.localeRu') }, ]; +// --- Log level --- + +export const logLevelItems = (): GridItem[] => [ + { value: 'DEBUG', icon: 'mdiBugOutline', label: 'DEBUG', desc: t('gridDesc.logLevelDebug') }, + { value: 'INFO', icon: 'mdiInformationOutline', label: 'INFO', desc: t('gridDesc.logLevelInfo') }, + { value: 'WARNING', icon: 'mdiAlertOutline', label: 'WARNING', desc: t('gridDesc.logLevelWarning') }, + { value: 'ERROR', icon: 'mdiAlertOctagonOutline', label: 'ERROR', desc: t('gridDesc.logLevelError') }, +]; + +// --- Log format --- + +export const logFormatItems = (): GridItem[] => [ + { value: 'text', icon: 'mdiFormatText', label: 'text', desc: t('gridDesc.logFormatText') }, + { value: 'json', icon: 'mdiCodeJson', label: 'json', desc: t('gridDesc.logFormatJson') }, +]; + // --- Response mode --- export const responseModeItems = (tFn: typeof t): GridItem[] => [ diff --git a/frontend/src/lib/i18n/en.json b/frontend/src/lib/i18n/en.json index 24965e5..2bf0630 100644 --- a/frontend/src/lib/i18n/en.json +++ b/frontend/src/lib/i18n/en.json @@ -192,7 +192,8 @@ "apiToken": "API Token", "apiTokenHint": "Optional. Needed for connection testing and repository listing.", "webhookUrl": "Webhook URL", - "webhookUrlHint": "Set this as the Target URL in Gitea webhook settings (relative to your bridge host).", + "webhookUrlHint": "Set this as the Target URL in Gitea webhook settings. The full URL is shown when an external base URL is configured in Settings; otherwise it is relative to your bridge host.", + "webhookUrlCopyTitle": "Click to copy", "nutHost": "NUT Server Host", "nutHostPlaceholder": "192.168.1.100 or ups.local", "nutPort": "NUT Server Port", @@ -1131,6 +1132,12 @@ "memorySourceNative": "Use Immich native memories API", "localeEn": "English interface", "localeRu": "Russian interface", + "logLevelDebug": "Verbose — show every step", + "logLevelInfo": "Default — high-level events", + "logLevelWarning": "Warnings and errors only", + "logLevelError": "Errors only — quietest", + "logFormatText": "Human-readable plain text", + "logFormatJson": "One JSON object per line", "modeMedia": "Send actual photo/video files", "modeText": "Send file names and links only", "allEvents": "Show all event types", diff --git a/frontend/src/lib/i18n/ru.json b/frontend/src/lib/i18n/ru.json index 12acfd8..a569a6c 100644 --- a/frontend/src/lib/i18n/ru.json +++ b/frontend/src/lib/i18n/ru.json @@ -192,7 +192,8 @@ "apiToken": "API токен", "apiTokenHint": "Необязательно. Нужен для проверки подключения и получения списка репозиториев.", "webhookUrl": "URL вебхука", - "webhookUrlHint": "Укажите этот URL в настройках вебхука Gitea (относительно хоста bridge).", + "webhookUrlHint": "Укажите этот URL в настройках вебхука Gitea. Полный URL показывается, если в настройках задан внешний адрес; иначе путь указан относительно хоста bridge.", + "webhookUrlCopyTitle": "Нажмите, чтобы скопировать", "nutHost": "Хост NUT-сервера", "nutHostPlaceholder": "192.168.1.100 или ups.local", "nutPort": "Порт NUT-сервера", @@ -1131,6 +1132,12 @@ "memorySourceNative": "Использовать API воспоминаний Immich", "localeEn": "Английский интерфейс", "localeRu": "Русский интерфейс", + "logLevelDebug": "Подробный — каждый шаг", + "logLevelInfo": "По умолчанию — ключевые события", + "logLevelWarning": "Только предупреждения и ошибки", + "logLevelError": "Только ошибки — самый тихий", + "logFormatText": "Читаемый человеком текст", + "logFormatJson": "Один JSON-объект на строку", "modeMedia": "Отправка файлов фото/видео", "modeText": "Только имена файлов и ссылки", "allEvents": "Показать все типы событий", diff --git a/frontend/src/lib/stores/caches.svelte.ts b/frontend/src/lib/stores/caches.svelte.ts index 1892040..2eef22e 100644 --- a/frontend/src/lib/stores/caches.svelte.ts +++ b/frontend/src/lib/stores/caches.svelte.ts @@ -112,6 +112,34 @@ export const capabilitiesCache = (() => { }; })(); +/** Configured external base URL — used to render absolute webhook URLs. + * Available to all authenticated users. Empty string when unset. */ +export const externalUrlCache = (() => { + let data = $state(''); + let fetchedAt = $state(0); + let inflight: Promise | null = null; + const TTL = 300_000; + return { + get value() { return data; }, + invalidate() { fetchedAt = 0; }, + async fetch(force = false): Promise { + if (!force && fetchedAt > 0 && Date.now() - fetchedAt < TTL) return data; + if (inflight) return inflight; + inflight = (async () => { + try { + const res = await api<{ external_url: string }>('/settings/external-url'); + data = (res?.external_url || '').replace(/\/+$/, ''); + fetchedAt = Date.now(); + return data; + } finally { + inflight = null; + } + })(); + return inflight; + }, + }; +})(); + /** Supported template locales — fetched from app settings. */ export const supportedLocalesCache = (() => { let data = $state(['en', 'ru']); diff --git a/frontend/src/routes/providers/+page.svelte b/frontend/src/routes/providers/+page.svelte index 66a333e..3c61c03 100644 --- a/frontend/src/routes/providers/+page.svelte +++ b/frontend/src/routes/providers/+page.svelte @@ -3,7 +3,7 @@ import { slide } from 'svelte/transition'; import { api, getBlockedBy, type BlockedByDetail } from '$lib/api'; import { t } from '$lib/i18n'; - import { providersCache } from '$lib/stores/caches.svelte'; + import { providersCache, externalUrlCache } from '$lib/stores/caches.svelte'; import PageHeader from '$lib/components/PageHeader.svelte'; import Card from '$lib/components/Card.svelte'; import Loading from '$lib/components/Loading.svelte'; @@ -21,7 +21,7 @@ import { globalProviderFilter } from '$lib/stores/provider-filter.svelte'; import { topbarAction } from '$lib/stores/topbar-action.svelte'; import { onDestroy } from 'svelte'; - import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; + import { snackSuccess, snackError, snackInfo } from '$lib/stores/snackbar.svelte'; import { highlightFromUrl } from '$lib/highlight'; import { getDescriptor, buildProviderFormDefaults } from '$lib/providers'; import Button from '$lib/components/Button.svelte'; @@ -45,6 +45,30 @@ let confirmDelete = $state(null); let descriptor = $derived(getDescriptor(form.type)); + let externalUrl = $derived(externalUrlCache.value); + + function buildWebhookUrl(pattern: string, token: string): string { + const path = pattern.replace('{token}', token ?? ''); + return externalUrl ? `${externalUrl}${path}` : path; + } + + function copyWebhookUrl(e: Event, url: string) { + e.preventDefault(); + e.stopPropagation(); + if (navigator.clipboard?.writeText) { + navigator.clipboard.writeText(url); + } else { + const ta = document.createElement('textarea'); + ta.value = url; + ta.style.position = 'fixed'; + ta.style.opacity = '0'; + document.body.appendChild(ta); + ta.select(); + document.execCommand('copy'); + document.body.removeChild(ta); + } + snackInfo(`${t('snack.copied')}: ${url}`); + } // Auto-update name when provider type changes (unless user manually edited) $effect(() => { @@ -76,6 +100,7 @@ onclick: () => { showForm ? (showForm = false, editing = null) : openNew(); }, }); load(); + externalUrlCache.fetch().catch(() => { /* fall back to relative URLs */ }); }); onDestroy(() => topbarAction.clear()); async function load() { @@ -246,9 +271,15 @@ {/each} {#if descriptor?.webhookUrlPattern && editing} + {@const editingWebhookUrl = buildWebhookUrl(descriptor.webhookUrlPattern, providers.find(p => p.id === editing)?.webhook_token ?? '')}
{t('providers.webhookUrl')}
- {descriptor.webhookUrlPattern.replace('{token}', providers.find(p => p.id === editing)?.webhook_token ?? '')} +

{t('providers.webhookUrlHint')}

{/if} @@ -295,7 +326,14 @@

{provider.config.host}:{provider.config.port || 3493}

{/if} {#if provDesc?.webhookUrlPattern} -

{t('providers.webhookUrl')}: {provDesc.webhookUrlPattern.replace('{token}', provider.webhook_token)}

+ {@const webhookUrl = buildWebhookUrl(provDesc.webhookUrlPattern, provider.webhook_token)} +

+ {t('providers.webhookUrl')}: + +

{/if} diff --git a/frontend/src/routes/settings/+page.svelte b/frontend/src/routes/settings/+page.svelte index df573a3..a07b59a 100644 --- a/frontend/src/routes/settings/+page.svelte +++ b/frontend/src/routes/settings/+page.svelte @@ -12,7 +12,10 @@ import ConfirmModal from '$lib/components/ConfirmModal.svelte'; import LocaleSelector from '$lib/components/LocaleSelector.svelte'; import TimezoneSelector from '$lib/components/TimezoneSelector.svelte'; + import IconGridSelect from '$lib/components/IconGridSelect.svelte'; + import { logLevelItems, logFormatItems } from '$lib/grid-items'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; + import { externalUrlCache } from '$lib/stores/caches.svelte'; interface CacheBucketStats { count: number; @@ -76,6 +79,7 @@ saving = true; error = ''; try { settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) }); + externalUrlCache.invalidate(); snackSuccess(t('settings.saved')); } catch (err: any) { error = err.message; snackError(err.message); } saving = false; @@ -221,21 +225,11 @@
- +
- +
diff --git a/packages/core/src/notify_bridge_core/notifications/discord/client.py b/packages/core/src/notify_bridge_core/notifications/discord/client.py index cd0ea24..85dd853 100644 --- a/packages/core/src/notify_bridge_core/notifications/discord/client.py +++ b/packages/core/src/notify_bridge_core/notifications/discord/client.py @@ -4,21 +4,24 @@ from __future__ import annotations import asyncio import logging -from typing import Any +from typing import Any, Final import aiohttp +from ..http_base import HttpProviderClient + _LOGGER = logging.getLogger(__name__) -# Discord webhook content limit -MAX_CONTENT_LENGTH = 2000 +# Discord API constraints (per webhook docs). +MAX_CONTENT_LENGTH: Final = 2000 +MAX_USERNAME_LENGTH: Final = 80 -class DiscordClient: +class DiscordClient(HttpProviderClient): """Sends messages via Discord webhook URLs.""" def __init__(self, session: aiohttp.ClientSession) -> None: - self._session = session + super().__init__(session, provider_name="discord") async def send( self, @@ -33,6 +36,8 @@ class DiscordClient: """ if not webhook_url: return {"success": False, "error": "Missing webhook_url"} + if username and len(username) > MAX_USERNAME_LENGTH: + return {"success": False, "error": f"username exceeds {MAX_USERNAME_LENGTH} chars"} chunks = _split_message(message, MAX_CONTENT_LENGTH) for chunk in chunks: @@ -42,71 +47,34 @@ class DiscordClient: if avatar_url: payload["avatar_url"] = avatar_url - result = await self._post(webhook_url, payload) - if not result["success"]: + result = await self.request("POST", webhook_url, json=payload) + if not result.get("success"): return result - # Small delay between chunks to respect rate limits if len(chunks) > 1: await asyncio.sleep(0.5) return {"success": True} - _MAX_RETRIES = 3 - _MAX_RETRY_AFTER = 60.0 - - async def _post(self, url: str, payload: dict) -> dict[str, Any]: - """POST with bounded 429 retry. - - We cap retries at _MAX_RETRIES and the ``Retry-After`` header at - _MAX_RETRY_AFTER seconds so a hostile or misbehaving upstream cannot - pin the dispatch task indefinitely. - """ - for attempt in range(self._MAX_RETRIES + 1): - try: - async with self._session.post( - url, - json=payload, - headers={"Content-Type": "application/json"}, - allow_redirects=False, - ) as resp: - if resp.status == 429 and attempt < self._MAX_RETRIES: - try: - retry_after = float(resp.headers.get("Retry-After", "2")) - except (TypeError, ValueError): - retry_after = 2.0 - retry_after = max(0.0, min(retry_after, self._MAX_RETRY_AFTER)) - _LOGGER.warning( - "Discord rate limited, retrying after %.1fs (attempt %d/%d)", - retry_after, attempt + 1, self._MAX_RETRIES, - ) - await asyncio.sleep(retry_after) - continue - if 200 <= resp.status < 300: - return {"success": True} - body = await resp.text() - return { - "success": False, - "error": f"HTTP {resp.status}: {body[:200]}", - } - except aiohttp.ClientError as e: - return {"success": False, "error": str(e)} - return {"success": False, "error": "Rate limited (retries exhausted)"} - def _split_message(text: str, limit: int) -> list[str]: - """Split message into chunks respecting the character limit.""" + """Split message into chunks respecting the character limit. + + Drops chunks that contain only whitespace — Discord rejects those. + """ if len(text) <= limit: return [text] - chunks = [] + chunks: list[str] = [] while text: if len(text) <= limit: - chunks.append(text) - break - # Try to split at newline - split_at = text.rfind("\n", 0, limit) - if split_at <= 0: - split_at = limit - chunks.append(text[:split_at]) - text = text[split_at:].lstrip("\n") - return chunks + piece = text + text = "" + else: + split_at = text.rfind("\n", 0, limit) + if split_at <= 0: + split_at = limit + piece = text[:split_at] + text = text[split_at:].lstrip("\n") + if piece.strip(): + chunks.append(piece) + return chunks or [text] diff --git a/packages/core/src/notify_bridge_core/notifications/dispatcher.py b/packages/core/src/notify_bridge_core/notifications/dispatcher.py index 931c03a..01e04a3 100644 --- a/packages/core/src/notify_bridge_core/notifications/dispatcher.py +++ b/packages/core/src/notify_bridge_core/notifications/dispatcher.py @@ -7,7 +7,7 @@ import contextlib import logging import uuid from dataclasses import dataclass, field -from typing import Any, AsyncIterator +from typing import Any, AsyncIterator, Awaitable, Callable, Final import aiohttp @@ -15,37 +15,20 @@ from notify_bridge_core.log_context import bind_log_context, dispatch_id_var from notify_bridge_core.models.events import ServiceEvent from notify_bridge_core.templates.context import build_template_context from notify_bridge_core.templates.renderer import render_template -from .ssrf import UnsafeURLError, avalidate_outbound_url - -_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30) - -# Cap on how many asset downloads run concurrently inside -# ``_preload_asset_data``. Peak memory during a send is bounded to roughly -# ``_PRELOAD_CONCURRENCY * max_asset_size`` instead of ``max_media_to_send * -# max_asset_size``, which matters on small-RAM Docker hosts when a batch -# contains many large videos. -_PRELOAD_CONCURRENCY = 6 - - -def _new_session() -> aiohttp.ClientSession: - """Per-dispatch aiohttp session with a sane default timeout. - - We still open a short-lived session per dispatch (connection reuse across - dispatches lives in the server-side shared session), but we always attach - a total timeout so a hung peer cannot wedge the task forever. - """ - return aiohttp.ClientSession(timeout=_HTTP_TIMEOUT) +from .http_base import safe_headers from .receiver import ( + DiscordReceiver, + EmailReceiver, + MatrixReceiver, + NtfyReceiver, Receiver, + SlackReceiver, TelegramReceiver, WebhookReceiver, - EmailReceiver, - DiscordReceiver, - SlackReceiver, - NtfyReceiver, - MatrixReceiver, ) +from .redact import redact_exc +from .ssrf import UnsafeURLError, avalidate_outbound_url from .telegram.cache import TelegramFileCache from .telegram.client import TelegramClient from .telegram.media import ( @@ -58,7 +41,33 @@ from .webhook.client import WebhookClient _LOGGER = logging.getLogger(__name__) -DEFAULT_TEMPLATE = '{{ event_type }}: "{{ collection_name }}"' +DEFAULT_TEMPLATE: Final = '{{ event_type }}: "{{ collection_name }}"' + +_HTTP_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10) + +# Cap on how many asset downloads run concurrently inside +# ``_preload_asset_data``. Peak memory during a send is bounded to roughly +# ``_PRELOAD_CONCURRENCY * max_asset_size`` instead of ``max_media_to_send * +# max_asset_size``. +_PRELOAD_CONCURRENCY: Final = 6 + +# Cap on how many targets the dispatcher fans out to at once. With dozens +# of targets and a single hung peer, unbounded ``gather`` can pin the +# dispatch task. The cap also protects against credential-reuse rate +# limits on shared providers. +_DISPATCH_CONCURRENCY: Final = 16 + +# Cap on parallel per-receiver sends within a single target. +_RECEIVER_CONCURRENCY: Final = 8 + +# Per-target soft timeout — at the top of the dispatch tree so a single +# misbehaving target can't hold the whole batch open. Individual provider +# clients carry their own per-request timeouts on top of this. +_TARGET_TIMEOUT_S: Final = 120.0 + + +def _new_session() -> aiohttp.ClientSession: + return aiohttp.ClientSession(timeout=_HTTP_TIMEOUT) @dataclass @@ -66,17 +75,23 @@ class TargetConfig: """Configuration for a notification target.""" type: str # "telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix" - config: dict[str, Any] # target-level config (bot_token, settings, etc.) - template_slots: dict[str, dict[str, str]] | None = None # event_type -> {locale -> template} - locale: str = "en" # default locale for template resolution + config: dict[str, Any] + template_slots: dict[str, dict[str, str]] | None = None + locale: str = "en" date_format: str = "%d.%m.%Y, %H:%M UTC" date_only_format: str = "%d.%m.%Y" - provider_api_key: str | None = None # API key for downloading assets from provider - provider_internal_url: str | None = None # Internal provider URL for API key scoping - provider_external_url: str | None = None # External domain for API key scoping + provider_api_key: str | None = None + provider_internal_url: str | None = None + provider_external_url: str | None = None receivers: list[Receiver] = field(default_factory=list) +_SendMethod = Callable[ + ["NotificationDispatcher", TargetConfig, str, ServiceEvent], + Awaitable[dict[str, Any]], +] + + class NotificationDispatcher: """Dispatches ServiceEvent notifications to configured targets.""" @@ -90,18 +105,11 @@ class NotificationDispatcher: self._url_cache = url_cache self._asset_cache = asset_cache # Optional shared session owned by the caller; when supplied we reuse - # its connection pool instead of opening a fresh per-dispatch session - # (saves a TLS handshake per outbound call). + # its connection pool instead of opening a fresh per-dispatch session. self._shared_session = session @contextlib.asynccontextmanager async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]: - """Yield an aiohttp session, reusing the shared one if provided. - - When a shared session was passed in ``__init__`` we yield it without - closing (the caller owns its lifetime). Otherwise we open a - short-lived session with our default timeout and close it on exit. - """ if self._shared_session is not None and not self._shared_session.closed: yield self._shared_session return @@ -115,11 +123,9 @@ class NotificationDispatcher: ) -> list[dict[str, Any]]: """Send event notification to all targets. - Returns list of results (one per target). + Returns one result per target. Per-target failures are isolated; + a single bad target cannot poison the batch. """ - # Bind a dispatch_id so every log line emitted by the target sends - # (including deep in TelegramClient) can be correlated to the same - # upstream event. new_id = dispatch_id_var.get() or f"disp:{uuid.uuid4().hex[:12]}" with bind_log_context(dispatch_id=new_id): @@ -128,20 +134,36 @@ class NotificationDispatcher: event.event_type.value if hasattr(event.event_type, "value") else event.event_type, getattr(event, "collection_name", None), len(targets), ) + + sem = asyncio.Semaphore(_DISPATCH_CONCURRENCY) + + async def run_one(t: TargetConfig) -> dict[str, Any]: + async with sem: + try: + return await asyncio.wait_for( + self._send_to_target(event, t), + timeout=_TARGET_TIMEOUT_S, + ) + except asyncio.TimeoutError: + return { + "success": False, + "error": f"Target dispatch timed out after {_TARGET_TIMEOUT_S}s", + } + raw_results = await asyncio.gather( - *[self._send_to_target(event, t) for t in targets], + *[run_one(t) for t in targets], return_exceptions=True, ) - results = [] + results: list[dict[str, Any]] = [] failures = 0 for target, raw in zip(targets, raw_results): if isinstance(raw, Exception): failures += 1 _LOGGER.error( "Dispatch to target type=%s failed: %s", - target.type, raw, exc_info=raw, + target.type, redact_exc(raw), ) - results.append({"success": False, "error": str(raw)}) + results.append({"success": False, "error": redact_exc(raw)}) else: if isinstance(raw, dict) and not raw.get("success"): failures += 1 @@ -155,7 +177,6 @@ class NotificationDispatcher: def _resolve_template( self, event: ServiceEvent, target: TargetConfig, locale: str, ) -> str: - """Resolve template string for an event, with locale fallback.""" template_str = DEFAULT_TEMPLATE if target.template_slots: locale_map = target.template_slots.get(event.event_type.value) @@ -166,7 +187,6 @@ class NotificationDispatcher: def _render_message( self, event: ServiceEvent, target: TargetConfig, locale: str, ) -> str: - """Resolve template and render message for a given locale.""" template_str = self._resolve_template(event, target, locale) ctx = build_template_context( event, target_type=target.type, @@ -179,7 +199,6 @@ class NotificationDispatcher: self, receiver: Receiver, default_message: str, event: ServiceEvent, target: TargetConfig, ) -> str: - """Return per-receiver message, re-rendering if receiver has a different locale.""" if receiver.locale and receiver.locale != target.locale: return self._render_message(event, target, receiver.locale) return default_message @@ -187,21 +206,16 @@ class NotificationDispatcher: async def _send_to_target( self, event: ServiceEvent, target: TargetConfig ) -> dict[str, Any]: - """Send event to a single target (potentially multiple receivers).""" + """Dispatch to a single target via the registered handler.""" default_message = self._render_message(event, target, target.locale) + send_method = _PROVIDER_HANDLERS.get(target.type) + if send_method is None: + return {"success": False, "error": f"Unknown target type: {target.type}"} + return await send_method(self, target, default_message, event) - send_method = { - "telegram": self._send_telegram, - "webhook": self._send_webhook, - "email": self._send_email, - "discord": self._send_discord, - "slack": self._send_slack, - "ntfy": self._send_ntfy, - "matrix": self._send_matrix, - }.get(target.type) - if send_method: - return await send_method(target, default_message, event) - return {"success": False, "error": f"Unknown target type: {target.type}"} + # ------------------------------------------------------------------ + # Asset preload (Telegram-specific) + # ------------------------------------------------------------------ async def _preload_asset_data( self, @@ -210,36 +224,13 @@ class NotificationDispatcher: session: aiohttp.ClientSession, max_size: int | None, ) -> None: - """Download each non-cached asset's bytes once and attach to the entry. - - Three benefits: - * ``TelegramClient`` sees ``entry["data"]`` and skips its own download, - so we don't fetch each URL twice. - * We know the exact upload size, which lets the oversize warning in - the rendered text compare against real bytes (for Immich videos, - the transcoded ``/video/playback``), not the original ``file_size``. - * Assets already in the Telegram file_id cache are skipped, and their - stored size (if any) is used to populate ``playback_size`` — so - templates see consistent sizes for repeat sends without re-download. - - Entries whose download fails or exceeds ``max_size`` are left without - ``data``; ``TelegramClient`` will then fall back to its own download - path and apply the same checks — no regression, just no preload win. - - Concurrency is bounded by ``_PRELOAD_CONCURRENCY`` so peak memory - stays predictable: at most N assets worth of bytes held in RAM at - once, regardless of ``max_media_to_send``. Total wall-clock is - unchanged for small batches and only marginally slower for large - ones (most assets fit in a single RTT and SSL negotiation cost - dominates, so 6-way parallelism is sufficient). - """ + """Download each non-cached asset's bytes once, with SSRF guard.""" if not assets: return sem = asyncio.Semaphore(_PRELOAD_CONCURRENCY) - async def _fetch(entry: dict[str, Any], media: Any) -> None: - # Cache hit → skip download; populate playback_size from stored size. + async def fetch(entry: dict[str, Any], media: Any) -> None: cache, key = self._cache_for_entry(entry) if cache and key: cached = cache.get(key) @@ -251,28 +242,40 @@ class NotificationDispatcher: url = entry["url"] headers = entry.get("headers") or {} + try: + # Defense-in-depth: validate even though TelegramClient + # also validates. The dispatcher is what triggers the + # download, so the guard belongs here too. + await avalidate_outbound_url(url) + except UnsafeURLError as err: + _LOGGER.warning( + "Asset preload skipped: unsafe URL (%s)", redact_exc(err), + ) + return async with sem: try: async with session.get(url, headers=headers) as resp: if resp.status != 200: return data = await resp.read() - except aiohttp.ClientError: + except (aiohttp.ClientError, asyncio.TimeoutError, OSError): return if max_size is not None and len(data) > max_size: return entry["data"] = data media.extra["playback_size"] = len(data) - await asyncio.gather(*(_fetch(e, m) for e, m in zip(assets, media_assets))) + raw = await asyncio.gather( + *(fetch(e, m) for e, m in zip(assets, media_assets)), + return_exceptions=True, + ) + for r in raw: + if isinstance(r, Exception): + _LOGGER.warning("Asset preload raised: %s", redact_exc(r)) def _cache_for_entry( self, entry: dict[str, Any], ) -> tuple[TelegramFileCache | None, str | None]: - """Resolve (cache, key) for an asset entry — mirrors TelegramClient logic. - - Returns (None, None) if no cache is configured or no key can be derived. - """ cache_key = entry.get("cache_key") if cache_key: cache = self._asset_cache if is_asset_cache_key(cache_key) else self._url_cache @@ -287,6 +290,10 @@ class NotificationDispatcher: return self._url_cache, url return None, None + # ------------------------------------------------------------------ + # Per-provider handlers + # ------------------------------------------------------------------ + async def _send_telegram( self, target: TargetConfig, default_message: str, event: ServiceEvent ) -> dict[str, Any]: @@ -296,27 +303,25 @@ class NotificationDispatcher: max_media = target.config.get("max_media_to_send", 50) max_group = target.config.get("max_media_per_group", 10) chunk_delay = target.config.get("media_delay", 500) - max_size = target.config.get("max_asset_size") - if max_size: - max_size = max_size * 1024 * 1024 # MB to bytes + max_size_mb = target.config.get("max_asset_size") + max_size_bytes = max_size_mb * 1024 * 1024 if max_size_mb else None send_large_as_docs = target.config.get("send_large_photos_as_documents", False) if not bot_token: return {"success": False, "error": "Missing bot_token"} - if not target.receivers: return {"success": False, "error": "No receivers configured"} - # Prepare assets list once (shared across receivers) - # Prefer internal URL for fetching (LAN speed vs public internet) internal_url = (target.provider_internal_url or "").rstrip("/") external_url = (target.provider_external_url or "").rstrip("/") - assets = [] - media_assets: list[Any] = [] # aligned with `assets` for preload + assets: list[dict[str, Any]] = [] + media_assets: list[Any] = [] for asset in event.added_assets[:max_media]: url = asset.preview_url or asset.thumbnail_url or asset.full_url + if not url: + continue asset_entry = build_telegram_asset_entry( - url=url or "", + url=url, media_type=asset.type.value, api_key=target.provider_api_key, internal_url=internal_url, @@ -327,26 +332,15 @@ class NotificationDispatcher: assets.append(asset_entry) media_assets.append(asset) - results: list[dict[str, Any]] = [] async with self._session_ctx() as session: - # Preload all asset bytes once so (a) TelegramClient can skip its - # own download and (b) we know exact upload sizes in time for the - # oversize warning in the rendered text. - await self._preload_asset_data(assets, media_assets, session, max_size) - default_message = self._render_message(event, target, target.locale) + await self._preload_asset_data(assets, media_assets, session, max_size_bytes) - # Asset cache (when in thumbhash mode) invalidates entries when the - # asset's visual content changes. The resolver maps asset id → its - # current thumbhash. Providers that expose thumbhash put it in - # ``asset.extra["thumbhash"]`` (currently Immich). thumbhash_map = { asset.id: asset.extra.get("thumbhash") for asset in event.added_assets if asset.extra.get("thumbhash") } - thumbhash_resolver = ( - (lambda key: thumbhash_map.get(key)) if thumbhash_map else None - ) + thumbhash_resolver = thumbhash_map.get if thumbhash_map else None client = TelegramClient( session, bot_token, @@ -355,39 +349,51 @@ class NotificationDispatcher: thumbhash_resolver=thumbhash_resolver, ) - for receiver in target.receivers: + async def send_one(receiver: Receiver) -> dict[str, Any]: if not isinstance(receiver, TelegramReceiver) or not receiver.chat_id: - results.append({"success": False, "error": "Invalid telegram receiver"}) - continue - + return {"success": False, "error": "Invalid telegram receiver"} message = self._message_for_receiver(receiver, default_message, event, target) - text_result = await client.send_message( chat_id=receiver.chat_id, text=message, disable_web_page_preview=bool(disable_preview), ) if not text_result.get("success"): - _LOGGER.warning("Failed to send to chat %s: %s", receiver.chat_id, text_result.get("error")) - results.append(text_result) - continue + _LOGGER.warning( + "Failed to send to chat %s: %s", + receiver.chat_id, text_result.get("error"), + ) + return text_result if assets: - reply_to = text_result.get("message_id") media_result = await client.send_notification( chat_id=receiver.chat_id, assets=assets, - reply_to_message_id=reply_to, + reply_to_message_id=text_result.get("message_id"), max_group_size=max_group, chunk_delay=chunk_delay, - max_asset_data_size=max_size, + max_asset_data_size=max_size_bytes, send_large_photos_as_documents=send_large_as_docs, chat_action=chat_action or None, ) if not media_result.get("success"): - _LOGGER.warning("Text sent OK but media failed for chat %s: %s", receiver.chat_id, media_result.get("error")) + _LOGGER.warning( + "Text sent OK but media failed for chat %s: %s", + receiver.chat_id, media_result.get("error"), + ) + # Preserve both outcomes — text succeeded, media + # didn't. Operators losing media-failure detail + # in the result dict made root-cause analysis + # impossible. + return { + "success": True, + "message_id": text_result.get("message_id"), + "media_error": media_result.get("error"), + "media_failed_at_chunk": media_result.get("failed_at_chunk"), + } + return text_result - results.append(text_result) + results = await self._fan_out(target.receivers, send_one) return self._aggregate_results(results) @@ -397,17 +403,10 @@ class NotificationDispatcher: if not target.receivers: return {"success": False, "error": "No receivers configured"} - results: list[dict[str, Any]] = [] async with self._session_ctx() as session: - for receiver in target.receivers: + async def send_one(receiver: Receiver) -> dict[str, Any]: if not isinstance(receiver, WebhookReceiver) or not receiver.url: - results.append({"success": False, "error": "Invalid webhook receiver"}) - continue - try: - await avalidate_outbound_url(receiver.url) - except UnsafeURLError as err: - results.append({"success": False, "error": f"Unsafe URL: {err}"}) - continue + return {"success": False, "error": "Invalid webhook receiver"} message = self._message_for_receiver(receiver, default_message, event, target) payload = { "message": message, @@ -417,8 +416,10 @@ class NotificationDispatcher: "collection_id": event.collection_id, "timestamp": event.timestamp.isoformat(), } - client = WebhookClient(session, receiver.url, receiver.headers) - results.append(await client.send(payload)) + client = WebhookClient(session, receiver.url, safe_headers(receiver.headers)) + return await client.send(payload) + + results = await self._fan_out(target.receivers, send_one) return self._aggregate_results(results) @@ -431,7 +432,7 @@ class NotificationDispatcher: if not smtp_cfg.get("host"): return {"success": False, "error": "SMTP not configured"} - client = EmailClient(SmtpConfig( + email_client = EmailClient(SmtpConfig( host=smtp_cfg["host"], port=int(smtp_cfg.get("port", 587)), username=smtp_cfg.get("username", ""), @@ -439,27 +440,28 @@ class NotificationDispatcher: from_address=smtp_cfg.get("from_address", ""), from_name=smtp_cfg.get("from_name", "Notify Bridge"), use_tls=smtp_cfg.get("use_tls", True), + tls_mode=smtp_cfg.get("tls_mode", "auto"), )) if not target.receivers: return {"success": False, "error": "No receivers configured"} subject = f"[Notify Bridge] {event.event_type.value}: {event.collection_name}" - results: list[dict[str, Any]] = [] - for receiver in target.receivers: + async def send_one(receiver: Receiver) -> dict[str, Any]: if not isinstance(receiver, EmailReceiver) or not receiver.email: - results.append({"success": False, "error": "Invalid email receiver"}) - continue + return {"success": False, "error": "Invalid email receiver"} message = self._message_for_receiver(receiver, default_message, event, target) - result = await client.send( + # body_html=None lets EmailClient build a safely-escaped HTML + # alternative from body_text instead of trusting user content. + return await email_client.send( to_email=receiver.email, subject=subject, body_text=message, - body_html=message, + body_html=None, to_name=receiver.name, ) - results.append(result) + results = await self._fan_out(target.receivers, send_one) return self._aggregate_results(results) async def _send_discord( @@ -471,20 +473,16 @@ class NotificationDispatcher: return {"success": False, "error": "No receivers configured"} username = target.config.get("username") - results: list[dict[str, Any]] = [] async with self._session_ctx() as session: client = DiscordClient(session) - for receiver in target.receivers: + + async def send_one(receiver: Receiver) -> dict[str, Any]: if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url: - results.append({"success": False, "error": "Invalid discord receiver"}) - continue - try: - await avalidate_outbound_url(receiver.webhook_url) - except UnsafeURLError as err: - results.append({"success": False, "error": f"Unsafe URL: {err}"}) - continue + return {"success": False, "error": "Invalid discord receiver"} message = self._message_for_receiver(receiver, default_message, event, target) - results.append(await client.send(receiver.webhook_url, message, username=username)) + return await client.send(receiver.webhook_url, message, username=username) + + results = await self._fan_out(target.receivers, send_one) return self._aggregate_results(results) @@ -497,20 +495,16 @@ class NotificationDispatcher: return {"success": False, "error": "No receivers configured"} username = target.config.get("username") - results: list[dict[str, Any]] = [] async with self._session_ctx() as session: client = SlackClient(session) - for receiver in target.receivers: + + async def send_one(receiver: Receiver) -> dict[str, Any]: if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url: - results.append({"success": False, "error": "Invalid slack receiver"}) - continue - try: - await avalidate_outbound_url(receiver.webhook_url) - except UnsafeURLError as err: - results.append({"success": False, "error": f"Unsafe URL: {err}"}) - continue + return {"success": False, "error": "Invalid slack receiver"} message = self._message_for_receiver(receiver, default_message, event, target) - results.append(await client.send(receiver.webhook_url, message, username=username)) + return await client.send(receiver.webhook_url, message, username=username) + + results = await self._fan_out(target.receivers, send_one) return self._aggregate_results(results) @@ -526,22 +520,23 @@ class NotificationDispatcher: try: await avalidate_outbound_url(server_url) except UnsafeURLError as err: - return {"success": False, "error": f"Unsafe ntfy server_url: {err}"} + return {"success": False, "error": f"Unsafe ntfy server_url: {redact_exc(err)}"} title = f"{event.event_type.value}: {event.collection_name}" - results: list[dict[str, Any]] = [] async with self._session_ctx() as session: client = NtfyClient(session) - for receiver in target.receivers: + + async def send_one(receiver: Receiver) -> dict[str, Any]: if not isinstance(receiver, NtfyReceiver) or not receiver.topic: - results.append({"success": False, "error": "Invalid ntfy receiver"}) - continue + return {"success": False, "error": "Invalid ntfy receiver"} message = self._message_for_receiver(receiver, default_message, event, target) - results.append(await client.send( + return await client.send( server_url, receiver.topic, message, title=title, priority=receiver.priority, auth_token=auth_token, - )) + ) + + results = await self._fan_out(target.receivers, send_one) return self._aggregate_results(results) @@ -557,33 +552,108 @@ class NotificationDispatcher: try: await avalidate_outbound_url(homeserver) except UnsafeURLError as err: - return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"} + return {"success": False, "error": f"Unsafe matrix homeserver_url: {redact_exc(err)}"} if not target.receivers: return {"success": False, "error": "No receivers configured"} - results: list[dict[str, Any]] = [] async with self._session_ctx() as session: client = MatrixClient(session, homeserver, access_token) - for receiver in target.receivers: + + async def send_one(receiver: Receiver) -> dict[str, Any]: if not isinstance(receiver, MatrixReceiver) or not receiver.room_id: - results.append({"success": False, "error": "Invalid matrix receiver"}) - continue + return {"success": False, "error": "Invalid matrix receiver"} message = self._message_for_receiver(receiver, default_message, event, target) - results.append(await client.send_message( - receiver.room_id, message, html_message=message, - )) + # body_html is the same plain text — Matrix accepts the + # raw message as both ``body`` and ``formatted_body``. + # If templates emit HTML in the future, generate a + # separate HTML body upstream rather than aliasing here. + return await client.send_message( + receiver.room_id, message, html_message=None, + ) + + results = await self._fan_out(target.receivers, send_one) return self._aggregate_results(results) + # ------------------------------------------------------------------ + # Aggregation + # ------------------------------------------------------------------ + + @staticmethod + async def _fan_out( + receivers: list[Receiver], + send_one: Callable[[Receiver], Awaitable[dict[str, Any]]], + ) -> list[dict[str, Any]]: + """Run ``send_one`` per receiver with bounded concurrency. + + Per-receiver exceptions are converted to failure dicts so a single + bad receiver can't cancel its peers. + """ + sem = asyncio.Semaphore(_RECEIVER_CONCURRENCY) + + async def guarded(receiver: Receiver) -> dict[str, Any]: + async with sem: + try: + return await send_one(receiver) + except Exception as exc: # noqa: BLE001 + _LOGGER.error("Receiver send raised: %s", redact_exc(exc)) + return {"success": False, "error": redact_exc(exc)} + + return await asyncio.gather(*(guarded(r) for r in receivers)) + @staticmethod def _aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]: - """Aggregate broadcast results into a single result dict.""" + """Aggregate per-receiver results into a single target-level result. + + Preserves the per-receiver detail under ``receivers`` so a caller + can see exactly which receivers failed, instead of getting only + the first error. + """ + if not results: + return {"success": False, "error": "No receivers configured"} + successes = sum(1 for r in results if r.get("success")) - if successes == len(results) and results: - return {"success": True, "receivers": len(results)} - elif successes > 0: - return {"success": True, "receivers": len(results), "partial_failures": len(results) - successes} - elif results: - return results[0] - return {"success": False, "error": "No receivers configured"} + failures = len(results) - successes + out: dict[str, Any] = { + "success": successes > 0, + "receivers": len(results), + "successes": successes, + "failures": failures, + "results": results, + } + if failures: + out["errors"] = [ + r.get("error") for r in results if not r.get("success") + ] + if successes == 0: + # Surface the first error at the top level for back-compat + # with callers that only check ``error``. + out["error"] = results[0].get("error", "All receivers failed") + return out + + +# ---------------------------------------------------------------------- +# Provider registry — replaces the if/elif chain so adding a provider +# means just registering it here, not editing dispatch logic. +# ---------------------------------------------------------------------- + +_PROVIDER_HANDLERS: dict[str, _SendMethod] = { + "telegram": NotificationDispatcher._send_telegram, + "webhook": NotificationDispatcher._send_webhook, + "email": NotificationDispatcher._send_email, + "discord": NotificationDispatcher._send_discord, + "slack": NotificationDispatcher._send_slack, + "ntfy": NotificationDispatcher._send_ntfy, + "matrix": NotificationDispatcher._send_matrix, +} + + +def register_provider(name: str, handler: _SendMethod) -> None: + """Register a new dispatcher provider at runtime. + + Allows out-of-tree providers to extend the dispatcher without + forking. The handler must follow the + ``async (dispatcher, target, default_message, event) -> dict`` shape. + """ + _PROVIDER_HANDLERS[name] = handler diff --git a/packages/core/src/notify_bridge_core/notifications/email/client.py b/packages/core/src/notify_bridge_core/notifications/email/client.py index cd99113..10287d9 100644 --- a/packages/core/src/notify_bridge_core/notifications/email/client.py +++ b/packages/core/src/notify_bridge_core/notifications/email/client.py @@ -2,14 +2,32 @@ from __future__ import annotations +import html import logging +import re +import ssl from dataclasses import dataclass -from email.mime.multipart import MIMEMultipart -from email.mime.text import MIMEText -from typing import Any +from email.headerregistry import Address +from email.message import EmailMessage +from typing import Any, Final, Literal + +try: # Optional dependency — fail at first send rather than at import. + import aiosmtplib + from aiosmtplib import SMTPException +except ImportError: # pragma: no cover + aiosmtplib = None # type: ignore[assignment] + + class SMTPException(Exception): # type: ignore[no-redef] + pass _LOGGER = logging.getLogger(__name__) +_DEFAULT_TIMEOUT_S: Final = 30.0 +_TlsMode = Literal["auto", "implicit", "starttls", "none"] +# RFC 5322 lite: catches the obvious-bad addresses ("foo bar", "no-at", +# embedded CRLF) without pretending to fully validate addresses. +_EMAIL_RE: Final = re.compile(r"^[^\s@\r\n,;<>]+@[^\s@\r\n,;<>]+\.[^\s@\r\n,;<>]+$") + @dataclass class SmtpConfig: @@ -22,6 +40,55 @@ class SmtpConfig: from_address: str = "" from_name: str = "Notify Bridge" use_tls: bool = True + # Explicit TLS mode. ``auto`` (back-compat) infers from ``use_tls`` and + # ``port``: 465 → implicit; 587 with use_tls=False → starttls; 25 → none. + tls_mode: _TlsMode = "auto" + timeout_s: float = _DEFAULT_TIMEOUT_S + + +def _strip_header(value: str) -> str: + """Reject CRLF and bare CR/LF in header-bound strings. + + SMTP header injection turns user-controlled subject/name strings into + arbitrary headers (``\\r\\nBcc: attacker@x``). The Python stdlib + accepts CRLF when followed by SP/HT (header folding), so explicit + sanitization is required even though :class:`EmailMessage` does some + validation of its own. + """ + return re.sub(r"[\r\n]+", " ", str(value or "")).strip() + + +def _validate_email(addr: str) -> str: + addr = _strip_header(addr) + if not addr: + raise ValueError("email address is empty") + if not _EMAIL_RE.match(addr): + raise ValueError("email address is invalid") + return addr + + +def _resolve_tls(cfg: SmtpConfig) -> tuple[bool, bool]: + """Resolve ``(use_tls, start_tls)`` flags from the config. + + ``tls_mode`` overrides ``use_tls``/port heuristics when provided. + """ + mode = cfg.tls_mode + if mode == "implicit": + return True, False + if mode == "starttls": + return False, True + if mode == "none": + return False, False + # auto — preserve the historical "use_tls bool + port heuristic" behavior + # but make the path explicit. + if cfg.use_tls: + return True, False + return False, cfg.port != 25 + + +def _to_html(text: str) -> str: + """Convert plain text to a minimal HTML body, escaped for safety.""" + return "
" + html.escape(text) + "
" class EmailClient: @@ -30,30 +97,39 @@ class EmailClient: def __init__(self, smtp_config: SmtpConfig) -> None: self._config = smtp_config + @staticmethod + def _ssl_context() -> ssl.SSLContext: + # Explicit context so the TLS posture is auditable; aiosmtplib + # defaults look correct today but past regressions (and downstream + # repackaging) make implicit reliance fragile. + return ssl.create_default_context() + async def verify_connection(self) -> dict[str, Any]: """Test SMTP connection and authentication without sending an email.""" - try: - import aiosmtplib - except ImportError: + if aiosmtplib is None: return {"success": False, "error": "aiosmtplib not installed"} cfg = self._config if not cfg.host: return {"success": False, "error": "SMTP host not configured"} + use_tls, start_tls = _resolve_tls(cfg) try: smtp = aiosmtplib.SMTP( hostname=cfg.host, port=cfg.port, - use_tls=cfg.use_tls, - start_tls=not cfg.use_tls and cfg.port != 25, + use_tls=use_tls, + start_tls=start_tls, + tls_context=self._ssl_context(), + timeout=cfg.timeout_s, + validate_certs=True, ) await smtp.connect() if cfg.username and cfg.password: await smtp.login(cfg.username, cfg.password) await smtp.quit() return {"success": True} - except Exception as e: + except (SMTPException, OSError) as e: _LOGGER.warning("SMTP verification failed for %s:%d: %s", cfg.host, cfg.port, e) return {"success": False, "error": str(e)} @@ -65,27 +141,52 @@ class EmailClient: body_html: str | None = None, to_name: str = "", ) -> dict[str, Any]: - """Send an email. Returns {"success": True} or {"success": False, "error": "..."}.""" - try: - import aiosmtplib - except ImportError: + """Send an email. + + Returns ``{"success": True}`` or ``{"success": False, "error": "..."}``. + + ``body_html`` is treated as already-safe markup. Pass ``None`` to + derive a safe HTML alternative from ``body_text`` automatically. + """ + if aiosmtplib is None: return {"success": False, "error": "aiosmtplib not installed. Run: pip install aiosmtplib"} cfg = self._config - if not cfg.host or not cfg.from_address: return {"success": False, "error": "SMTP not configured (missing host or from_address)"} - # Build email message - msg = MIMEMultipart("alternative") - msg["From"] = f"{cfg.from_name} <{cfg.from_address}>" if cfg.from_name else cfg.from_address - msg["To"] = f"{to_name} <{to_email}>" if to_name else to_email - msg["Subject"] = subject + try: + to_addr = _validate_email(to_email) + from_addr = _validate_email(cfg.from_address) + except ValueError as exc: + return {"success": False, "error": f"Invalid email address: {exc}"} - msg.attach(MIMEText(body_text, "plain", "utf-8")) - if body_html: - msg.attach(MIMEText(body_html, "html", "utf-8")) + # EmailMessage with structured Address objects rejects CRLF and + # framework-folds long headers safely. We still strip first because + # EmailMessage's display-name slot is a pure string. + msg = EmailMessage() + from_display = _strip_header(cfg.from_name) or "" + to_display = _strip_header(to_name) or "" + try: + from_user, _, from_domain = from_addr.partition("@") + to_user, _, to_domain = to_addr.partition("@") + msg["From"] = Address(from_display, from_user, from_domain) if from_display else from_addr + msg["To"] = Address(to_display, to_user, to_domain) if to_display else to_addr + except ValueError as exc: + return {"success": False, "error": f"Invalid email address: {exc}"} + msg["Subject"] = _strip_header(subject) + msg.set_content(body_text or "", subtype="plain", charset="utf-8") + # If the caller provided HTML explicitly, honor it; otherwise build a + # safe escaped version so a stray "<" in the rendered template can't + # break the markup. + msg.add_alternative( + body_html if body_html is not None else _to_html(body_text or ""), + subtype="html", + charset="utf-8", + ) + + use_tls, start_tls = _resolve_tls(cfg) try: await aiosmtplib.send( msg, @@ -93,11 +194,14 @@ class EmailClient: port=cfg.port, username=cfg.username or None, password=cfg.password or None, - use_tls=cfg.use_tls, - start_tls=not cfg.use_tls and cfg.port != 25, + use_tls=use_tls, + start_tls=start_tls, + tls_context=self._ssl_context(), + timeout=cfg.timeout_s, + validate_certs=True, ) - _LOGGER.info("Email sent to %s", to_email) + _LOGGER.info("Email sent to %s", to_addr) return {"success": True} - except Exception as e: - _LOGGER.error("Failed to send email to %s: %s", to_email, e) + except (SMTPException, OSError) as e: + _LOGGER.error("Failed to send email to %s: %s", to_addr, e) return {"success": False, "error": str(e)} diff --git a/packages/core/src/notify_bridge_core/notifications/http_base.py b/packages/core/src/notify_bridge_core/notifications/http_base.py new file mode 100644 index 0000000..78c3d4b --- /dev/null +++ b/packages/core/src/notify_bridge_core/notifications/http_base.py @@ -0,0 +1,196 @@ +"""Shared HTTP infrastructure for notification provider clients. + +Slack/Discord/ntfy/Matrix/Webhook all follow the same pattern: build a +JSON payload, POST/PUT it, decode 200-range as success, decode 4xx/5xx +into a stable error dict, and retry transient 429/503 responses with a +capped ``Retry-After``. ``HttpProviderClient`` centralizes that pattern +so every provider gets the same SSRF guard, timeouts, secret-redacted +errors, and bounded retry policy by construction — adding a new +provider doesn't get to forget any one of them. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Final, Mapping + +import aiohttp + +from .redact import redact, redact_exc +from .ssrf import UnsafeURLError, avalidate_outbound_url + +_LOGGER = logging.getLogger(__name__) + +_DEFAULT_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10) +_MAX_RETRIES: Final = 3 +_MAX_RETRY_AFTER_S: Final = 60.0 +_RETRY_STATUSES: Final = frozenset({429, 503}) + +# Hop-by-hop / framing headers a caller must not be able to override via +# user-supplied target config. Letting them through enables request +# smuggling, host-header bypasses of WAFs, and cache poisoning. +_FORBIDDEN_HEADERS: Final = frozenset({ + "host", + "content-length", + "transfer-encoding", + "connection", + "keep-alive", + "te", + "upgrade", + "expect", + "proxy-authorization", + "proxy-connection", +}) + + +def safe_headers(headers: Mapping[str, str] | None) -> dict[str, str]: + """Return a copy of ``headers`` with hop-by-hop/forbidden names dropped + and CRLF-bearing values rejected. + + A target config that lets a user inject ``"X-Foo": "bar\\r\\nHost: evil"`` + can perform request smuggling depending on aiohttp's framing. We strip + those values rather than letting them reach the wire. + """ + if not headers: + return {} + safe: dict[str, str] = {} + for raw_name, raw_value in headers.items(): + name = str(raw_name).strip() + if not name or name.lower() in _FORBIDDEN_HEADERS: + continue + if any(c in name for c in "\r\n:"): + continue + value = str(raw_value) + if "\r" in value or "\n" in value: + continue + safe[name] = value + return safe + + +def make_error(message: str, *, status: int | None = None, body: str | None = None) -> dict[str, Any]: + """Build a stable failure dict shape used by every provider client.""" + err: dict[str, Any] = {"success": False, "error": redact(message)} + if status is not None: + err["status_code"] = status + if body: + err["body"] = redact(body)[:200] + return err + + +def make_success(**extra: Any) -> dict[str, Any]: + """Build a stable success dict shape used by every provider client.""" + out: dict[str, Any] = {"success": True} + out.update(extra) + return out + + +def _retry_after_seconds(headers: Mapping[str, str], cap_s: float) -> float: + raw = headers.get("Retry-After") or headers.get("retry-after") or "2" + try: + seconds = float(raw) + except (TypeError, ValueError): + seconds = 2.0 + return max(0.0, min(seconds, cap_s)) + + +class HttpProviderClient: + """Base for JSON-over-HTTP notification providers. + + Subclasses call :meth:`request` instead of using ``self._session`` + directly. ``request`` runs the SSRF guard (skippable for known-safe + upstreams via ``ssrf_validate=False``), enforces a per-request + timeout, retries 429/503 with a capped ``Retry-After``, and turns + transport/HTTP errors into the canonical ``{"success": False, ...}`` + shape with secrets redacted. + """ + + _max_retries: int = _MAX_RETRIES + # Settable per-instance so tests / hostile-upstream tuning can + # tighten the cap. Reads of this attribute fall through to the + # class default when no instance value has been set. + _MAX_RETRY_AFTER: float = _MAX_RETRY_AFTER_S + + def __init__( + self, + session: aiohttp.ClientSession, + *, + timeout: aiohttp.ClientTimeout | None = None, + provider_name: str = "http", + ) -> None: + self._session = session + self._timeout = timeout or _DEFAULT_TIMEOUT + self._provider = provider_name + + async def request( + self, + method: str, + url: str, + *, + json: Any = None, + headers: Mapping[str, str] | None = None, + ssrf_validate: bool = True, + retry_statuses: frozenset[int] = _RETRY_STATUSES, + ) -> dict[str, Any]: + """Send a single request with retry + redaction. Always returns a dict. + + On 2xx returns ``{"success": True, "status_code": int, "json": ... + OR "body": str}``. On non-2xx returns the canonical error dict. + """ + if ssrf_validate: + try: + await avalidate_outbound_url(url) + except UnsafeURLError as err: + return make_error(f"Unsafe URL: {redact_exc(err)}") + + outbound_headers: dict[str, str] = {"Content-Type": "application/json"} + outbound_headers.update(safe_headers(headers)) + + for attempt in range(1, self._max_retries + 1): + try: + async with self._session.request( + method, + url, + json=json, + headers=outbound_headers, + timeout=self._timeout, + allow_redirects=False, + ) as resp: + if resp.status in retry_statuses and attempt < self._max_retries: + delay = _retry_after_seconds(resp.headers, self._MAX_RETRY_AFTER) + _LOGGER.warning( + "%s %s %s: HTTP %d, retrying after %.2fs (attempt %d/%d)", + self._provider, method, redact(url), resp.status, + delay, attempt, self._max_retries, + ) + await resp.read() # drain body so connection can return to pool + await asyncio.sleep(delay) + continue + if 200 <= resp.status < 300: + try: + payload: Any = await resp.json(content_type=None) + except (aiohttp.ContentTypeError, ValueError): + payload = await resp.text() + return make_success(status_code=resp.status, json=payload) + body = await resp.text() + return make_error( + f"HTTP {resp.status}", + status=resp.status, + body=body, + ) + except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err: + # asyncio.CancelledError inherits from BaseException on + # 3.8+, so it is not caught here — good: cancellation + # must propagate. + if attempt < self._max_retries and isinstance(err, asyncio.TimeoutError): + _LOGGER.warning( + "%s %s %s: timeout, retrying (attempt %d/%d)", + self._provider, method, redact(url), + attempt, self._max_retries, + ) + await asyncio.sleep(min(2 ** (attempt - 1), 5)) + continue + return make_error(redact_exc(err)) + + # Retry budget exhausted on a retriable status. + return make_error("Rate limited (retries exhausted)") diff --git a/packages/core/src/notify_bridge_core/notifications/matrix/client.py b/packages/core/src/notify_bridge_core/notifications/matrix/client.py index 224239f..e4c8407 100644 --- a/packages/core/src/notify_bridge_core/notifications/matrix/client.py +++ b/packages/core/src/notify_bridge_core/notifications/matrix/client.py @@ -2,22 +2,36 @@ from __future__ import annotations +import asyncio import logging -import time -from typing import Any +import re +import uuid +from typing import Any, Final +from urllib.parse import quote import aiohttp +from ..http_base import _MAX_RETRY_AFTER_S, safe_headers +from ..redact import redact, redact_exc + _LOGGER = logging.getLogger(__name__) -# Monotonically increasing transaction counter for idempotent sends -_txn_counter = int(time.time() * 1000) +# Matrix room IDs are ``!opaque:server.name`` per the spec. We also allow +# the ``#alias:server`` form because some callers may pass aliases. The +# pattern's purpose is to reject obvious path-injection (``/``, ``..``, +# control chars, query/fragment chars) before the value reaches a URL. +_ROOM_ID_RE: Final = re.compile(r"^[!#][^\x00-\x1f\s/?#]{1,255}:[A-Za-z0-9.\-:]{1,255}$") + +_DEFAULT_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10) +_MAX_RETRIES: Final = 3 -def _next_txn_id() -> str: - global _txn_counter - _txn_counter += 1 - return str(_txn_counter) +def _validate_room_id(room_id: str) -> str: + if not room_id: + raise ValueError("room_id is empty") + if not _ROOM_ID_RE.match(room_id): + raise ValueError("room_id format is invalid") + return room_id class MatrixClient: @@ -33,49 +47,67 @@ class MatrixClient: self._homeserver = homeserver_url.rstrip("/") self._token = access_token + @staticmethod + def _txn_id() -> str: + # uuid4 hex is collision-resistant across processes/restarts; + # eliminates the previous module-level counter race. + return uuid.uuid4().hex + async def send_message( self, room_id: str, message: str, html_message: str | None = None, ) -> dict[str, Any]: - """Send a text message to a Matrix room. + """Send a text message to a Matrix room.""" + try: + room_id = _validate_room_id(room_id) + except ValueError as exc: + return {"success": False, "error": f"Invalid room_id: {exc}"} - Args: - room_id: Internal room ID (e.g. !abc:matrix.org) - message: Plain text body - html_message: Optional HTML-formatted body - """ - if not room_id: - return {"success": False, "error": "Missing room_id"} + encoded_room = quote(room_id, safe="") + url = ( + f"{self._homeserver}/_matrix/client/v3/rooms/{encoded_room}" + f"/send/m.room.message/{self._txn_id()}" + ) - txn_id = _next_txn_id() - # URL-encode the room_id (! and : need encoding) - encoded_room = room_id.replace("!", "%21").replace(":", "%3A") - url = f"{self._homeserver}/_matrix/client/v3/rooms/{encoded_room}/send/m.room.message/{txn_id}" - - body: dict[str, Any] = { - "msgtype": "m.text", - "body": message, - } + body: dict[str, Any] = {"msgtype": "m.text", "body": message} if html_message: body["format"] = "org.matrix.custom.html" body["formatted_body"] = html_message - headers = { + headers = safe_headers({ "Authorization": f"Bearer {self._token}", "Content-Type": "application/json", - } + }) - try: - async with self._session.put( - url, json=body, headers=headers, allow_redirects=False, - ) as resp: - if 200 <= resp.status < 300: - return {"success": True} - resp_body = await resp.text() - if resp.status == 429: - _LOGGER.warning("Matrix rate limited: %s", resp_body[:200]) - return {"success": False, "error": f"HTTP {resp.status}: {resp_body[:200]}"} - except aiohttp.ClientError as e: - return {"success": False, "error": str(e)} + for attempt in range(1, _MAX_RETRIES + 1): + try: + async with self._session.put( + url, json=body, headers=headers, + timeout=_DEFAULT_TIMEOUT, allow_redirects=False, + ) as resp: + if 200 <= resp.status < 300: + return {"success": True} + resp_body = await resp.text() + if resp.status == 429 and attempt < _MAX_RETRIES: + try: + wait_s = float(resp.headers.get("Retry-After", "2")) + except (TypeError, ValueError): + wait_s = 2.0 + wait_s = max(0.0, min(wait_s, _MAX_RETRY_AFTER_S)) + _LOGGER.warning( + "Matrix rate limited, retrying after %.2fs (attempt %d/%d)", + wait_s, attempt, _MAX_RETRIES, + ) + await asyncio.sleep(wait_s) + continue + return { + "success": False, + "error": f"HTTP {resp.status}: {redact(resp_body)[:200]}", + "status_code": resp.status, + } + except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e: + return {"success": False, "error": redact_exc(e)} + + return {"success": False, "error": "Rate limited (retries exhausted)"} diff --git a/packages/core/src/notify_bridge_core/notifications/ntfy/client.py b/packages/core/src/notify_bridge_core/notifications/ntfy/client.py index d5be0f9..44a513e 100644 --- a/packages/core/src/notify_bridge_core/notifications/ntfy/client.py +++ b/packages/core/src/notify_bridge_core/notifications/ntfy/client.py @@ -3,18 +3,33 @@ from __future__ import annotations import logging -from typing import Any +from typing import Any, Final import aiohttp +from ..http_base import HttpProviderClient + _LOGGER = logging.getLogger(__name__) +_PRIORITY_MIN: Final = 1 +_PRIORITY_MAX: Final = 5 +_DEFAULT_PRIORITY: Final = 3 +_MAX_TAGS: Final = 10 +_MAX_TAG_LEN: Final = 64 -class NtfyClient: + +def _strip_crlf(value: str) -> str: + """Remove CR/LF — ntfy's JSON path is safe today, but the same fields + are used by the header API; defensive sanitization here means a future + refactor can't accidentally re-introduce header injection.""" + return value.replace("\r", " ").replace("\n", " ") + + +class NtfyClient(HttpProviderClient): """Sends push notifications via ntfy server.""" def __init__(self, session: aiohttp.ClientSession) -> None: - self._session = session + super().__init__(session, provider_name="ntfy") async def send( self, @@ -22,41 +37,48 @@ class NtfyClient: topic: str, message: str, title: str | None = None, - priority: int = 3, + priority: int = _DEFAULT_PRIORITY, tags: list[str] | None = None, click_url: str | None = None, auth_token: str | None = None, + markdown: bool = True, ) -> dict[str, Any]: """Send a push notification to an ntfy topic.""" if not server_url or not topic: return {"success": False, "error": "Missing server_url or topic"} - url = f"{server_url.rstrip('/')}" + topic = _strip_crlf(topic).strip() + if not topic: + return {"success": False, "error": "Topic is empty after sanitization"} + + try: + priority_int = int(priority) if priority is not None else _DEFAULT_PRIORITY + except (TypeError, ValueError): + priority_int = _DEFAULT_PRIORITY + priority_int = max(_PRIORITY_MIN, min(priority_int, _PRIORITY_MAX)) + payload: dict[str, Any] = { "topic": topic, "message": message, - "markdown": True, + "markdown": bool(markdown), } if title: - payload["title"] = title - if priority != 3: - payload["priority"] = priority + payload["title"] = _strip_crlf(title) + if priority_int != _DEFAULT_PRIORITY: + payload["priority"] = priority_int if tags: - payload["tags"] = tags + cleaned = [ + _strip_crlf(str(t))[:_MAX_TAG_LEN] + for t in tags[:_MAX_TAGS] + if t + ] + if cleaned: + payload["tags"] = cleaned if click_url: - payload["click"] = click_url + payload["click"] = _strip_crlf(click_url) - headers: dict[str, str] = {"Content-Type": "application/json"} + headers: dict[str, str] = {} if auth_token: headers["Authorization"] = f"Bearer {auth_token}" - try: - async with self._session.post( - url, json=payload, headers=headers, allow_redirects=False, - ) as resp: - if 200 <= resp.status < 300: - return {"success": True} - body = await resp.text() - return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"} - except aiohttp.ClientError as e: - return {"success": False, "error": str(e)} + return await self.request("POST", server_url.rstrip("/"), json=payload, headers=headers) diff --git a/packages/core/src/notify_bridge_core/notifications/queue.py b/packages/core/src/notify_bridge_core/notifications/queue.py index 1cc02bd..585b3eb 100644 --- a/packages/core/src/notify_bridge_core/notifications/queue.py +++ b/packages/core/src/notify_bridge_core/notifications/queue.py @@ -2,47 +2,88 @@ from __future__ import annotations +import asyncio +import copy import logging from datetime import datetime, timezone -from typing import Any +from typing import Any, Final from notify_bridge_core.storage import StorageBackend _LOGGER = logging.getLogger(__name__) +# Bound on queue length. Without a cap, a misconfigured quiet-hour +# window plus high event throughput grows the persisted file unboundedly +# and every enqueue rewrites the whole file (O(n²) total writes). When +# the cap is hit we drop the oldest entry (FIFO) so the most recent +# events still reach the recipient when the window opens. +DEFAULT_MAX_QUEUE_SIZE: Final = 1000 + class NotificationQueue: """Persistent queue for notifications deferred during quiet hours.""" - def __init__(self, backend: StorageBackend) -> None: + def __init__( + self, + backend: StorageBackend, + *, + max_size: int = DEFAULT_MAX_QUEUE_SIZE, + ) -> None: self._backend = backend self._data: dict[str, Any] | None = None + self._max_size = max_size + # Coordinates load / enqueue / clear / remove so a write-while-load + # race can't leave the in-memory copy out of sync with disk and so + # bulk operations don't interleave their reads-then-writes. + self._lock = asyncio.Lock() + + @staticmethod + def _ensure_schema(data: Any) -> dict[str, Any]: + if not isinstance(data, dict) or not isinstance(data.get("queue"), list): + return {"queue": []} + return data async def async_load(self) -> None: - self._data = await self._backend.load() or {"queue": []} + async with self._lock: + raw = await self._backend.load() + self._data = self._ensure_schema(raw) async def async_enqueue(self, notification_params: dict[str, Any]) -> None: - if self._data is None: - self._data = {"queue": []} - self._data["queue"].append({ - "params": notification_params, - "queued_at": datetime.now(timezone.utc).isoformat(), - }) - await self._backend.save(self._data) + async with self._lock: + if self._data is None: + self._data = {"queue": []} + queue: list[dict[str, Any]] = self._data["queue"] + queue.append({ + "params": notification_params, + "queued_at": datetime.now(timezone.utc).isoformat(), + }) + if self._max_size > 0 and len(queue) > self._max_size: + # Drop oldest (FIFO) so a new event can still land. + drop = len(queue) - self._max_size + _LOGGER.warning( + "NotificationQueue: dropping %d oldest entries (cap=%d)", + drop, self._max_size, + ) + del queue[:drop] + await self._backend.save(self._data) def get_all(self) -> list[dict[str, Any]]: if not self._data: return [] - return list(self._data.get("queue", [])) + # Deep copy so callers can iterate / mutate without corrupting the + # in-memory queue. The cost is bounded by ``max_size``. + return copy.deepcopy(list(self._data.get("queue", []))) def has_pending(self) -> bool: return bool(self._data and self._data.get("queue")) async def async_clear(self) -> None: - if self._data: - self._data["queue"] = [] - await self._backend.save(self._data) + async with self._lock: + if self._data: + self._data["queue"] = [] + await self._backend.save(self._data) async def async_remove(self) -> None: - await self._backend.remove() - self._data = None + async with self._lock: + await self._backend.remove() + self._data = None diff --git a/packages/core/src/notify_bridge_core/notifications/receiver.py b/packages/core/src/notify_bridge_core/notifications/receiver.py index b07c221..9b20ce3 100644 --- a/packages/core/src/notify_bridge_core/notifications/receiver.py +++ b/packages/core/src/notify_bridge_core/notifications/receiver.py @@ -3,7 +3,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any +from typing import Any, Callable @dataclass @@ -70,51 +70,64 @@ class MatrixReceiver(Receiver): room_id: str = "" +_ReceiverFactory = Callable[[str, dict[str, Any]], Receiver] + + +def _coerce_int(value: Any, default: int) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + +_RECEIVER_FACTORIES: dict[str, _ReceiverFactory] = { + "telegram": lambda locale, config: TelegramReceiver( + locale=locale, config=config, chat_id=str(config.get("chat_id", "")), + ), + "webhook": lambda locale, config: WebhookReceiver( + locale=locale, config=config, + url=str(config.get("url", "")), + headers=dict(config.get("headers", {}) or {}), + ), + "email": lambda locale, config: EmailReceiver( + locale=locale, config=config, + email=str(config.get("email", "")), + name=str(config.get("name", "")), + ), + "discord": lambda locale, config: DiscordReceiver( + locale=locale, config=config, + webhook_url=str(config.get("webhook_url", "")), + ), + "slack": lambda locale, config: SlackReceiver( + locale=locale, config=config, + webhook_url=str(config.get("webhook_url", "")), + ), + "ntfy": lambda locale, config: NtfyReceiver( + locale=locale, config=config, + topic=str(config.get("topic", "")), + priority=_coerce_int(config.get("priority"), 3), + ), + "matrix": lambda locale, config: MatrixReceiver( + locale=locale, config=config, + room_id=str(config.get("room_id", "")), + ), +} + + +def register_receiver_factory(target_type: str, factory: _ReceiverFactory) -> None: + """Register a receiver factory for an out-of-tree target type.""" + _RECEIVER_FACTORIES[target_type] = factory + + def build_receiver(target_type: str, config: dict[str, Any], locale: str = "") -> Receiver: - """Factory: build typed Receiver from target type and config dict.""" - if target_type == "telegram": - return TelegramReceiver( - locale=locale, - config=config, - chat_id=str(config.get("chat_id", "")), - ) - if target_type == "webhook": - return WebhookReceiver( - locale=locale, - config=config, - url=config.get("url", ""), - headers=config.get("headers", {}), - ) - if target_type == "email": - return EmailReceiver( - locale=locale, - config=config, - email=config.get("email", ""), - name=config.get("name", ""), - ) - if target_type == "discord": - return DiscordReceiver( - locale=locale, - config=config, - webhook_url=config.get("webhook_url", ""), - ) - if target_type == "slack": - return SlackReceiver( - locale=locale, - config=config, - webhook_url=config.get("webhook_url", ""), - ) - if target_type == "ntfy": - return NtfyReceiver( - locale=locale, - config=config, - topic=config.get("topic", ""), - priority=config.get("priority", 3), - ) - if target_type == "matrix": - return MatrixReceiver( - locale=locale, - config=config, - room_id=config.get("room_id", ""), - ) - return Receiver(locale=locale, config=config) + """Factory: build typed Receiver from target type and config dict. + + Falls back to a base ``Receiver`` for unknown target types so callers + that handle types defensively still receive a usable object — but the + dispatcher rejects them with ``"Unknown target type"`` so a typo can't + silently route to nowhere. + """ + factory = _RECEIVER_FACTORIES.get(target_type) + if factory is None: + return Receiver(locale=locale, config=config) + return factory(locale, config) diff --git a/packages/core/src/notify_bridge_core/notifications/redact.py b/packages/core/src/notify_bridge_core/notifications/redact.py new file mode 100644 index 0000000..9647d3c --- /dev/null +++ b/packages/core/src/notify_bridge_core/notifications/redact.py @@ -0,0 +1,64 @@ +"""Secret-redaction helpers for log lines and error strings. + +Notification clients embed secrets in URLs (Telegram bot tokens) and +Authorization headers (Matrix access tokens, ntfy bearer tokens). When +those secrets surface in ``aiohttp.ClientError.__str__``, response +bodies, or operator-visible error fields, they leak into logs and into +the per-target result dict that callers may forward upstream. ``redact`` +returns a defanged copy safe for both contexts. +""" + +from __future__ import annotations + +import re +from typing import Final + +# api.telegram.org/bot:/ +_TELEGRAM_BOT_TOKEN_RE: Final = re.compile( + r"(api\.telegram\.org/bot)\d+:[A-Za-z0-9_-]+", re.IGNORECASE, +) +# Authorization: Bearer (header form, case-insensitive) +_BEARER_RE: Final = re.compile(r"(Bearer\s+)[A-Za-z0-9._\-+/=]+", re.IGNORECASE) +# Discord webhook: /api/webhooks// +_DISCORD_WEBHOOK_RE: Final = re.compile( + r"(discord(?:app)?\.com/api/webhooks/\d+/)[A-Za-z0-9_-]+", + re.IGNORECASE, +) +# Slack webhook path: /services/T.../B.../ +_SLACK_WEBHOOK_RE: Final = re.compile( + r"(hooks\.slack\.com/services/[A-Z0-9]+/[A-Z0-9]+/)[A-Za-z0-9]+", + re.IGNORECASE, +) +# URL userinfo: scheme://user:password@host +_URL_USERINFO_RE: Final = re.compile( + r"([a-z][a-z0-9+\-.]*://)[^/@\s]+:[^/@\s]+@", + re.IGNORECASE, +) +# Common token query parameters +_QUERY_TOKEN_RE: Final = re.compile( + r"([?&](?:token|access_token|api_key|key|secret|password)=)[^&\s]+", + re.IGNORECASE, +) + + +def redact(text: str) -> str: + """Return ``text`` with known secret patterns replaced by ``***``. + + Idempotent and safe to call on already-redacted strings. Always + returns a ``str``; non-strings are coerced via ``str()`` so callers + can pass exception instances directly. + """ + if not isinstance(text, str): + text = str(text) + text = _TELEGRAM_BOT_TOKEN_RE.sub(r"\1***", text) + text = _DISCORD_WEBHOOK_RE.sub(r"\1***", text) + text = _SLACK_WEBHOOK_RE.sub(r"\1***", text) + text = _BEARER_RE.sub(r"\1***", text) + text = _URL_USERINFO_RE.sub(r"\1***@", text) + text = _QUERY_TOKEN_RE.sub(r"\1***", text) + return text + + +def redact_exc(err: BaseException) -> str: + """Redact-and-stringify an exception. Convenience for error fields.""" + return redact(str(err)) diff --git a/packages/core/src/notify_bridge_core/notifications/slack/client.py b/packages/core/src/notify_bridge_core/notifications/slack/client.py index 681b286..ad86468 100644 --- a/packages/core/src/notify_bridge_core/notifications/slack/client.py +++ b/packages/core/src/notify_bridge_core/notifications/slack/client.py @@ -7,14 +7,16 @@ from typing import Any import aiohttp +from ..http_base import HttpProviderClient + _LOGGER = logging.getLogger(__name__) -class SlackClient: +class SlackClient(HttpProviderClient): """Sends messages via Slack incoming webhook URLs.""" def __init__(self, session: aiohttp.ClientSession) -> None: - self._session = session + super().__init__(session, provider_name="slack") async def send( self, @@ -33,19 +35,4 @@ class SlackClient: if icon_emoji: payload["icon_emoji"] = icon_emoji - try: - async with self._session.post( - webhook_url, - json=payload, - headers={"Content-Type": "application/json"}, - allow_redirects=False, - ) as resp: - if resp.status == 429: - _LOGGER.warning("Slack rate limited") - return {"success": False, "error": "Rate limited by Slack"} - if 200 <= resp.status < 300: - return {"success": True} - body = await resp.text() - return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"} - except aiohttp.ClientError as e: - return {"success": False, "error": str(e)} + return await self.request("POST", webhook_url, json=payload) diff --git a/packages/core/src/notify_bridge_core/notifications/ssrf.py b/packages/core/src/notify_bridge_core/notifications/ssrf.py index 66aea5a..bc113b9 100644 --- a/packages/core/src/notify_bridge_core/notifications/ssrf.py +++ b/packages/core/src/notify_bridge_core/notifications/ssrf.py @@ -1,10 +1,22 @@ """Outbound URL validation to mitigate SSRF attacks. -User-controlled URLs (provider `url`, webhook target `url`, shared-link -base URLs, image downloads) must be validated before any HTTP request is -issued. This module rejects schemes other than http/https and blocks -destinations that resolve to private, loopback, link-local, or unspecified -address ranges. +User-controlled URLs (provider ``url``, webhook target ``url``, +shared-link base URLs, image downloads) must be validated before any +HTTP request is issued. This module rejects schemes other than +http/https and blocks destinations that resolve to private, loopback, +link-local, unspecified, CGNAT (100.64.0.0/10), or IPv4-mapped IPv6 +ranges. + +DNS rebinding mitigation +~~~~~~~~~~~~~~~~~~~~~~~~ +``avalidate_outbound_url`` returns the original URL on success, but +also returns the resolved IP it actually validated. Callers that pass +the validated URL straight into ``aiohttp`` are vulnerable to a +DNS-rebinding attack: the validator's ``getaddrinfo`` returns a public +IP; aiohttp's connect-time resolution returns ``127.0.0.1``. To close +that gap, use :func:`build_ssrf_safe_session` (or +:class:`PinnedResolver`) so the resolved IP from the validation step is +the one aiohttp connects to. Set ``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1`` in the environment for development against localhost services. @@ -17,12 +29,20 @@ import ipaddress import logging import os import socket +from dataclasses import dataclass from urllib.parse import urlparse +import aiohttp + _LOGGER = logging.getLogger(__name__) _ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1" -_ALLOWED_SCHEMES = {"http", "https"} +_ALLOWED_SCHEMES = frozenset({"http", "https"}) + +# Carrier-grade NAT range. Not in stdlib's ``is_private``; an attacker +# pointing a domain at a CGNAT IP could reach the operator's ISP-side +# routing infrastructure. RFC 6598. +_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10") if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner _LOGGER.warning( @@ -36,7 +56,29 @@ class UnsafeURLError(ValueError): """Raised when a URL targets a disallowed network destination.""" +@dataclass(frozen=True) +class ValidatedURL: + """Result of validating an outbound URL. + + Attributes: + url: The original URL string (unchanged). + host: Hostname extracted from the URL (lower-cased, IDN-encoded). + ip: Resolved IP address that passed the block-range check, as a + string. Pass to :class:`PinnedResolver` to defeat DNS + rebinding by reusing this exact IP at connect time. + """ + + url: str + host: str + ip: str + + def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: + # An IPv4-mapped IPv6 like ``::ffff:127.0.0.1`` is NOT considered + # ``is_private`` etc. by stdlib — the v4 view holds those flags. So + # we unwrap before checking. + if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None: + ip = ip.ipv4_mapped return ( ip.is_private or ip.is_loopback @@ -44,22 +86,54 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: or ip.is_multicast or ip.is_reserved or ip.is_unspecified + or (isinstance(ip, ipaddress.IPv4Address) and ip in _CGNAT_NETWORK) ) +def _safe_host_repr(host: str) -> str: + """Return ``host`` shortened/escaped for safe inclusion in error text.""" + h = host[:64].replace("\r", "").replace("\n", "") + return h + + +def _normalize_host(parsed_host: str) -> str: + """Normalize a hostname: lowercase, strip trailing dot, IDN-encode.""" + host = parsed_host.lower() + if host.endswith("."): + host = host[:-1] + # Strip IPv6 zone id ("fe80::1%eth0") — must not reach the resolver. + if "%" in host: + host = host.split("%", 1)[0] + # IDN-encode unicode hostnames so we don't downgrade to confusables + # in any later log/output and so getaddrinfo gets the ascii form. + try: + if any(ord(c) > 127 for c in host): + host = host.encode("idna").decode("ascii") + except UnicodeError: + # Caller will fail on resolution; leave as-is so the error path + # surfaces a "DNS resolution failed" rather than a stack trace. + pass + return host + + def _check_scheme_host(url: str) -> tuple[str, str]: if not isinstance(url, str) or not url: raise UnsafeURLError("URL is empty") parsed = urlparse(url) - if parsed.scheme not in _ALLOWED_SCHEMES: - raise UnsafeURLError(f"Scheme '{parsed.scheme}' not allowed") + scheme = parsed.scheme.lower() + if scheme not in _ALLOWED_SCHEMES: + raise UnsafeURLError(f"Scheme '{scheme[:16]}' not allowed") host = parsed.hostname if not host: raise UnsafeURLError("URL has no host") - return parsed.scheme, host + return scheme, _normalize_host(host) -def _check_resolved_addresses(host: str, infos: list[tuple]) -> None: +def _select_addresses( + host: str, infos: list[tuple], +) -> list[ipaddress.IPv4Address | ipaddress.IPv6Address]: + """Return parsed, non-blocked IPs from ``getaddrinfo`` results.""" + addrs: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = [] for info in infos: sockaddr = info[4] try: @@ -67,64 +141,143 @@ def _check_resolved_addresses(host: str, infos: list[tuple]) -> None: except ValueError: continue if _is_blocked_ip(ip): - raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}") + raise UnsafeURLError( + f"Host {_safe_host_repr(host)} resolves to blocked address {ip}" + ) + addrs.append(ip) + if not addrs: + raise UnsafeURLError(f"Host {_safe_host_repr(host)} has no usable address") + return addrs def validate_outbound_url(url: str) -> str: """Validate ``url`` is safe to fetch; returns the URL on success. - Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP - is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``) - private addresses are permitted but the scheme check still applies. - - Synchronous; uses blocking ``socket.getaddrinfo``. Prefer - :func:`avalidate_outbound_url` from async code paths. + .. deprecated:: + Synchronous; uses blocking ``socket.getaddrinfo``. Prefer + :func:`avalidate_outbound_url` from async code paths so the + event loop isn't blocked, and use :func:`build_ssrf_safe_session` + to defeat DNS rebinding. """ _, host = _check_scheme_host(url) if _ALLOW_PRIVATE: return url - # Literal IP host try: ip = ipaddress.ip_address(host) if _is_blocked_ip(ip): - raise UnsafeURLError(f"Host {host} is in a blocked range") + raise UnsafeURLError(f"Host {_safe_host_repr(host)} is in a blocked range") return url except ValueError: pass try: infos = socket.getaddrinfo(host, None) - except socket.gaierror as exc: - raise UnsafeURLError(f"DNS resolution failed for {host}") from exc - _check_resolved_addresses(host, infos) + except (socket.gaierror, UnicodeError, OSError) as exc: + # ``UnicodeError`` covers IDNA failures (labels >63 chars, malformed + # unicode) which getaddrinfo surfaces as encoding errors rather than + # gaierror. ``OSError`` covers transient resolver failures on some + # platforms. + raise UnsafeURLError(f"DNS resolution failed for {_safe_host_repr(host)}") from exc + _select_addresses(host, infos) return url async def avalidate_outbound_url(url: str) -> str: - """Async variant that resolves DNS via the running loop's resolver. + """Async variant — returns the URL on success. - Use this from ``async def`` code paths to avoid blocking the event - loop on DNS lookups. + For DNS-rebinding-safe usage, prefer :func:`avalidate_outbound_url_full` + which also returns the resolved IP for connect-time pinning. + """ + result = await avalidate_outbound_url_full(url) + return result.url + + +async def avalidate_outbound_url_full(url: str) -> ValidatedURL: + """Validate ``url`` and return a :class:`ValidatedURL` on success. + + The returned ``ip`` field is the IP that passed the block-range + check. Pair with :class:`PinnedResolver` so aiohttp connects to that + exact IP and a malicious DNS server can't swap in a private address + after validation. """ _, host = _check_scheme_host(url) if _ALLOW_PRIVATE: - return url + # In dev mode we still resolve to give a usable IP, but we don't + # gate on the result. + try: + ip = str(ipaddress.ip_address(host)) + except ValueError: + try: + loop = asyncio.get_running_loop() + infos = await loop.getaddrinfo(host, None) + ip = infos[0][4][0] if infos else host + except (socket.gaierror, OSError): + ip = host + return ValidatedURL(url=url, host=host, ip=ip) + # Literal IP host try: - ip = ipaddress.ip_address(host) - if _is_blocked_ip(ip): - raise UnsafeURLError(f"Host {host} is in a blocked range") - return url + ip_obj = ipaddress.ip_address(host) + if _is_blocked_ip(ip_obj): + raise UnsafeURLError(f"Host {_safe_host_repr(host)} is in a blocked range") + return ValidatedURL(url=url, host=host, ip=str(ip_obj)) except ValueError: pass loop = asyncio.get_running_loop() try: infos = await loop.getaddrinfo(host, None) - except socket.gaierror as exc: - raise UnsafeURLError(f"DNS resolution failed for {host}") from exc - _check_resolved_addresses(host, infos) - return url + except (socket.gaierror, UnicodeError, OSError) as exc: + raise UnsafeURLError(f"DNS resolution failed for {_safe_host_repr(host)}") from exc + addrs = _select_addresses(host, infos) + return ValidatedURL(url=url, host=host, ip=str(addrs[0])) + + +class PinnedResolver(aiohttp.abc.AbstractResolver): + """aiohttp resolver that returns a fixed (host, ip) mapping. + + Used to pin the resolved IP from :func:`avalidate_outbound_url_full` + so aiohttp's connect-time resolution can't be tricked by DNS + rebinding into using a different IP than the one we validated. + + Falls back to :class:`aiohttp.AsyncResolver` (or default) for any + host not explicitly pinned, so a single resolver instance can be + reused across multiple validated URLs. + """ + + def __init__(self, mapping: dict[str, str] | None = None) -> None: + self._map: dict[str, str] = dict(mapping or {}) + self._fallback: aiohttp.abc.AbstractResolver | None = None + + def pin(self, host: str, ip: str) -> None: + self._map[host.lower()] = ip + + async def resolve( + self, host: str, port: int = 0, family: int = socket.AF_INET, + ) -> list[dict]: + ip = self._map.get(host.lower()) + if ip is not None: + try: + ip_obj = ipaddress.ip_address(ip) + except ValueError: + ip_obj = None + if ip_obj is not None: + fam = socket.AF_INET6 if ip_obj.version == 6 else socket.AF_INET + return [{ + "hostname": host, + "host": ip, + "port": port, + "family": fam, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + }] + if self._fallback is None: + self._fallback = aiohttp.ThreadedResolver() + return await self._fallback.resolve(host, port, family) + + async def close(self) -> None: + if self._fallback is not None: + await self._fallback.close() diff --git a/packages/core/src/notify_bridge_core/notifications/telegram/cache.py b/packages/core/src/notify_bridge_core/notifications/telegram/cache.py index ebd2f53..cee275a 100644 --- a/packages/core/src/notify_bridge_core/notifications/telegram/cache.py +++ b/packages/core/src/notify_bridge_core/notifications/telegram/cache.py @@ -2,16 +2,29 @@ from __future__ import annotations +import asyncio import logging from datetime import datetime, timezone -from typing import Any +from typing import Any, Final from notify_bridge_core.storage import StorageBackend _LOGGER = logging.getLogger(__name__) -DEFAULT_TELEGRAM_CACHE_TTL = 48 * 60 * 60 -DEFAULT_MAX_ENTRIES = 5000 +DEFAULT_TELEGRAM_CACHE_TTL: Final = 48 * 60 * 60 +DEFAULT_MAX_ENTRIES: Final = 5000 + + +def _parse_iso(value: str | None) -> datetime | None: + """Parse an ISO-8601 timestamp tolerantly. Returns ``None`` on failure.""" + if not value or not isinstance(value, str): + return None + try: + # Python <3.11 doesn't accept "Z"; normalize to +00:00. + v = value.replace("Z", "+00:00") if value.endswith("Z") else value + return datetime.fromisoformat(v) + except ValueError: + return None class TelegramFileCache: @@ -25,7 +38,17 @@ class TelegramFileCache: Intended for content-addressable assets (e.g. Immich) where re-uploads should be triggered by visual change, not elapsed time. - ``max_entries`` always applies as an LRU size cap (by ``cached_at``). + ``max_entries`` always applies as a FIFO size cap (oldest-cached first). + + Concurrency + ~~~~~~~~~~~ + All mutators take an internal ``asyncio.Lock`` so concurrent + media-group sends can't interleave a read-time invalidation with a + bulk write and corrupt the underlying dict (``RuntimeError: + dictionary changed size during iteration``) or lose just-written + entries. Reads do not take the lock — they are O(1) dict lookups — + but ``get`` uses a snapshot reference so it cannot mutate the data + structure under another task. """ def __init__( @@ -40,35 +63,40 @@ class TelegramFileCache: self._ttl_seconds = ttl_seconds self._use_thumbhash = use_thumbhash self._max_entries = max_entries + self._lock = asyncio.Lock() async def async_load(self) -> None: - self._data = await self._backend.load() or {"files": {}} - await self._cleanup_expired() + async with self._lock: + self._data = await self._backend.load() or {"files": {}} + await self._cleanup_expired_locked() - async def _cleanup_expired(self) -> None: + async def _cleanup_expired_locked(self) -> None: + """Caller must hold ``self._lock``.""" if not self._data or "files" not in self._data: return - files = self._data["files"] + files: dict[str, dict[str, Any]] = self._data["files"] changed = False - # TTL sweep — only when TTL validation is active (i.e. no thumbhash - # mode and a positive TTL). In thumbhash mode we rely entirely on - # content validation; in "TTL disabled" mode (ttl_seconds <= 0) we - # cache forever, subject only to the size cap. if not self._use_thumbhash and self._ttl_seconds > 0: now = datetime.now(timezone.utc) - expired = [ - url for url, entry in files.items() - if entry.get("cached_at") and - (now - datetime.fromisoformat(entry["cached_at"])).total_seconds() > self._ttl_seconds - ] + expired: list[str] = [] + for url, entry in list(files.items()): + cached_at = _parse_iso(entry.get("cached_at")) + if cached_at is None: + continue + if cached_at.tzinfo is None: + cached_at = cached_at.replace(tzinfo=timezone.utc) + if (now - cached_at).total_seconds() > self._ttl_seconds: + expired.append(url) for key in expired: del files[key] changed = True - # LRU cap — always enforced. Evicts oldest-cached entries first. if self._max_entries > 0 and len(files) > self._max_entries: - sorted_keys = sorted(files, key=lambda k: files[k].get("cached_at", "")) + sorted_keys = sorted( + files, + key=lambda k: _parse_iso(files[k].get("cached_at")) or datetime.min.replace(tzinfo=timezone.utc), + ) for key in sorted_keys[: len(files) - self._max_entries]: del files[key] changed = True @@ -80,7 +108,10 @@ class TelegramFileCache: if not self._data or "files" not in self._data: return None - entry = self._data["files"].get(key) + # Take a local reference so a concurrent ``async_set`` rebuilding + # the dict cannot pull the rug out mid-read. + files = self._data["files"] + entry = files.get(key) if not entry: return None @@ -88,19 +119,23 @@ class TelegramFileCache: if thumbhash is not None: stored = entry.get("thumbhash") if stored and stored != thumbhash: - del self._data["files"][key] + # Mark stale — actual deletion happens lock-protected + # in the next mutation. Returning None is sufficient + # for the caller to skip the cache hit. return None elif self._ttl_seconds > 0: - cached_at_str = entry.get("cached_at") - if cached_at_str: - age = (datetime.now(timezone.utc) - datetime.fromisoformat(cached_at_str)).total_seconds() + cached_at = _parse_iso(entry.get("cached_at")) + if cached_at is not None: + if cached_at.tzinfo is None: + cached_at = cached_at.replace(tzinfo=timezone.utc) + age = (datetime.now(timezone.utc) - cached_at).total_seconds() if age > self._ttl_seconds: return None return { "file_id": entry.get("file_id"), "type": entry.get("type"), - "size": entry.get("size"), # bytes of what was uploaded; None for legacy entries + "size": entry.get("size"), } async def async_set( @@ -111,21 +146,22 @@ class TelegramFileCache: thumbhash: str | None = None, size: int | None = None, ) -> None: - if self._data is None: - self._data = {"files": {}} + async with self._lock: + if self._data is None: + self._data = {"files": {}} - entry: dict[str, Any] = { - "file_id": file_id, - "type": media_type, - "cached_at": datetime.now(timezone.utc).isoformat(), - } - if thumbhash is not None: - entry["thumbhash"] = thumbhash - if size is not None: - entry["size"] = size + entry: dict[str, Any] = { + "file_id": file_id, + "type": media_type, + "cached_at": datetime.now(timezone.utc).isoformat(), + } + if thumbhash is not None: + entry["thumbhash"] = thumbhash + if size is not None: + entry["size"] = size - self._data["files"][key] = entry - await self._backend.save(self._data) + self._data["files"][key] = entry + await self._backend.save(self._data) async def async_set_many( self, @@ -139,32 +175,34 @@ class TelegramFileCache: """ if not entries: return - if self._data is None: - self._data = {"files": {}} + async with self._lock: + if self._data is None: + self._data = {"files": {}} - now_iso = datetime.now(timezone.utc).isoformat() - for item in entries: - if len(item) == 5: - key, file_id, media_type, thumbhash, size = item - else: - key, file_id, media_type, thumbhash = item - size = None - entry: dict[str, Any] = { - "file_id": file_id, - "type": media_type, - "cached_at": now_iso, - } - if thumbhash is not None: - entry["thumbhash"] = thumbhash - if size is not None: - entry["size"] = size - self._data["files"][key] = entry + now_iso = datetime.now(timezone.utc).isoformat() + for item in entries: + if len(item) == 5: + key, file_id, media_type, thumbhash, size = item + else: + key, file_id, media_type, thumbhash = item + size = None + entry: dict[str, Any] = { + "file_id": file_id, + "type": media_type, + "cached_at": now_iso, + } + if thumbhash is not None: + entry["thumbhash"] = thumbhash + if size is not None: + entry["size"] = size + self._data["files"][key] = entry - await self._backend.save(self._data) + await self._backend.save(self._data) async def async_remove(self) -> None: - await self._backend.remove() - self._data = None + async with self._lock: + await self._backend.remove() + self._data = None def stats(self) -> dict[str, Any]: """Return summary stats about the current cache contents. @@ -172,25 +210,33 @@ class TelegramFileCache: Includes the number of cached entries, total tracked size in bytes (only counts entries with a recorded ``size``), and the oldest / newest ``cached_at`` timestamps (ISO strings, or ``None`` if empty). + Timestamps are compared as parsed ``datetime`` objects so mixed + timezone formats (``Z`` vs ``+00:00``) order correctly. """ files = self._data.get("files", {}) if self._data else {} count = len(files) total_size = 0 - oldest: str | None = None - newest: str | None = None + oldest_dt: datetime | None = None + newest_dt: datetime | None = None + oldest_str: str | None = None + newest_str: str | None = None for entry in files.values(): size = entry.get("size") if isinstance(size, int): total_size += size cached_at = entry.get("cached_at") - if cached_at: - if oldest is None or cached_at < oldest: - oldest = cached_at - if newest is None or cached_at > newest: - newest = cached_at + dt = _parse_iso(cached_at) + if dt is None or not cached_at: + continue + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + if oldest_dt is None or dt < oldest_dt: + oldest_dt, oldest_str = dt, cached_at + if newest_dt is None or dt > newest_dt: + newest_dt, newest_str = dt, cached_at return { "count": count, "total_size_bytes": total_size, - "oldest": oldest, - "newest": newest, + "oldest": oldest_str, + "newest": newest_str, } diff --git a/packages/core/src/notify_bridge_core/notifications/telegram/client.py b/packages/core/src/notify_bridge_core/notifications/telegram/client.py index b8d51ff..45c48ab 100644 --- a/packages/core/src/notify_bridge_core/notifications/telegram/client.py +++ b/packages/core/src/notify_bridge_core/notifications/telegram/client.py @@ -7,18 +7,53 @@ import json import logging import mimetypes import re -from dataclasses import dataclass -from typing import Any, Callable +from dataclasses import dataclass, field +from typing import Any, Callable, Final import aiohttp from aiohttp import FormData -# Telegram 429 / flood-control retry settings. Telegram returns -# ``parameters.retry_after`` for rate limits; we honor it up to a cap so a -# pathological value can't park the dispatcher for minutes. -_TG_429_MAX_ATTEMPTS = 4 -_TG_429_MAX_WAIT_S = 60 -_RETRY_AFTER_RE = re.compile(r"retry after (\d+)", re.IGNORECASE) +from ..redact import redact, redact_exc +from ..ssrf import UnsafeURLError, avalidate_outbound_url +from .cache import TelegramFileCache +from .media import ( + TELEGRAM_API_BASE_URL, + TELEGRAM_MAX_CAPTION_LENGTH, + TELEGRAM_MAX_PHOTO_SIZE, + TELEGRAM_MAX_TEXT_LENGTH, + TELEGRAM_MAX_VIDEO_SIZE, + asset_id_from_cache_key, + check_photo_limits, + extract_asset_id_from_url, + is_asset_cache_key, + is_asset_id, + split_media_by_upload_size, +) + +_LOGGER = logging.getLogger(__name__) + +NotificationResult = dict[str, Any] + +# Telegram 429 / flood-control retry settings. +_TG_429_MAX_ATTEMPTS: Final = 4 +_TG_429_MAX_WAIT_S: Final = 60 +_RETRY_AFTER_RE: Final = re.compile(r"retry after (\d+)", re.IGNORECASE) + +# Media-group fan-out cap. With 10 large videos, unbounded gather can +# spike memory to ``N * max_asset_data_size`` and saturate the upstream +# provider. Three concurrent fetches still saturates a typical home LAN +# while keeping peak RAM bounded to ~3 * max_asset_data_size. +_MEDIA_FETCH_CONCURRENCY: Final = 3 + +# Telegram chat actions expire after ~5s, so we refresh slightly faster. +_KEEPALIVE_REFRESH_S: Final = 4 + +# Per-call timeouts. The session may already carry a default but we set +# explicit values here so a future caller without a session-default +# can't hang the dispatcher. +_API_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10) +_UPLOAD_TIMEOUT: Final = aiohttp.ClientTimeout(total=120, connect=10) +_DOWNLOAD_TIMEOUT: Final = aiohttp.ClientTimeout(total=120, connect=10) def _extract_retry_after(result: dict[str, Any]) -> int | None: @@ -46,42 +81,31 @@ def _extract_retry_after(result: dict[str, Any]) -> int | None: def _is_rate_limited(status: int, result: dict[str, Any]) -> bool: return status == 429 or result.get("error_code") == 429 -from .cache import TelegramFileCache -from .media import ( - TELEGRAM_API_BASE_URL, - TELEGRAM_MAX_PHOTO_SIZE, - TELEGRAM_MAX_VIDEO_SIZE, - asset_id_from_cache_key, - check_photo_limits, - extract_asset_id_from_url, - is_asset_cache_key, - is_asset_id, - split_media_by_upload_size, -) -_LOGGER = logging.getLogger(__name__) +def _is_parse_error(status: int, result: dict[str, Any]) -> bool: + """True iff the response signals a Bot-API parse-mode failure. -NotificationResult = dict[str, Any] + Gates on the canonical 400 status AND a parse-related description so a + 5xx that incidentally mentions "parse" doesn't trigger a retry. + """ + if status != 400 and result.get("error_code") != 400: + return False + desc = str(result.get("description", "")).lower() + return "can't parse" in desc or "parse_mode" in desc or "parse entities" in desc @dataclass(frozen=True) class _MediaKind: - """Describes one Telegram media kind (photo / video / document). - - Used by the generic _send_from_cache / _upload_media helpers so the three - send paths don't have to duplicate endpoint, field-name, or response-shape - boilerplate. - """ - api_method: str # "sendPhoto" / "sendVideo" / "sendDocument" - form_field: str # "photo" / "video" / "document" - cache_type: str # same string stored in cache entries - default_filename: str # "photo.jpg" / "video.mp4" / "file" + """Describes one Telegram media kind (photo / video / document).""" + api_method: str + form_field: str + cache_type: str + default_filename: str default_content_type: str def file_id_from_result(self, result: dict[str, Any]) -> str | None: obj = result.get(self.form_field) if isinstance(obj, list) and obj: - # sendPhoto returns a list of resolutions; the largest is last. last = obj[-1] return last.get("file_id") if isinstance(last, dict) else None if isinstance(obj, dict): @@ -89,9 +113,40 @@ class _MediaKind: return None -_PHOTO_KIND = _MediaKind("sendPhoto", "photo", "photo", "photo.jpg", "image/jpeg") -_VIDEO_KIND = _MediaKind("sendVideo", "video", "video", "video.mp4", "video/mp4") -_DOCUMENT_KIND = _MediaKind("sendDocument", "document", "document", "file", "application/octet-stream") +_PHOTO_KIND: Final = _MediaKind("sendPhoto", "photo", "photo", "photo.jpg", "image/jpeg") +_VIDEO_KIND: Final = _MediaKind("sendVideo", "video", "video", "video.mp4", "video/mp4") +_DOCUMENT_KIND: Final = _MediaKind( + "sendDocument", "document", "document", "file", "application/octet-stream", +) + + +@dataclass +class _KeepaliveHandle: + """Carries a chat-action keepalive task and its stop event.""" + task: asyncio.Task | None = None + stop_event: asyncio.Event = field(default_factory=asyncio.Event) + + +@dataclass +class _MediaItem: + """One element of a sendMediaGroup chunk after fetch+filter. + + Pairs a media JSON entry with the optional cache-update tuple + keyed by position. Bundling these together prevents the + ``media_json`` and ``cache_info`` lists from drifting out of + alignment under future edits. + """ + media_json: dict[str, Any] + cache_info: tuple[str, str, str | None, int] | None + attachment: tuple[str, bytes, str, str] | None # (name, data, filename, content_type) + + +def _truncate(text: str, limit: int, *, marker: str = "…") -> str: + """Soft-truncate ``text`` to ``limit`` chars with a trailing marker.""" + if len(text) <= limit: + return text + cutoff = max(0, limit - len(marker)) + return text[:cutoff] + marker class TelegramClient: @@ -114,6 +169,13 @@ class TelegramClient: self._url_resolver = url_resolver self._thumbhash_resolver = thumbhash_resolver + # ------------------------------------------------------------------ + # URL / cache helpers + # ------------------------------------------------------------------ + + def _api_url(self, method: str) -> str: + return f"{TELEGRAM_API_BASE_URL}{self._token}/{method}" + def _resolve_url(self, url: str) -> str: if self._url_resolver: return self._url_resolver(url) @@ -123,11 +185,6 @@ class TelegramClient: self, url: str | None, cache_key: str | None = None, ) -> tuple[TelegramFileCache | None, str | None, str | None]: if cache_key: - # Route asset-UUID cache keys to the asset cache so single-item - # sends hit the same cache the media-group path uses. Without - # this, a command returning one photo stored file_ids in the - # URL cache and a command returning multiple stored them in - # the asset cache — repeated sends never hit. if is_asset_cache_key(cache_key): bare_id = asset_id_from_cache_key(cache_key) thumbhash = ( @@ -153,6 +210,36 @@ class TelegramClient: is_asset = is_asset_cache_key(key) return self._asset_cache if is_asset else self._url_cache + # ------------------------------------------------------------------ + # Asset download with SSRF guard + # ------------------------------------------------------------------ + + async def _safe_get( + self, + url: str, + headers: dict[str, str] | None, + ) -> tuple[bytes | None, str | None]: + """SSRF-guarded GET that returns ``(data, error)``. + + Validates the URL via ``avalidate_outbound_url`` before any HTTP + traffic. Errors are returned (not raised) and stripped of any + embedded secrets before they propagate to the operator-visible + result dict. + """ + try: + await avalidate_outbound_url(url) + except UnsafeURLError as err: + return None, f"Unsafe URL: {redact_exc(err)}" + try: + async with self._session.get( + url, headers=headers or {}, timeout=_DOWNLOAD_TIMEOUT, + ) as resp: + if resp.status != 200: + return None, f"HTTP {resp.status}" + return await resp.read(), None + except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err: + return None, redact_exc(err) + async def _fetch_bytes( self, url: str, @@ -162,13 +249,11 @@ class TelegramClient: """Return ``(data, error_msg)``. Uses ``preloaded`` bytes if provided.""" if preloaded is not None: return preloaded, None - try: - async with self._session.get(self._resolve_url(url), headers=headers or {}) as resp: - if resp.status != 200: - return None, f"HTTP {resp.status}" - return await resp.read(), None - except aiohttp.ClientError as err: - return None, str(err) + return await self._safe_get(self._resolve_url(url), headers) + + # ------------------------------------------------------------------ + # Cached / fresh single-media send paths + # ------------------------------------------------------------------ async def _send_from_cache( self, @@ -181,14 +266,19 @@ class TelegramClient: ) -> NotificationResult | None: """POST a file_id reference. Return None on transient error so the caller can fall through to a fresh upload.""" - payload: dict[str, Any] = {"chat_id": chat_id, kind.form_field: file_id, "parse_mode": parse_mode} + payload: dict[str, Any] = { + "chat_id": chat_id, + kind.form_field: file_id, + "parse_mode": parse_mode, + } if caption: - payload["caption"] = caption - if reply_to_message_id: + payload["caption"] = _truncate(caption, TELEGRAM_MAX_CAPTION_LENGTH) + if reply_to_message_id is not None: payload["reply_parameters"] = {"message_id": reply_to_message_id} - telegram_url = f"{TELEGRAM_API_BASE_URL}{self._token}/{kind.api_method}" try: - async with self._session.post(telegram_url, json=payload) as response: + async with self._session.post( + self._api_url(kind.api_method), json=payload, timeout=_API_TIMEOUT, + ) as response: result = await response.json() if response.status == 200 and result.get("ok"): return { @@ -196,19 +286,15 @@ class TelegramClient: "message_id": result.get("result", {}).get("message_id"), "cached": True, } - # Non-ok from a cached send — file_id stale or file deleted on - # Telegram's side. Log at DEBUG so operators who are hunting - # "why didn't the cached send work?" can see it, but the - # caller will fall through to a fresh upload. _LOGGER.debug( "Telegram %s (cached) returned non-ok: status=%s code=%s desc=%r — falling back to fresh upload", kind.api_method, response.status, result.get("error_code"), - result.get("description"), + redact(str(result.get("description", ""))), ) - except aiohttp.ClientError as err: + except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err: _LOGGER.debug( "Telegram %s (cached) transport error — falling back to fresh upload: %s", - kind.api_method, err, + kind.api_method, redact_exc(err), ) return None @@ -227,21 +313,26 @@ class TelegramClient: thumbhash: str | None, ) -> NotificationResult: """Multipart-upload ``data`` to Telegram and cache the returned file_id.""" + capped_caption = _truncate(caption, TELEGRAM_MAX_CAPTION_LENGTH) if caption else None + def _build_form() -> FormData: f = FormData() f.add_field("chat_id", chat_id) f.add_field(kind.form_field, data, filename=filename, content_type=content_type) f.add_field("parse_mode", parse_mode) - if caption: - f.add_field("caption", caption) - if reply_to_message_id: + if capped_caption: + f.add_field("caption", capped_caption) + if reply_to_message_id is not None: f.add_field("reply_parameters", json.dumps({"message_id": reply_to_message_id})) return f - telegram_url = f"{TELEGRAM_API_BASE_URL}{self._token}/{kind.api_method}" for attempt in range(1, _TG_429_MAX_ATTEMPTS + 1): try: - async with self._session.post(telegram_url, data=_build_form()) as response: + async with self._session.post( + self._api_url(kind.api_method), + data=_build_form(), + timeout=_UPLOAD_TIMEOUT, + ) as response: result = await response.json() if response.status == 200 and result.get("ok"): res = result.get("result", {}) @@ -267,23 +358,28 @@ class TelegramClient: _LOGGER.error( "Telegram %s failed: status=%s code=%s desc=%r bytes=%d", kind.api_method, response.status, result.get("error_code"), - result.get("description", "Unknown"), len(data), + redact(str(result.get("description", "Unknown"))), len(data), ) - return {"success": False, "error": result.get("description", "Unknown Telegram error")} - except aiohttp.ClientError as err: + return { + "success": False, + "error": redact(str(result.get("description", "Unknown Telegram error"))), + } + except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err: _LOGGER.error( "Telegram %s transport error (bytes=%d): %s", - kind.api_method, len(data), err, exc_info=True, + kind.api_method, len(data), redact_exc(err), ) - return {"success": False, "error": str(err)} - # All attempts exhausted via 429 — should be unreachable, but keep - # an explicit error path so we never return None. + return {"success": False, "error": redact_exc(err)} return {"success": False, "error": "Telegram rate limit: max retries exhausted"} + # ------------------------------------------------------------------ + # Public send entry points + # ------------------------------------------------------------------ + async def send_notification( self, chat_id: str, - assets: list[dict[str, str]] | None = None, + assets: list[dict[str, Any]] | None = None, caption: str | None = None, reply_to_message_id: int | None = None, disable_web_page_preview: bool | None = None, @@ -300,9 +396,9 @@ class TelegramClient: disable_web_page_preview, parse_mode, ) - typing_task = None + keepalive: _KeepaliveHandle | None = None if chat_action: - typing_task = self.start_chat_action_keepalive(chat_id, chat_action) + keepalive = self.start_chat_action_keepalive(chat_id, chat_action) try: if len(assets) == 1 and assets[0].get("type") == "photo": @@ -325,19 +421,13 @@ class TelegramClient: url = assets[0].get("url") if not url: return {"success": False, "error": "Missing 'url' for document"} - data = assets[0].get("data") + data, err = await self._fetch_bytes( + url, assets[0].get("headers"), assets[0].get("data"), + ) if data is None: - try: - download_url = self._resolve_url(url) - dl_headers = assets[0].get("headers") or {} - async with self._session.get(download_url, headers=dl_headers) as resp: - if resp.status != 200: - return {"success": False, "error": f"Failed to download media: HTTP {resp.status}"} - data = await resp.read() - except aiohttp.ClientError as err: - return {"success": False, "error": f"Failed to download media: {err}"} + return {"success": False, "error": f"Failed to download media: {err}"} if max_asset_data_size is not None and len(data) > max_asset_data_size: - return {"success": False, "error": f"Media size exceeds limit"} + return {"success": False, "error": "Media size exceeds limit"} filename = url.split("/")[-1].split("?")[0] or "file" return await self._send_document( chat_id, data, filename, caption, reply_to_message_id, @@ -351,7 +441,7 @@ class TelegramClient: send_large_photos_as_documents, ) finally: - await self.stop_keepalive(typing_task) + await self.stop_keepalive(keepalive) async def send_message( self, @@ -361,108 +451,128 @@ class TelegramClient: disable_web_page_preview: bool | None = None, parse_mode: str = "HTML", ) -> NotificationResult: - telegram_url = f"{TELEGRAM_API_BASE_URL}{self._token}/sendMessage" + if not text: + _LOGGER.warning("send_message called with empty text — using placeholder") + body = _truncate(text or "Notification", TELEGRAM_MAX_TEXT_LENGTH) payload: dict[str, Any] = { "chat_id": chat_id, - "text": text or "Notification", + "text": body, "parse_mode": parse_mode, } - if reply_to_message_id: + if reply_to_message_id is not None: payload["reply_parameters"] = {"message_id": reply_to_message_id} if disable_web_page_preview: payload["link_preview_options"] = {"is_disabled": True} + url = self._api_url("sendMessage") try: - async with self._session.post(telegram_url, json=payload) as response: + async with self._session.post(url, json=payload, timeout=_API_TIMEOUT) as response: result = await response.json() if response.status == 200 and result.get("ok"): return {"success": True, "message_id": result.get("result", {}).get("message_id")} - # Retry without parse_mode on parse errors - desc = str(result.get("description", "")) - if "parse" in desc.lower(): - # Log loudly: a parse failure means the template author (or - # an asset field) is producing malformed HTML. Silent - # fallback hides bugs and makes XSS-via-unescaped-field - # harder to spot. Do not log the full payload — it may - # contain secrets. + if _is_parse_error(response.status, result): _LOGGER.warning( "Telegram rejected parse_mode=%s (%r); retrying as plain text. " "Check template output for unescaped characters.", - payload.get("parse_mode"), desc, + payload.get("parse_mode"), + redact(str(result.get("description", ""))), ) payload.pop("parse_mode", None) - async with self._session.post(telegram_url, json=payload) as retry_resp: + # Brief backoff so retry doesn't dogpile on a 4xx. + await asyncio.sleep(0.1) + async with self._session.post(url, json=payload, timeout=_API_TIMEOUT) as retry_resp: retry_result = await retry_resp.json() if retry_resp.status == 200 and retry_result.get("ok"): - return {"success": True, "message_id": retry_result.get("result", {}).get("message_id")} + return { + "success": True, + "message_id": retry_result.get("result", {}).get("message_id"), + } _LOGGER.error( "Telegram sendMessage failed: status=%s code=%s desc=%r", response.status, result.get("error_code"), - result.get("description", "Unknown"), + redact(str(result.get("description", "Unknown"))), ) - return {"success": False, "error": result.get("description", "Unknown Telegram error"), "error_code": result.get("error_code")} - except aiohttp.ClientError as err: - _LOGGER.error("Telegram sendMessage transport error: %s", err, exc_info=True) - return {"success": False, "error": str(err)} + return { + "success": False, + "error": redact(str(result.get("description", "Unknown Telegram error"))), + "error_code": result.get("error_code"), + } + except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err: + _LOGGER.error("Telegram sendMessage transport error: %s", redact_exc(err)) + return {"success": False, "error": redact_exc(err)} async def send_chat_action(self, chat_id: str, action: str = "typing") -> bool: - telegram_url = f"{TELEGRAM_API_BASE_URL}{self._token}/sendChatAction" try: - async with self._session.post(telegram_url, json={"chat_id": chat_id, "action": action}) as response: + async with self._session.post( + self._api_url("sendChatAction"), + json={"chat_id": chat_id, "action": action}, + timeout=_API_TIMEOUT, + ) as response: result = await response.json() return response.status == 200 and result.get("ok", False) - except aiohttp.ClientError: + except (aiohttp.ClientError, asyncio.TimeoutError, OSError): return False - def start_chat_action_keepalive(self, chat_id: str, action: str = "typing") -> asyncio.Task: - """Repeatedly post ``action`` every 4s until stopped. + def start_chat_action_keepalive( + self, chat_id: str, action: str = "typing", + ) -> _KeepaliveHandle: + """Repeatedly post ``action`` until stopped. - Telegram chat actions expire after ~5s, so callers that want the hint - to persist through longer work (fetching assets, multi-chunk uploads) - need a keep-alive. - - The returned task carries an attached ``stop_event`` (``asyncio.Event``). - Stop cleanly via :meth:`stop_keepalive` — setting the event before - cancellation prevents the loop from firing one last ``sendChatAction`` - after the caller's final user-visible message, which would otherwise - leave a phantom indicator hanging for ~5s. + Returns a :class:`_KeepaliveHandle` that bundles the asyncio.Task + and its stop event. Stop cleanly via :meth:`stop_keepalive` — + setting the event before cancellation prevents the loop from + firing one last ``sendChatAction`` after the caller's final + user-visible message, which would otherwise leave a phantom + indicator hanging. """ - stop_event = asyncio.Event() + handle = _KeepaliveHandle() async def action_loop() -> None: try: - while not stop_event.is_set(): + while not handle.stop_event.is_set(): await self.send_chat_action(chat_id, action) try: - await asyncio.wait_for(stop_event.wait(), timeout=4) + await asyncio.wait_for( + handle.stop_event.wait(), timeout=_KEEPALIVE_REFRESH_S, + ) except asyncio.TimeoutError: - pass # 4s elapsed, refresh the action + pass # refresh window elapsed except asyncio.CancelledError: pass - task: asyncio.Task = asyncio.create_task(action_loop()) - task.stop_event = stop_event # type: ignore[attr-defined] - return task + handle.task = asyncio.create_task(action_loop()) + return handle @staticmethod - async def stop_keepalive(task: asyncio.Task | None) -> None: - """Stop a keepalive task started by :meth:`start_chat_action_keepalive`. + async def stop_keepalive(handle: _KeepaliveHandle | asyncio.Task | None) -> None: + """Stop a keepalive started by :meth:`start_chat_action_keepalive`. - Sets the attached stop event before cancelling so the loop won't - fire another ``sendChatAction`` after the caller's final message - landed at Telegram. + Accepts the new ``_KeepaliveHandle`` form as well as a bare + ``asyncio.Task`` (back-compat for callers that may have stored + one before the refactor). """ + if handle is None: + return + if isinstance(handle, _KeepaliveHandle): + handle.stop_event.set() + task = handle.task + else: + task = handle + stop_event = getattr(task, "stop_event", None) + if stop_event is not None: + stop_event.set() if task is None: return - stop_event = getattr(task, "stop_event", None) - if stop_event is not None: - stop_event.set() task.cancel() try: await task except asyncio.CancelledError: pass + # ------------------------------------------------------------------ + # Single-media helpers + # ------------------------------------------------------------------ + async def _send_photo( self, chat_id: str, url: str | None, caption: str | None = None, reply_to_message_id: int | None = None, parse_mode: str = "HTML", @@ -577,403 +687,472 @@ class TelegramClient: cache, key, thumbhash, ) + # ------------------------------------------------------------------ + # Media group + # ------------------------------------------------------------------ + async def _send_media_group( - self, chat_id: str, assets: list[dict[str, str]], + self, chat_id: str, assets: list[dict[str, Any]], caption: str | None = None, reply_to_message_id: int | None = None, max_group_size: int = 10, chunk_delay: int = 0, parse_mode: str = "HTML", max_asset_data_size: int | None = None, send_large_photos_as_documents: bool = False, ) -> NotificationResult: - chunks = [assets[i:i + max_group_size] for i in range(0, len(assets), max_group_size)] - all_message_ids: list = [] + # Telegram rejects mixed photo/video + document in a single + # sendMediaGroup. Split before chunking so a malformed input + # batch can't poison the entire send. + partitions = self._partition_media_by_kind(assets) - for chunk_idx, chunk in enumerate(chunks): - if chunk_idx > 0 and chunk_delay > 0: - await asyncio.sleep(chunk_delay / 1000) + all_message_ids: list[int] = [] + first_chunk_overall = True + for partition in partitions: + chunks = [ + partition[i:i + max_group_size] + for i in range(0, len(partition), max_group_size) + ] + for chunk_idx, chunk in enumerate(chunks): + if not first_chunk_overall and chunk_delay > 0: + await asyncio.sleep(chunk_delay / 1000) - if len(chunk) == 1: - item = chunk[0] - chunk_caption = caption if chunk_idx == 0 else None - chunk_reply = reply_to_message_id if chunk_idx == 0 else None - if item.get("type") == "photo": - result = await self._send_photo(chat_id, item.get("url"), chunk_caption, chunk_reply, parse_mode, max_asset_data_size, send_large_photos_as_documents, item.get("content_type"), item.get("cache_key"), download_headers=item.get("headers"), preloaded_data=item.get("data")) - elif item.get("type") == "video": - result = await self._send_video(chat_id, item.get("url"), chunk_caption, chunk_reply, parse_mode, max_asset_data_size, item.get("content_type"), item.get("cache_key"), download_headers=item.get("headers"), preloaded_data=item.get("data")) - else: - continue - if not result.get("success"): - result["failed_at_chunk"] = chunk_idx + 1 - return result - all_message_ids.append(result.get("message_id")) - continue - - # Multi-item: download all, build form, send media group. - # Attachments are recorded separately so we can rebuild FormData on - # 429 retry — aiohttp.FormData is single-use after a request. - attachments: list[tuple[str, bytes, str, str]] = [] # (name, data, filename, content_type) - media_json = [] - upload_idx = 0 - # Track cache info per media_json entry (in order) so we can map - # Telegram response items back to cache keys for newly uploaded items. - # None = already cached (no need to store), tuple = needs caching. - # Tuple is (cache_key, media_type, thumbhash, uploaded_size). - media_cache_info: list[tuple[str, str, str | None, int] | None] = [] - - # Resolve cache hits and collect download tasks in parallel. - # Each drop site logs the reason — otherwise a filtered asset - # disappears silently and the media group silently shrinks. - async def _fetch_asset(idx: int, item: dict) -> tuple[int, dict | None, bytes | None]: - """Return (index, cache_entry_or_None, downloaded_bytes_or_None).""" - url = item.get("url") - if not url: - _LOGGER.warning("Media skipped: missing url (idx=%d type=%s)", idx, item.get("type")) - return idx, None, None - media_type = item.get("type", "photo") - custom_cache_key = item.get("cache_key") - - ck = custom_cache_key or extract_asset_id_from_url(url) or url - ck_is_asset = is_asset_cache_key(ck) - item_cache = self._get_cache_for_key(ck, ck_is_asset) - bare_ck = asset_id_from_cache_key(ck) if ck_is_asset else ck - item_thumbhash = self._thumbhash_resolver(bare_ck) if ck_is_asset and self._thumbhash_resolver else None - cached = item_cache.get(ck, thumbhash=item_thumbhash) if item_cache else None - - if cached and cached.get("file_id"): - return idx, cached, None - - # Use preloaded bytes if the dispatcher already fetched them - preloaded = item.get("data") - if preloaded is not None: - data = preloaded - if max_asset_data_size and len(data) > max_asset_data_size: - _LOGGER.warning( - "Media skipped: preloaded size %d exceeds max_asset_data_size %d (idx=%d type=%s url=%s)", - len(data), max_asset_data_size, idx, media_type, url, + # Single-item chunk → use the simpler send_photo/video path. + if len(chunk) == 1: + item = chunk[0] + chunk_caption = caption if first_chunk_overall else None + chunk_reply = reply_to_message_id if first_chunk_overall else None + if item.get("type") == "photo": + result = await self._send_photo( + chat_id, item.get("url"), chunk_caption, chunk_reply, parse_mode, + max_asset_data_size, send_large_photos_as_documents, + item.get("content_type"), item.get("cache_key"), + download_headers=item.get("headers"), + preloaded_data=item.get("data"), ) - return idx, None, None - if media_type == "video" and len(data) > TELEGRAM_MAX_VIDEO_SIZE: - _LOGGER.warning( - "Media skipped: preloaded video %d bytes exceeds Telegram limit %d (idx=%d url=%s)", - len(data), TELEGRAM_MAX_VIDEO_SIZE, idx, url, + elif item.get("type") == "video": + result = await self._send_video( + chat_id, item.get("url"), chunk_caption, chunk_reply, parse_mode, + max_asset_data_size, + item.get("content_type"), item.get("cache_key"), + download_headers=item.get("headers"), + preloaded_data=item.get("data"), ) - return idx, None, None - if media_type == "photo": - exceeds, reason, _, _ = check_photo_limits(data) - if exceeds: - _LOGGER.warning( - "Media skipped: preloaded photo %s (idx=%d url=%s)", - reason, idx, url, - ) - return idx, None, None - return idx, None, data - - try: - download_url = self._resolve_url(url) - dl_headers = item.get("headers") or {} - async with self._session.get(download_url, headers=dl_headers) as resp: - if resp.status != 200: - _LOGGER.warning( - "Media skipped: download HTTP %d (idx=%d type=%s url=%s)", - resp.status, idx, media_type, url, - ) - return idx, None, None - data = await resp.read() - if max_asset_data_size and len(data) > max_asset_data_size: - _LOGGER.warning( - "Media skipped: downloaded size %d exceeds max_asset_data_size %d (idx=%d type=%s url=%s)", - len(data), max_asset_data_size, idx, media_type, url, - ) - return idx, None, None - if media_type == "video" and len(data) > TELEGRAM_MAX_VIDEO_SIZE: - _LOGGER.warning( - "Media skipped: video %d bytes exceeds Telegram %d-byte limit (idx=%d url=%s)", - len(data), TELEGRAM_MAX_VIDEO_SIZE, idx, url, - ) - return idx, None, None - if media_type == "photo": - exceeds, reason, _, _ = check_photo_limits(data) - if exceeds: - _LOGGER.warning( - "Media skipped: photo %s (idx=%d url=%s)", - reason, idx, url, - ) - return idx, None, None - return idx, None, data - except aiohttp.ClientError as err: - _LOGGER.warning( - "Media skipped: download failed (idx=%d type=%s url=%s): %s", - idx, media_type, url, err, - ) - return idx, None, None - - results = await asyncio.gather( - *(_fetch_asset(i, item) for i, item in enumerate(chunk)) - ) - - for idx, cached_entry, data in results: - item = chunk[idx] - url = item.get("url") - if not url: - continue - media_type = item.get("type", "photo") - custom_cache_key = item.get("cache_key") - - if cached_entry and cached_entry.get("file_id"): - mij: dict[str, Any] = {"type": media_type, "media": cached_entry["file_id"]} - media_cache_info.append(None) # already cached - elif data is not None: - attach_name = f"file{upload_idx}" - ct = item.get("content_type") or ("image/jpeg" if media_type == "photo" else "video/mp4") - ext = "jpg" if media_type == "photo" else "mp4" - attachments.append((attach_name, data, f"media_{idx}.{ext}", ct)) - mij = {"type": media_type, "media": f"attach://{attach_name}"} - upload_idx += 1 - # Record cache key so we can store file_id from response - ck = custom_cache_key or extract_asset_id_from_url(url) or url - ck_is_asset = is_asset_cache_key(ck) - bare_ck = asset_id_from_cache_key(ck) if ck_is_asset else ck - th = self._thumbhash_resolver(bare_ck) if ck_is_asset and self._thumbhash_resolver else None - media_cache_info.append((ck, media_type, th, len(data))) - else: + else: + first_chunk_overall = False + continue + first_chunk_overall = False + if not result.get("success"): + result["failed_at_chunk"] = chunk_idx + 1 + return result + if result.get("message_id") is not None: + all_message_ids.append(result["message_id"]) continue - if idx == 0 and chunk_idx == 0 and caption: - mij["caption"] = caption - mij["parse_mode"] = parse_mode - media_json.append(mij) - - if not media_json: - # Every asset in this chunk was filtered out (size, download - # failure, etc.). Without this log, sendMediaGroup returns - # success=True with zero message_ids and nobody knows why - # the user sees only the text reply and no media. - _LOGGER.warning( - "sendMediaGroup skipped — chunk %d/%d had %d input items but 0 usable (all filtered/failed)", - chunk_idx + 1, len(chunks), len(chunk), + items = await self._build_media_items( + chunk, max_asset_data_size, caption if first_chunk_overall else None, + parse_mode, ) - continue - - telegram_url = f"{TELEGRAM_API_BASE_URL}{self._token}/sendMediaGroup" - - def _build_form() -> FormData: - f = FormData() - f.add_field("chat_id", chat_id) - if reply_to_message_id and chunk_idx == 0: - f.add_field("reply_parameters", json.dumps({"message_id": reply_to_message_id})) - for name, payload, filename, ct in attachments: - f.add_field(name, payload, filename=filename, content_type=ct) - f.add_field("media", json.dumps(media_json)) - return f - - chunk_failed_result: dict[str, Any] | None = None - for attempt in range(1, _TG_429_MAX_ATTEMPTS + 1): - try: - async with self._session.post(telegram_url, data=_build_form()) as response: - result = await response.json() - if response.status == 200 and result.get("ok"): - result_msgs = result.get("result", []) - all_message_ids.extend(msg.get("message_id") for msg in result_msgs) - - # Cache file_ids from response — map by position - cache_entries: list[tuple[str, str, str, str | None, int | None]] = [] - for i, msg in enumerate(result_msgs): - if i >= len(media_cache_info): - break - info = media_cache_info[i] - if info is None: - continue # was a cache hit, skip - ck, mt, th, sz = info - file_id = None - if msg.get("photo"): - file_id = msg["photo"][-1].get("file_id") - elif msg.get("video"): - file_id = msg["video"].get("file_id") - elif msg.get("document"): - file_id = msg["document"].get("file_id") - if file_id: - cache_entries.append((ck, file_id, mt, th, sz)) - if cache_entries: - # All entries in a chunk share the same cache backend - eff_cache = self._get_cache_for_key(cache_entries[0][0], is_asset_cache_key(cache_entries[0][0])) - if eff_cache: - await eff_cache.async_set_many(cache_entries) - break # chunk succeeded - - if _is_rate_limited(response.status, result) and attempt < _TG_429_MAX_ATTEMPTS: - retry_after = _extract_retry_after(result) or 1 - wait_s = min(retry_after + 1, _TG_429_MAX_WAIT_S) - _LOGGER.warning( - "Telegram sendMediaGroup 429 (retry_after=%ds, attempt %d/%d) chunk=%d/%d items=%d — sleeping %ds", - retry_after, attempt, _TG_429_MAX_ATTEMPTS, - chunk_idx + 1, len(chunks), len(media_json), wait_s, - ) - await asyncio.sleep(wait_s) - continue - - _LOGGER.error( - "Telegram sendMediaGroup failed: status=%s code=%s desc=%r chunk=%d/%d items=%d", - response.status, result.get("error_code"), - result.get("description", "Unknown"), - chunk_idx + 1, len(chunks), len(media_json), - ) - chunk_failed_result = { - "success": False, - "error": result.get("description", "Unknown"), - "error_code": result.get("error_code"), - "failed_at_chunk": chunk_idx + 1, - } - break - except aiohttp.ClientError as err: - _LOGGER.error( - "Telegram sendMediaGroup transport error on chunk %d/%d (%d items): %s", - chunk_idx + 1, len(chunks), len(media_json), err, - exc_info=True, + if not items: + _LOGGER.warning( + "sendMediaGroup skipped — chunk %d/%d had %d input items but 0 usable (all filtered/failed)", + chunk_idx + 1, len(chunks), len(chunk), ) - return {"success": False, "error": str(err), "failed_at_chunk": chunk_idx + 1} + first_chunk_overall = False + continue - if chunk_failed_result is not None: - return chunk_failed_result + chunk_msg_ids, chunk_err = await self._post_media_group( + chat_id, items, reply_to_message_id if first_chunk_overall else None, + chunk_idx, len(chunks), + ) + first_chunk_overall = False + if chunk_err is not None: + return chunk_err + all_message_ids.extend(chunk_msg_ids) - # Distinguish "posted something" from "posted nothing" so the caller - # can surface an ERROR when a command produced a caption reply but no - # media ever reached Telegram. if not all_message_ids: _LOGGER.warning( - "sendMediaGroup completed with 0 message_ids across %d chunk(s) — nothing was delivered", - len(chunks), + "sendMediaGroup completed with 0 message_ids — nothing was delivered", ) - return {"success": False, "error": "no_items_delivered", "chunks_sent": len(chunks)} - return {"success": True, "message_ids": all_message_ids, "chunks_sent": len(chunks)} + return {"success": False, "error": "no_items_delivered"} + return {"success": True, "message_ids": all_message_ids} + + @staticmethod + def _partition_media_by_kind( + assets: list[dict[str, Any]], + ) -> list[list[dict[str, Any]]]: + """Split assets into runs of compatible kinds. + + Telegram allows photos+videos in the same media group but rejects + mixing those with documents. We preserve user-provided order + within each partition so the visual sequence the operator + configured is honored. + """ + partitions: list[list[dict[str, Any]]] = [] + current: list[dict[str, Any]] = [] + current_doc = False + for asset in assets: + kind = asset.get("type") or "photo" + is_doc = kind == "document" + if current and current_doc != is_doc: + partitions.append(current) + current = [] + current.append(asset) + current_doc = is_doc + if current: + partitions.append(current) + return partitions + + async def _build_media_items( + self, + chunk: list[dict[str, Any]], + max_asset_data_size: int | None, + first_caption: str | None, + parse_mode: str, + ) -> list[_MediaItem]: + """Fetch + filter a chunk and return aligned media-group items. + + Concurrency is bounded by ``_MEDIA_FETCH_CONCURRENCY`` so peak + memory stays predictable. Per-fetch exceptions are isolated via + ``return_exceptions=True`` so a single failed download cannot + cancel its peers. + """ + sem = asyncio.Semaphore(_MEDIA_FETCH_CONCURRENCY) + + async def fetch(idx: int, item: dict[str, Any]) -> tuple[int, dict | None, bytes | None]: + url = item.get("url") + if not url: + _LOGGER.warning("Media skipped: missing url (idx=%d type=%s)", idx, item.get("type")) + return idx, None, None + media_type = item.get("type", "photo") + custom_cache_key = item.get("cache_key") + + ck = custom_cache_key or extract_asset_id_from_url(url) or url + ck_is_asset = is_asset_cache_key(ck) + item_cache = self._get_cache_for_key(ck, ck_is_asset) + bare_ck = asset_id_from_cache_key(ck) if ck_is_asset else ck + item_thumbhash = ( + self._thumbhash_resolver(bare_ck) + if ck_is_asset and self._thumbhash_resolver else None + ) + cached = item_cache.get(ck, thumbhash=item_thumbhash) if item_cache else None + if cached and cached.get("file_id"): + return idx, cached, None + + preloaded = item.get("data") + data: bytes | None + if preloaded is not None: + data = preloaded + else: + async with sem: + data, err = await self._safe_get(self._resolve_url(url), item.get("headers")) + if data is None: + _LOGGER.warning( + "Media skipped: download failed (idx=%d type=%s): %s", + idx, media_type, err, + ) + return idx, None, None + + if max_asset_data_size and len(data) > max_asset_data_size: + _LOGGER.warning( + "Media skipped: size %d exceeds max_asset_data_size %d (idx=%d type=%s)", + len(data), max_asset_data_size, idx, media_type, + ) + return idx, None, None + if media_type == "video" and len(data) > TELEGRAM_MAX_VIDEO_SIZE: + _LOGGER.warning( + "Media skipped: video %d bytes exceeds Telegram limit %d (idx=%d)", + len(data), TELEGRAM_MAX_VIDEO_SIZE, idx, + ) + return idx, None, None + if media_type == "photo": + exceeds, reason, _, _ = check_photo_limits(data) + if exceeds: + _LOGGER.warning( + "Media skipped: photo %s (idx=%d)", reason, idx, + ) + return idx, None, None + return idx, None, data + + raw = await asyncio.gather( + *(fetch(i, item) for i, item in enumerate(chunk)), + return_exceptions=True, + ) + results: list[tuple[int, dict | None, bytes | None]] = [] + for entry in raw: + if isinstance(entry, Exception): + _LOGGER.warning("Media fetch raised: %s", redact_exc(entry)) + continue + results.append(entry) + + items: list[_MediaItem] = [] + upload_idx = 0 + for idx, cached_entry, data in results: + item = chunk[idx] + url = item.get("url") + if not url: + continue + media_type = item.get("type") or "photo" + custom_cache_key = item.get("cache_key") + + if cached_entry and cached_entry.get("file_id"): + mij: dict[str, Any] = {"type": media_type, "media": cached_entry["file_id"]} + cache_info: tuple[str, str, str | None, int] | None = None + attachment: tuple[str, bytes, str, str] | None = None + elif data is not None: + attach_name = f"file{upload_idx}" + ct = item.get("content_type") or ("image/jpeg" if media_type == "photo" else "video/mp4") + ext = "jpg" if media_type == "photo" else "mp4" + attachment = (attach_name, data, f"media_{idx}.{ext}", ct) + mij = {"type": media_type, "media": f"attach://{attach_name}"} + upload_idx += 1 + ck = custom_cache_key or extract_asset_id_from_url(url) or url + ck_is_asset = is_asset_cache_key(ck) + bare_ck = asset_id_from_cache_key(ck) if ck_is_asset else ck + th = ( + self._thumbhash_resolver(bare_ck) + if ck_is_asset and self._thumbhash_resolver else None + ) + cache_info = (ck, media_type, th, len(data)) + else: + continue + + if first_caption and not items: + # Only the first usable item in the first chunk receives + # the caption, per Telegram's media-group semantics. + mij["caption"] = _truncate(first_caption, TELEGRAM_MAX_CAPTION_LENGTH) + mij["parse_mode"] = parse_mode + + items.append(_MediaItem(media_json=mij, cache_info=cache_info, attachment=attachment)) + return items + + async def _post_media_group( + self, + chat_id: str, + items: list[_MediaItem], + reply_to_message_id: int | None, + chunk_idx: int, + total_chunks: int, + ) -> tuple[list[int], dict[str, Any] | None]: + """POST the media group with bounded 429 retry. + + Returns ``(message_ids, None)`` on success or ``([], err_dict)``. + """ + media_json = [it.media_json for it in items] + attachments = [it.attachment for it in items if it.attachment] + cache_infos = [it.cache_info for it in items] + + def _build_form() -> FormData: + f = FormData() + f.add_field("chat_id", chat_id) + if reply_to_message_id is not None: + f.add_field("reply_parameters", json.dumps({"message_id": reply_to_message_id})) + for name, payload, filename, ct in attachments: + f.add_field(name, payload, filename=filename, content_type=ct) + f.add_field("media", json.dumps(media_json)) + return f + + for attempt in range(1, _TG_429_MAX_ATTEMPTS + 1): + try: + async with self._session.post( + self._api_url("sendMediaGroup"), + data=_build_form(), + timeout=_UPLOAD_TIMEOUT, + ) as response: + result = await response.json() + if response.status == 200 and result.get("ok"): + result_msgs = result.get("result", []) + msg_ids = [ + msg.get("message_id") for msg in result_msgs + if msg.get("message_id") is not None + ] + await self._cache_media_group_response(result_msgs, cache_infos) + return msg_ids, None + + if _is_rate_limited(response.status, result) and attempt < _TG_429_MAX_ATTEMPTS: + retry_after = _extract_retry_after(result) or 1 + wait_s = min(retry_after + 1, _TG_429_MAX_WAIT_S) + _LOGGER.warning( + "Telegram sendMediaGroup 429 (retry_after=%ds, attempt %d/%d) chunk=%d/%d items=%d — sleeping %ds", + retry_after, attempt, _TG_429_MAX_ATTEMPTS, + chunk_idx + 1, total_chunks, len(media_json), wait_s, + ) + await asyncio.sleep(wait_s) + continue + + _LOGGER.error( + "Telegram sendMediaGroup failed: status=%s code=%s desc=%r chunk=%d/%d items=%d", + response.status, result.get("error_code"), + redact(str(result.get("description", "Unknown"))), + chunk_idx + 1, total_chunks, len(media_json), + ) + return [], { + "success": False, + "error": redact(str(result.get("description", "Unknown"))), + "error_code": result.get("error_code"), + "failed_at_chunk": chunk_idx + 1, + } + except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err: + _LOGGER.error( + "Telegram sendMediaGroup transport error on chunk %d/%d (%d items): %s", + chunk_idx + 1, total_chunks, len(media_json), redact_exc(err), + ) + return [], { + "success": False, + "error": redact_exc(err), + "failed_at_chunk": chunk_idx + 1, + } + + return [], { + "success": False, + "error": "Telegram rate limit: max retries exhausted", + "failed_at_chunk": chunk_idx + 1, + } + + async def _cache_media_group_response( + self, + result_msgs: list[dict[str, Any]], + cache_infos: list[tuple[str, str, str | None, int] | None], + ) -> None: + """Persist file_ids returned by sendMediaGroup. + + Entries align by position: ``cache_infos[i]`` corresponds to + ``result_msgs[i]``. ``None`` means the item was a cache hit and + does not need re-storing. + """ + cache_entries: list[tuple[str, str, str, str | None, int | None]] = [] + for i, msg in enumerate(result_msgs): + if i >= len(cache_infos): + break + info = cache_infos[i] + if info is None: + continue + ck, mt, th, sz = info + file_id: str | None = None + if msg.get("photo"): + file_id = msg["photo"][-1].get("file_id") + elif msg.get("video"): + file_id = msg["video"].get("file_id") + elif msg.get("document"): + file_id = msg["document"].get("file_id") + if file_id: + cache_entries.append((ck, file_id, mt, th, sz)) + if cache_entries: + eff_cache = self._get_cache_for_key( + cache_entries[0][0], is_asset_cache_key(cache_entries[0][0]), + ) + if eff_cache: + await eff_cache.async_set_many(cache_entries) # ------------------------------------------------------------------ # Bot management methods # ------------------------------------------------------------------ + async def _api_call( + self, + method: str, + *, + json_payload: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + timeout: aiohttp.ClientTimeout | None = None, + http_method: str = "POST", + ) -> dict[str, Any]: + """Generic Bot-API call with redacted errors and explicit timeout.""" + url = self._api_url(method) + try: + async with self._session.request( + http_method, url, json=json_payload, params=params, + timeout=timeout or _API_TIMEOUT, + ) as resp: + data = await resp.json() + if data.get("ok"): + return {"success": True, "result": data.get("result", {})} + return { + "success": False, + "error": redact(str(data.get("description", "Unknown error"))), + } + except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err: + return {"success": False, "error": redact_exc(err)} + async def get_me(self) -> dict[str, Any]: """Call getMe to verify the bot token and get bot info.""" - url = f"{TELEGRAM_API_BASE_URL}{self._token}/getMe" - try: - async with self._session.get(url) as resp: - data = await resp.json() - if data.get("ok"): - return {"success": True, "result": data.get("result", {})} - return {"success": False, "error": data.get("description", "Unknown error")} - except aiohttp.ClientError as err: - return {"success": False, "error": str(err)} + return await self._api_call("getMe", http_method="GET") async def get_chat(self, chat_id: str) -> dict[str, Any]: - """Call getChat to fetch up-to-date chat metadata (title, username, type, etc.).""" - url = f"{TELEGRAM_API_BASE_URL}{self._token}/getChat" - try: - async with self._session.post(url, json={"chat_id": chat_id}) as resp: - data = await resp.json() - if data.get("ok"): - return {"success": True, "result": data.get("result", {})} - return {"success": False, "error": data.get("description", "Unknown error")} - except aiohttp.ClientError as err: - return {"success": False, "error": str(err)} + """Call getChat to fetch up-to-date chat metadata.""" + return await self._api_call("getChat", json_payload={"chat_id": chat_id}) async def get_webhook_info(self) -> dict[str, Any]: """Call getWebhookInfo to check current webhook status.""" - url = f"{TELEGRAM_API_BASE_URL}{self._token}/getWebhookInfo" - try: - async with self._session.get(url) as resp: - data = await resp.json() - if data.get("ok"): - return {"success": True, "result": data.get("result", {})} - return {"success": False, "error": data.get("description", "Unknown error")} - except aiohttp.ClientError as err: - return {"success": False, "error": str(err)} + return await self._api_call("getWebhookInfo", http_method="GET") async def set_webhook(self, webhook_url: str, secret: str | None = None) -> dict[str, Any]: """Register a webhook URL with Telegram.""" - url = f"{TELEGRAM_API_BASE_URL}{self._token}/setWebhook" payload: dict[str, Any] = {"url": webhook_url} if secret: payload["secret_token"] = secret - try: - async with self._session.post(url, json=payload) as resp: - data = await resp.json() - if data.get("ok"): - return {"success": True} - return {"success": False, "error": data.get("description", "Unknown error")} - except aiohttp.ClientError as err: - return {"success": False, "error": str(err)} + result = await self._api_call("setWebhook", json_payload=payload) + # Caller previously got plain ``{"success": True}``; preserve. + if result.get("success"): + return {"success": True} + return result async def delete_webhook(self) -> dict[str, Any]: """Remove the webhook from Telegram.""" - url = f"{TELEGRAM_API_BASE_URL}{self._token}/deleteWebhook" - try: - async with self._session.post(url) as resp: - data = await resp.json() - return {"success": data.get("ok", False)} - except aiohttp.ClientError as err: - return {"success": False, "error": str(err)} + result = await self._api_call("deleteWebhook") + if result.get("success"): + return {"success": True} + return result async def get_updates( self, offset: int | None = None, limit: int = 50, timeout: int = 0, ) -> dict[str, Any]: """Long-poll for updates via getUpdates.""" - url = f"{TELEGRAM_API_BASE_URL}{self._token}/getUpdates" + url = self._api_url("getUpdates") params: dict[str, Any] = { "timeout": timeout, "limit": limit, - "allowed_updates": '["message"]', + "allowed_updates": json.dumps(["message"]), } if offset is not None: params["offset"] = offset try: async with self._session.get( - url, params=params, timeout=aiohttp.ClientTimeout(total=max(10, timeout + 5)), + url, params=params, + timeout=aiohttp.ClientTimeout(total=max(10, timeout + 5)), ) as resp: data = await resp.json() if data.get("ok"): return {"success": True, "result": data.get("result", [])} - return {"success": False, "error": data.get("description", "Unknown error")} - except aiohttp.ClientError as err: - return {"success": False, "error": str(err)} + return { + "success": False, + "error": redact(str(data.get("description", "Unknown error"))), + } + except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err: + return {"success": False, "error": redact_exc(err)} async def set_my_commands( self, commands: list[dict[str, str]], language_code: str | None = None, scope: dict[str, Any] | None = None, ) -> dict[str, Any]: - """Register bot commands with BotFather API. - - ``scope`` is a Telegram BotCommandScope object (e.g. - ``{"type": "chat", "chat_id": 123}``). When provided, the - registration applies only to that scope. ``language_code`` and - ``scope`` may be combined to localize per-scope. - """ - url = f"{TELEGRAM_API_BASE_URL}{self._token}/setMyCommands" + """Register bot commands with the BotFather API.""" payload: dict[str, Any] = {"commands": commands} if language_code: payload["language_code"] = language_code if scope: payload["scope"] = scope - try: - async with self._session.post(url, json=payload) as resp: - data = await resp.json() - if data.get("ok"): - return {"success": True} - return {"success": False, "error": data.get("description", "Unknown error")} - except aiohttp.ClientError as err: - return {"success": False, "error": str(err)} + result = await self._api_call("setMyCommands", json_payload=payload) + if result.get("success"): + return {"success": True} + return result async def delete_my_commands( self, language_code: str | None = None, scope: dict[str, Any] | None = None, ) -> dict[str, Any]: - """Clear bot commands for the given scope/language via BotFather API.""" - url = f"{TELEGRAM_API_BASE_URL}{self._token}/deleteMyCommands" + """Clear bot commands for the given scope/language.""" payload: dict[str, Any] = {} if language_code: payload["language_code"] = language_code if scope: payload["scope"] = scope - try: - async with self._session.post(url, json=payload) as resp: - data = await resp.json() - if data.get("ok"): - return {"success": True} - return {"success": False, "error": data.get("description", "Unknown error")} - except aiohttp.ClientError as err: - return {"success": False, "error": str(err)} + result = await self._api_call("deleteMyCommands", json_payload=payload) + if result.get("success"): + return {"success": True} + return result diff --git a/packages/core/src/notify_bridge_core/notifications/telegram/media.py b/packages/core/src/notify_bridge_core/notifications/telegram/media.py index d5c630c..e9b65fc 100644 --- a/packages/core/src/notify_bridge_core/notifications/telegram/media.py +++ b/packages/core/src/notify_bridge_core/notifications/telegram/media.py @@ -2,20 +2,35 @@ from __future__ import annotations +import logging import re from typing import Any, Final from urllib.parse import urlparse +_LOGGER = logging.getLogger(__name__) + # Telegram constants TELEGRAM_API_BASE_URL: Final = "https://api.telegram.org/bot" TELEGRAM_MAX_PHOTO_SIZE: Final = 10 * 1024 * 1024 # 10 MB TELEGRAM_MAX_VIDEO_SIZE: Final = 50 * 1024 * 1024 # 50 MB TELEGRAM_MAX_DIMENSION_SUM: Final = 10000 +# Telegram message-text limit (sendMessage) and caption limit +# (sendPhoto/sendVideo/sendDocument/first item of sendMediaGroup). +TELEGRAM_MAX_TEXT_LENGTH: Final = 4096 +TELEGRAM_MAX_CAPTION_LENGTH: Final = 1024 -# Generic UUID pattern for asset IDs -_ASSET_ID_PATTERN = re.compile(r"^[a-f0-9-]{36}$") +# Strict canonical-UUID pattern (8-4-4-4-12) for asset IDs. The previous +# loose ``[a-f0-9-]{36}`` matched 36 hyphens / arbitrary digit groupings, +# which could collide across providers when used as a cache key. +_ASSET_ID_PATTERN = re.compile( + r"^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$", + re.IGNORECASE, +) # Cache key: "host:uuid" or bare "uuid" -_ASSET_CACHE_KEY_PATTERN = re.compile(r"^(?:[^:]+:)?[a-f0-9-]{36}$") +_ASSET_CACHE_KEY_PATTERN = re.compile( + r"^(?:[^:]+:)?[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$", + re.IGNORECASE, +) # URL patterns to extract asset IDs (generic enough for Immich-style URLs) _ASSET_ID_URL_PATTERNS = [ @@ -162,5 +177,10 @@ def check_photo_limits( return False, None, width, height except ImportError: return False, None, None, None - except Exception: + except (OSError, ValueError, MemoryError) as exc: + # PIL surfaces ``UnidentifiedImageError`` (subclass of OSError), + # truncated-image / decompression-bomb errors here. Log so a + # corrupt asset isn't silently passed to Telegram and rejected + # downstream with a less actionable error. + _LOGGER.warning("check_photo_limits: failed to inspect image (%d bytes): %s", len(data), exc) return False, None, None, None diff --git a/packages/core/src/notify_bridge_core/notifications/webhook/client.py b/packages/core/src/notify_bridge_core/notifications/webhook/client.py index 0f3cafd..864c4a4 100644 --- a/packages/core/src/notify_bridge_core/notifications/webhook/client.py +++ b/packages/core/src/notify_bridge_core/notifications/webhook/client.py @@ -7,37 +7,29 @@ from typing import Any import aiohttp -from ..ssrf import UnsafeURLError, avalidate_outbound_url +from ..http_base import HttpProviderClient _LOGGER = logging.getLogger(__name__) -_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30) +class WebhookClient(HttpProviderClient): + """Send JSON payloads to a webhook URL. -class WebhookClient: - """Send JSON payloads to a webhook URL.""" + The URL is SSRF-validated on every send (defense-in-depth: re-validating + catches DNS rebinding between calls and a misconfigured target). Headers + pass through :func:`safe_headers` so a target config can't inject + framing/hop-by-hop headers like ``Host`` or ``Transfer-Encoding``. + """ - def __init__(self, session: aiohttp.ClientSession, url: str, headers: dict[str, str] | None = None) -> None: - self._session = session + def __init__( + self, + session: aiohttp.ClientSession, + url: str, + headers: dict[str, str] | None = None, + ) -> None: + super().__init__(session, provider_name="webhook") self._url = url - self._headers = headers or {} + self._extra_headers = headers or {} async def send(self, payload: dict[str, Any]) -> dict[str, Any]: - try: - await avalidate_outbound_url(self._url) - except UnsafeURLError as err: - return {"success": False, "error": f"Unsafe URL: {err}"} - try: - async with self._session.post( - self._url, - json=payload, - headers={"Content-Type": "application/json", **self._headers}, - timeout=_DEFAULT_TIMEOUT, - allow_redirects=False, - ) as response: - if 200 <= response.status < 300: - return {"success": True, "status_code": response.status} - body = await response.text() - return {"success": False, "error": f"HTTP {response.status}", "body": body[:200]} - except aiohttp.ClientError as err: - return {"success": False, "error": str(err)} + return await self.request("POST", self._url, json=payload, headers=self._extra_headers) diff --git a/packages/server/src/notify_bridge_server/api/app_settings.py b/packages/server/src/notify_bridge_server/api/app_settings.py index 5468600..8555f50 100644 --- a/packages/server/src/notify_bridge_server/api/app_settings.py +++ b/packages/server/src/notify_bridge_server/api/app_settings.py @@ -218,6 +218,19 @@ async def get_supported_locales( return locales or ["en"] +@router.get("/external-url") +async def get_external_url( + user: User = Depends(get_current_user), + session: AsyncSession = Depends(get_session), +): + """Return the configured external base URL (available to all users). + + Used by the UI to render absolute provider webhook URLs. Returns empty + string when unset so the UI falls back to the relative path. + """ + return {"external_url": (await get_setting(session, "external_url")).rstrip("/")} + + async def _reregister_webhooks( session: AsyncSession, base_url: str, secret: str ) -> None: diff --git a/packages/server/tests/test_dispatcher_aggregation.py b/packages/server/tests/test_dispatcher_aggregation.py new file mode 100644 index 0000000..d5c20ab --- /dev/null +++ b/packages/server/tests/test_dispatcher_aggregation.py @@ -0,0 +1,46 @@ +"""Dispatcher result aggregation: per-receiver detail must survive.""" + +from __future__ import annotations + +from notify_bridge_core.notifications.dispatcher import NotificationDispatcher + + +def test_aggregate_all_success() -> None: + out = NotificationDispatcher._aggregate_results([ + {"success": True, "message_id": 1}, + {"success": True, "message_id": 2}, + ]) + assert out["success"] is True + assert out["receivers"] == 2 + assert out["successes"] == 2 + assert out["failures"] == 0 + + +def test_aggregate_partial() -> None: + out = NotificationDispatcher._aggregate_results([ + {"success": True}, + {"success": False, "error": "boom"}, + ]) + assert out["success"] is True # at least one succeeded + assert out["successes"] == 1 + assert out["failures"] == 1 + assert "boom" in out["errors"] + assert "results" in out + + +def test_aggregate_all_fail_preserves_all_errors() -> None: + out = NotificationDispatcher._aggregate_results([ + {"success": False, "error": "first"}, + {"success": False, "error": "second"}, + ]) + assert out["success"] is False + assert out["error"] == "first" # back-compat top-level field + assert out["errors"] == ["first", "second"] + # Per-receiver details survive — operator can see exactly what failed. + assert len(out["results"]) == 2 + + +def test_aggregate_empty() -> None: + out = NotificationDispatcher._aggregate_results([]) + assert out["success"] is False + assert "error" in out diff --git a/packages/server/tests/test_email_client.py b/packages/server/tests/test_email_client.py new file mode 100644 index 0000000..f27ffc8 --- /dev/null +++ b/packages/server/tests/test_email_client.py @@ -0,0 +1,77 @@ +"""Email client header-injection / address-validation regression tests.""" + +from __future__ import annotations + +import pytest + +from notify_bridge_core.notifications.email.client import ( + EmailClient, + SmtpConfig, + _strip_header, + _validate_email, + _to_html, +) + + +def test_strip_header_removes_crlf() -> None: + out = _strip_header("Subject\r\nBcc: attacker@example.com") + assert "\r" not in out + assert "\n" not in out + # The injected "Bcc:" line is folded to a single header line; the SMTP + # server will treat it as part of the subject text, not a header. + assert "Bcc:" in out # value preserved as plain text + + +def test_strip_header_removes_bare_lf() -> None: + out = _strip_header("Hello\nWorld") + assert "\n" not in out + + +def test_strip_header_handles_non_string() -> None: + assert _strip_header(None) == "" + + +def test_validate_email_rejects_crlf() -> None: + with pytest.raises(ValueError): + _validate_email("user@example.com\r\nBcc: x@y") + + +def test_validate_email_rejects_no_at() -> None: + with pytest.raises(ValueError): + _validate_email("not-an-email") + + +def test_validate_email_rejects_empty() -> None: + with pytest.raises(ValueError): + _validate_email("") + + +def test_validate_email_accepts_normal() -> None: + assert _validate_email("user@example.com") == "user@example.com" + + +def test_to_html_escapes_brackets() -> None: + out = _to_html("") + assert "