feat: harden notification stack and switch logging selectors to icon grid

Notifications:
- Add shared http_base, redact, and SSRF hardening modules
- Refactor dispatcher, queue, receiver and per-provider clients
  (telegram, discord, email, matrix, ntfy, slack, webhook) to use
  the shared base, with bounded queue and redacted error logs
- Tests for ssrf, redact, http_base, queue bounds, dispatcher
  aggregation, telegram media partition, email and matrix clients

Frontend:
- Settings: log level / log format selectors now use IconGridSelect
  with per-option icons and i18n descriptions
- Minor providers page and entity-cache store updates

Tooling:
- Document code-review-graph MCP usage in CLAUDE.md
- Ignore .code-review-graph/, register .mcp.json
This commit is contained in:
2026-05-07 13:53:26 +03:00
parent 5bd63a2191
commit 0eb899afb9
33 changed files with 2623 additions and 1033 deletions
+2
View File
@@ -56,3 +56,5 @@ frontend/.svelte-kit/
# Logs
*.log
# Added by code-review-graph
.code-review-graph/
+12
View File
@@ -0,0 +1,12 @@
{
"mcpServers": {
"code-review-graph": {
"command": "uvx",
"args": [
"code-review-graph",
"serve"
],
"type": "stdio"
}
}
}
+39
View File
@@ -43,3 +43,42 @@ Detailed context is split into focused documents under `.claude/docs/`. Read the
- Notification preview sample: `packages/server/src/notify_bridge_server/services/sample_context.py` (`_SAMPLE_CONTEXT`)
- Command preview sample: `packages/server/src/notify_bridge_server/api/command_template_configs.py` (`sample_ctx` in `preview_raw`)
- Runtime validator whitelist: `packages/core/src/notify_bridge_core/templates/validator.py`
<!-- code-review-graph MCP tools -->
## MCP Tools: code-review-graph
**IMPORTANT: This project has a knowledge graph. ALWAYS use the
code-review-graph MCP tools BEFORE using Grep/Glob/Read to explore
the codebase.** The graph is faster, cheaper (fewer tokens), and gives
you structural context (callers, dependents, test coverage) that file
scanning cannot.
### When to use graph tools FIRST
- **Exploring code**: `semantic_search_nodes` or `query_graph` instead of Grep
- **Understanding impact**: `get_impact_radius` instead of manually tracing imports
- **Code review**: `detect_changes` + `get_review_context` instead of reading entire files
- **Finding relationships**: `query_graph` with callers_of/callees_of/imports_of/tests_for
- **Architecture questions**: `get_architecture_overview` + `list_communities`
Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
### Key Tools
| Tool | Use when |
|------|----------|
| `detect_changes` | Reviewing code changes — gives risk-scored analysis |
| `get_review_context` | Need source snippets for review — token-efficient |
| `get_impact_radius` | Understanding blast radius of a change |
| `get_affected_flows` | Finding which execution paths are impacted |
| `query_graph` | Tracing callers, callees, imports, tests, dependencies |
| `semantic_search_nodes` | Finding functions/classes by name or keyword |
| `get_architecture_overview` | Understanding high-level codebase structure |
| `refactor_tool` | Planning renames, finding dead code |
### Workflow
1. The graph auto-updates on file changes (via hooks).
2. Use `detect_changes` for code review.
3. Use `get_affected_flows` to understand impact.
4. Use `query_graph` pattern="tests_for" to check coverage.
+16
View File
@@ -73,6 +73,22 @@ export const localeItems = (): GridItem[] => [
{ value: 'ru', icon: 'mdiAlphabeticalVariant', label: 'Русский', desc: t('gridDesc.localeRu') },
];
// --- Log level ---
export const logLevelItems = (): GridItem[] => [
{ value: 'DEBUG', icon: 'mdiBugOutline', label: 'DEBUG', desc: t('gridDesc.logLevelDebug') },
{ value: 'INFO', icon: 'mdiInformationOutline', label: 'INFO', desc: t('gridDesc.logLevelInfo') },
{ value: 'WARNING', icon: 'mdiAlertOutline', label: 'WARNING', desc: t('gridDesc.logLevelWarning') },
{ value: 'ERROR', icon: 'mdiAlertOctagonOutline', label: 'ERROR', desc: t('gridDesc.logLevelError') },
];
// --- Log format ---
export const logFormatItems = (): GridItem[] => [
{ value: 'text', icon: 'mdiFormatText', label: 'text', desc: t('gridDesc.logFormatText') },
{ value: 'json', icon: 'mdiCodeJson', label: 'json', desc: t('gridDesc.logFormatJson') },
];
// --- Response mode ---
export const responseModeItems = (tFn: typeof t): GridItem[] => [
+8 -1
View File
@@ -192,7 +192,8 @@
"apiToken": "API Token",
"apiTokenHint": "Optional. Needed for connection testing and repository listing.",
"webhookUrl": "Webhook URL",
"webhookUrlHint": "Set this as the Target URL in Gitea webhook settings (relative to your bridge host).",
"webhookUrlHint": "Set this as the Target URL in Gitea webhook settings. The full URL is shown when an external base URL is configured in Settings; otherwise it is relative to your bridge host.",
"webhookUrlCopyTitle": "Click to copy",
"nutHost": "NUT Server Host",
"nutHostPlaceholder": "192.168.1.100 or ups.local",
"nutPort": "NUT Server Port",
@@ -1131,6 +1132,12 @@
"memorySourceNative": "Use Immich native memories API",
"localeEn": "English interface",
"localeRu": "Russian interface",
"logLevelDebug": "Verbose — show every step",
"logLevelInfo": "Default — high-level events",
"logLevelWarning": "Warnings and errors only",
"logLevelError": "Errors only — quietest",
"logFormatText": "Human-readable plain text",
"logFormatJson": "One JSON object per line",
"modeMedia": "Send actual photo/video files",
"modeText": "Send file names and links only",
"allEvents": "Show all event types",
+8 -1
View File
@@ -192,7 +192,8 @@
"apiToken": "API токен",
"apiTokenHint": "Необязательно. Нужен для проверки подключения и получения списка репозиториев.",
"webhookUrl": "URL вебхука",
"webhookUrlHint": "Укажите этот URL в настройках вебхука Gitea (относительно хоста bridge).",
"webhookUrlHint": "Укажите этот URL в настройках вебхука Gitea. Полный URL показывается, если в настройках задан внешний адрес; иначе путь указан относительно хоста bridge.",
"webhookUrlCopyTitle": "Нажмите, чтобы скопировать",
"nutHost": "Хост NUT-сервера",
"nutHostPlaceholder": "192.168.1.100 или ups.local",
"nutPort": "Порт NUT-сервера",
@@ -1131,6 +1132,12 @@
"memorySourceNative": "Использовать API воспоминаний Immich",
"localeEn": "Английский интерфейс",
"localeRu": "Русский интерфейс",
"logLevelDebug": "Подробный — каждый шаг",
"logLevelInfo": "По умолчанию — ключевые события",
"logLevelWarning": "Только предупреждения и ошибки",
"logLevelError": "Только ошибки — самый тихий",
"logFormatText": "Читаемый человеком текст",
"logFormatJson": "Один JSON-объект на строку",
"modeMedia": "Отправка файлов фото/видео",
"modeText": "Только имена файлов и ссылки",
"allEvents": "Показать все типы событий",
+28
View File
@@ -112,6 +112,34 @@ export const capabilitiesCache = (() => {
};
})();
/** Configured external base URL — used to render absolute webhook URLs.
* Available to all authenticated users. Empty string when unset. */
export const externalUrlCache = (() => {
let data = $state<string>('');
let fetchedAt = $state(0);
let inflight: Promise<string> | null = null;
const TTL = 300_000;
return {
get value() { return data; },
invalidate() { fetchedAt = 0; },
async fetch(force = false): Promise<string> {
if (!force && fetchedAt > 0 && Date.now() - fetchedAt < TTL) return data;
if (inflight) return inflight;
inflight = (async () => {
try {
const res = await api<{ external_url: string }>('/settings/external-url');
data = (res?.external_url || '').replace(/\/+$/, '');
fetchedAt = Date.now();
return data;
} finally {
inflight = null;
}
})();
return inflight;
},
};
})();
/** Supported template locales — fetched from app settings. */
export const supportedLocalesCache = (() => {
let data = $state<string[]>(['en', 'ru']);
+42 -4
View File
@@ -3,7 +3,7 @@
import { slide } from 'svelte/transition';
import { api, getBlockedBy, type BlockedByDetail } from '$lib/api';
import { t } from '$lib/i18n';
import { providersCache } from '$lib/stores/caches.svelte';
import { providersCache, externalUrlCache } from '$lib/stores/caches.svelte';
import PageHeader from '$lib/components/PageHeader.svelte';
import Card from '$lib/components/Card.svelte';
import Loading from '$lib/components/Loading.svelte';
@@ -21,7 +21,7 @@
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
import { topbarAction } from '$lib/stores/topbar-action.svelte';
import { onDestroy } from 'svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import { snackSuccess, snackError, snackInfo } from '$lib/stores/snackbar.svelte';
import { highlightFromUrl } from '$lib/highlight';
import { getDescriptor, buildProviderFormDefaults } from '$lib/providers';
import Button from '$lib/components/Button.svelte';
@@ -45,6 +45,30 @@
let confirmDelete = $state<ServiceProvider | null>(null);
let descriptor = $derived(getDescriptor(form.type));
let externalUrl = $derived(externalUrlCache.value);
function buildWebhookUrl(pattern: string, token: string): string {
const path = pattern.replace('{token}', token ?? '');
return externalUrl ? `${externalUrl}${path}` : path;
}
function copyWebhookUrl(e: Event, url: string) {
e.preventDefault();
e.stopPropagation();
if (navigator.clipboard?.writeText) {
navigator.clipboard.writeText(url);
} else {
const ta = document.createElement('textarea');
ta.value = url;
ta.style.position = 'fixed';
ta.style.opacity = '0';
document.body.appendChild(ta);
ta.select();
document.execCommand('copy');
document.body.removeChild(ta);
}
snackInfo(`${t('snack.copied')}: ${url}`);
}
// Auto-update name when provider type changes (unless user manually edited)
$effect(() => {
@@ -76,6 +100,7 @@
onclick: () => { showForm ? (showForm = false, editing = null) : openNew(); },
});
load();
externalUrlCache.fetch().catch(() => { /* fall back to relative URLs */ });
});
onDestroy(() => topbarAction.clear());
async function load() {
@@ -246,9 +271,15 @@
</div>
{/each}
{#if descriptor?.webhookUrlPattern && editing}
{@const editingWebhookUrl = buildWebhookUrl(descriptor.webhookUrlPattern, providers.find(p => p.id === editing)?.webhook_token ?? '')}
<div class="bg-[var(--color-muted)] rounded-md p-3">
<div class="block text-sm font-medium mb-1">{t('providers.webhookUrl')}</div>
<code class="text-xs select-all break-all">{descriptor.webhookUrlPattern.replace('{token}', providers.find(p => p.id === editing)?.webhook_token ?? '')}</code>
<button type="button"
onclick={(e) => copyWebhookUrl(e, editingWebhookUrl)}
title={t('providers.webhookUrlCopyTitle')}
class="text-xs break-all text-left hover:text-[var(--color-primary)] cursor-pointer font-mono w-full">
<code class="bg-transparent">{editingWebhookUrl}</code>
</button>
<p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.webhookUrlHint')}</p>
</div>
{/if}
@@ -295,7 +326,14 @@
<p class="text-xs text-[var(--color-muted-foreground)] font-mono">{provider.config.host}:{provider.config.port || 3493}</p>
{/if}
{#if provDesc?.webhookUrlPattern}
<p class="text-xs text-[var(--color-muted-foreground)] font-mono mt-0.5">{t('providers.webhookUrl')}: <span class="select-all">{provDesc.webhookUrlPattern.replace('{token}', provider.webhook_token)}</span></p>
{@const webhookUrl = buildWebhookUrl(provDesc.webhookUrlPattern, provider.webhook_token)}
<p class="text-xs text-[var(--color-muted-foreground)] font-mono mt-0.5">
{t('providers.webhookUrl')}:
<button type="button"
onclick={(e) => copyWebhookUrl(e, webhookUrl)}
title={t('providers.webhookUrlCopyTitle')}
class="hover:text-[var(--color-primary)] cursor-pointer break-all text-left">{webhookUrl}</button>
</p>
{/if}
</div>
</div>
+6 -12
View File
@@ -12,7 +12,10 @@
import ConfirmModal from '$lib/components/ConfirmModal.svelte';
import LocaleSelector from '$lib/components/LocaleSelector.svelte';
import TimezoneSelector from '$lib/components/TimezoneSelector.svelte';
import IconGridSelect from '$lib/components/IconGridSelect.svelte';
import { logLevelItems, logFormatItems } from '$lib/grid-items';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import { externalUrlCache } from '$lib/stores/caches.svelte';
interface CacheBucketStats {
count: number;
@@ -76,6 +79,7 @@
saving = true; error = '';
try {
settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) });
externalUrlCache.invalidate();
snackSuccess(t('settings.saved'));
} catch (err: any) { error = err.message; snackError(err.message); }
saving = false;
@@ -221,21 +225,11 @@
<div class="grid grid-cols-1 sm:grid-cols-2 gap-4">
<div>
<label class="block text-xs font-medium mb-1">{t('settings.logLevel')}<Hint text={t('settings.logLevelHint')} /></label>
<select bind:value={settings.log_level}
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
<option value="DEBUG">DEBUG</option>
<option value="INFO">INFO</option>
<option value="WARNING">WARNING</option>
<option value="ERROR">ERROR</option>
</select>
<IconGridSelect items={logLevelItems()} bind:value={settings.log_level} columns={2} />
</div>
<div>
<label class="block text-xs font-medium mb-1">{t('settings.logFormat')}<Hint text={t('settings.logFormatHint')} /></label>
<select bind:value={settings.log_format}
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
<option value="text">text</option>
<option value="json">json</option>
</select>
<IconGridSelect items={logFormatItems()} bind:value={settings.log_format} columns={2} />
</div>
<div class="sm:col-span-2">
<label class="block text-xs font-medium mb-1">{t('settings.logLevels')}<Hint text={t('settings.logLevelsHint')} /></label>
@@ -4,21 +4,24 @@ from __future__ import annotations
import asyncio
import logging
from typing import Any
from typing import Any, Final
import aiohttp
from ..http_base import HttpProviderClient
_LOGGER = logging.getLogger(__name__)
# Discord webhook content limit
MAX_CONTENT_LENGTH = 2000
# Discord API constraints (per webhook docs).
MAX_CONTENT_LENGTH: Final = 2000
MAX_USERNAME_LENGTH: Final = 80
class DiscordClient:
class DiscordClient(HttpProviderClient):
"""Sends messages via Discord webhook URLs."""
def __init__(self, session: aiohttp.ClientSession) -> None:
self._session = session
super().__init__(session, provider_name="discord")
async def send(
self,
@@ -33,6 +36,8 @@ class DiscordClient:
"""
if not webhook_url:
return {"success": False, "error": "Missing webhook_url"}
if username and len(username) > MAX_USERNAME_LENGTH:
return {"success": False, "error": f"username exceeds {MAX_USERNAME_LENGTH} chars"}
chunks = _split_message(message, MAX_CONTENT_LENGTH)
for chunk in chunks:
@@ -42,71 +47,34 @@ class DiscordClient:
if avatar_url:
payload["avatar_url"] = avatar_url
result = await self._post(webhook_url, payload)
if not result["success"]:
result = await self.request("POST", webhook_url, json=payload)
if not result.get("success"):
return result
# Small delay between chunks to respect rate limits
if len(chunks) > 1:
await asyncio.sleep(0.5)
return {"success": True}
_MAX_RETRIES = 3
_MAX_RETRY_AFTER = 60.0
async def _post(self, url: str, payload: dict) -> dict[str, Any]:
"""POST with bounded 429 retry.
We cap retries at _MAX_RETRIES and the ``Retry-After`` header at
_MAX_RETRY_AFTER seconds so a hostile or misbehaving upstream cannot
pin the dispatch task indefinitely.
"""
for attempt in range(self._MAX_RETRIES + 1):
try:
async with self._session.post(
url,
json=payload,
headers={"Content-Type": "application/json"},
allow_redirects=False,
) as resp:
if resp.status == 429 and attempt < self._MAX_RETRIES:
try:
retry_after = float(resp.headers.get("Retry-After", "2"))
except (TypeError, ValueError):
retry_after = 2.0
retry_after = max(0.0, min(retry_after, self._MAX_RETRY_AFTER))
_LOGGER.warning(
"Discord rate limited, retrying after %.1fs (attempt %d/%d)",
retry_after, attempt + 1, self._MAX_RETRIES,
)
await asyncio.sleep(retry_after)
continue
if 200 <= resp.status < 300:
return {"success": True}
body = await resp.text()
return {
"success": False,
"error": f"HTTP {resp.status}: {body[:200]}",
}
except aiohttp.ClientError as e:
return {"success": False, "error": str(e)}
return {"success": False, "error": "Rate limited (retries exhausted)"}
def _split_message(text: str, limit: int) -> list[str]:
"""Split message into chunks respecting the character limit."""
"""Split message into chunks respecting the character limit.
Drops chunks that contain only whitespace — Discord rejects those.
"""
if len(text) <= limit:
return [text]
chunks = []
chunks: list[str] = []
while text:
if len(text) <= limit:
chunks.append(text)
break
# Try to split at newline
split_at = text.rfind("\n", 0, limit)
if split_at <= 0:
split_at = limit
chunks.append(text[:split_at])
text = text[split_at:].lstrip("\n")
return chunks
piece = text
text = ""
else:
split_at = text.rfind("\n", 0, limit)
if split_at <= 0:
split_at = limit
piece = text[:split_at]
text = text[split_at:].lstrip("\n")
if piece.strip():
chunks.append(piece)
return chunks or [text]
@@ -7,7 +7,7 @@ import contextlib
import logging
import uuid
from dataclasses import dataclass, field
from typing import Any, AsyncIterator
from typing import Any, AsyncIterator, Awaitable, Callable, Final
import aiohttp
@@ -15,37 +15,20 @@ from notify_bridge_core.log_context import bind_log_context, dispatch_id_var
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.templates.context import build_template_context
from notify_bridge_core.templates.renderer import render_template
from .ssrf import UnsafeURLError, avalidate_outbound_url
_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
# Cap on how many asset downloads run concurrently inside
# ``_preload_asset_data``. Peak memory during a send is bounded to roughly
# ``_PRELOAD_CONCURRENCY * max_asset_size`` instead of ``max_media_to_send *
# max_asset_size``, which matters on small-RAM Docker hosts when a batch
# contains many large videos.
_PRELOAD_CONCURRENCY = 6
def _new_session() -> aiohttp.ClientSession:
"""Per-dispatch aiohttp session with a sane default timeout.
We still open a short-lived session per dispatch (connection reuse across
dispatches lives in the server-side shared session), but we always attach
a total timeout so a hung peer cannot wedge the task forever.
"""
return aiohttp.ClientSession(timeout=_HTTP_TIMEOUT)
from .http_base import safe_headers
from .receiver import (
DiscordReceiver,
EmailReceiver,
MatrixReceiver,
NtfyReceiver,
Receiver,
SlackReceiver,
TelegramReceiver,
WebhookReceiver,
EmailReceiver,
DiscordReceiver,
SlackReceiver,
NtfyReceiver,
MatrixReceiver,
)
from .redact import redact_exc
from .ssrf import UnsafeURLError, avalidate_outbound_url
from .telegram.cache import TelegramFileCache
from .telegram.client import TelegramClient
from .telegram.media import (
@@ -58,7 +41,33 @@ from .webhook.client import WebhookClient
_LOGGER = logging.getLogger(__name__)
DEFAULT_TEMPLATE = '{{ event_type }}: "{{ collection_name }}"'
DEFAULT_TEMPLATE: Final = '{{ event_type }}: "{{ collection_name }}"'
_HTTP_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
# Cap on how many asset downloads run concurrently inside
# ``_preload_asset_data``. Peak memory during a send is bounded to roughly
# ``_PRELOAD_CONCURRENCY * max_asset_size`` instead of ``max_media_to_send *
# max_asset_size``.
_PRELOAD_CONCURRENCY: Final = 6
# Cap on how many targets the dispatcher fans out to at once. With dozens
# of targets and a single hung peer, unbounded ``gather`` can pin the
# dispatch task. The cap also protects against credential-reuse rate
# limits on shared providers.
_DISPATCH_CONCURRENCY: Final = 16
# Cap on parallel per-receiver sends within a single target.
_RECEIVER_CONCURRENCY: Final = 8
# Per-target soft timeout — at the top of the dispatch tree so a single
# misbehaving target can't hold the whole batch open. Individual provider
# clients carry their own per-request timeouts on top of this.
_TARGET_TIMEOUT_S: Final = 120.0
def _new_session() -> aiohttp.ClientSession:
return aiohttp.ClientSession(timeout=_HTTP_TIMEOUT)
@dataclass
@@ -66,17 +75,23 @@ class TargetConfig:
"""Configuration for a notification target."""
type: str # "telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"
config: dict[str, Any] # target-level config (bot_token, settings, etc.)
template_slots: dict[str, dict[str, str]] | None = None # event_type -> {locale -> template}
locale: str = "en" # default locale for template resolution
config: dict[str, Any]
template_slots: dict[str, dict[str, str]] | None = None
locale: str = "en"
date_format: str = "%d.%m.%Y, %H:%M UTC"
date_only_format: str = "%d.%m.%Y"
provider_api_key: str | None = None # API key for downloading assets from provider
provider_internal_url: str | None = None # Internal provider URL for API key scoping
provider_external_url: str | None = None # External domain for API key scoping
provider_api_key: str | None = None
provider_internal_url: str | None = None
provider_external_url: str | None = None
receivers: list[Receiver] = field(default_factory=list)
_SendMethod = Callable[
["NotificationDispatcher", TargetConfig, str, ServiceEvent],
Awaitable[dict[str, Any]],
]
class NotificationDispatcher:
"""Dispatches ServiceEvent notifications to configured targets."""
@@ -90,18 +105,11 @@ class NotificationDispatcher:
self._url_cache = url_cache
self._asset_cache = asset_cache
# Optional shared session owned by the caller; when supplied we reuse
# its connection pool instead of opening a fresh per-dispatch session
# (saves a TLS handshake per outbound call).
# its connection pool instead of opening a fresh per-dispatch session.
self._shared_session = session
@contextlib.asynccontextmanager
async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]:
"""Yield an aiohttp session, reusing the shared one if provided.
When a shared session was passed in ``__init__`` we yield it without
closing (the caller owns its lifetime). Otherwise we open a
short-lived session with our default timeout and close it on exit.
"""
if self._shared_session is not None and not self._shared_session.closed:
yield self._shared_session
return
@@ -115,11 +123,9 @@ class NotificationDispatcher:
) -> list[dict[str, Any]]:
"""Send event notification to all targets.
Returns list of results (one per target).
Returns one result per target. Per-target failures are isolated;
a single bad target cannot poison the batch.
"""
# Bind a dispatch_id so every log line emitted by the target sends
# (including deep in TelegramClient) can be correlated to the same
# upstream event.
new_id = dispatch_id_var.get() or f"disp:{uuid.uuid4().hex[:12]}"
with bind_log_context(dispatch_id=new_id):
@@ -128,20 +134,36 @@ class NotificationDispatcher:
event.event_type.value if hasattr(event.event_type, "value") else event.event_type,
getattr(event, "collection_name", None), len(targets),
)
sem = asyncio.Semaphore(_DISPATCH_CONCURRENCY)
async def run_one(t: TargetConfig) -> dict[str, Any]:
async with sem:
try:
return await asyncio.wait_for(
self._send_to_target(event, t),
timeout=_TARGET_TIMEOUT_S,
)
except asyncio.TimeoutError:
return {
"success": False,
"error": f"Target dispatch timed out after {_TARGET_TIMEOUT_S}s",
}
raw_results = await asyncio.gather(
*[self._send_to_target(event, t) for t in targets],
*[run_one(t) for t in targets],
return_exceptions=True,
)
results = []
results: list[dict[str, Any]] = []
failures = 0
for target, raw in zip(targets, raw_results):
if isinstance(raw, Exception):
failures += 1
_LOGGER.error(
"Dispatch to target type=%s failed: %s",
target.type, raw, exc_info=raw,
target.type, redact_exc(raw),
)
results.append({"success": False, "error": str(raw)})
results.append({"success": False, "error": redact_exc(raw)})
else:
if isinstance(raw, dict) and not raw.get("success"):
failures += 1
@@ -155,7 +177,6 @@ class NotificationDispatcher:
def _resolve_template(
self, event: ServiceEvent, target: TargetConfig, locale: str,
) -> str:
"""Resolve template string for an event, with locale fallback."""
template_str = DEFAULT_TEMPLATE
if target.template_slots:
locale_map = target.template_slots.get(event.event_type.value)
@@ -166,7 +187,6 @@ class NotificationDispatcher:
def _render_message(
self, event: ServiceEvent, target: TargetConfig, locale: str,
) -> str:
"""Resolve template and render message for a given locale."""
template_str = self._resolve_template(event, target, locale)
ctx = build_template_context(
event, target_type=target.type,
@@ -179,7 +199,6 @@ class NotificationDispatcher:
self, receiver: Receiver, default_message: str,
event: ServiceEvent, target: TargetConfig,
) -> str:
"""Return per-receiver message, re-rendering if receiver has a different locale."""
if receiver.locale and receiver.locale != target.locale:
return self._render_message(event, target, receiver.locale)
return default_message
@@ -187,21 +206,16 @@ class NotificationDispatcher:
async def _send_to_target(
self, event: ServiceEvent, target: TargetConfig
) -> dict[str, Any]:
"""Send event to a single target (potentially multiple receivers)."""
"""Dispatch to a single target via the registered handler."""
default_message = self._render_message(event, target, target.locale)
send_method = _PROVIDER_HANDLERS.get(target.type)
if send_method is None:
return {"success": False, "error": f"Unknown target type: {target.type}"}
return await send_method(self, target, default_message, event)
send_method = {
"telegram": self._send_telegram,
"webhook": self._send_webhook,
"email": self._send_email,
"discord": self._send_discord,
"slack": self._send_slack,
"ntfy": self._send_ntfy,
"matrix": self._send_matrix,
}.get(target.type)
if send_method:
return await send_method(target, default_message, event)
return {"success": False, "error": f"Unknown target type: {target.type}"}
# ------------------------------------------------------------------
# Asset preload (Telegram-specific)
# ------------------------------------------------------------------
async def _preload_asset_data(
self,
@@ -210,36 +224,13 @@ class NotificationDispatcher:
session: aiohttp.ClientSession,
max_size: int | None,
) -> None:
"""Download each non-cached asset's bytes once and attach to the entry.
Three benefits:
* ``TelegramClient`` sees ``entry["data"]`` and skips its own download,
so we don't fetch each URL twice.
* We know the exact upload size, which lets the oversize warning in
the rendered text compare against real bytes (for Immich videos,
the transcoded ``/video/playback``), not the original ``file_size``.
* Assets already in the Telegram file_id cache are skipped, and their
stored size (if any) is used to populate ``playback_size`` — so
templates see consistent sizes for repeat sends without re-download.
Entries whose download fails or exceeds ``max_size`` are left without
``data``; ``TelegramClient`` will then fall back to its own download
path and apply the same checks — no regression, just no preload win.
Concurrency is bounded by ``_PRELOAD_CONCURRENCY`` so peak memory
stays predictable: at most N assets worth of bytes held in RAM at
once, regardless of ``max_media_to_send``. Total wall-clock is
unchanged for small batches and only marginally slower for large
ones (most assets fit in a single RTT and SSL negotiation cost
dominates, so 6-way parallelism is sufficient).
"""
"""Download each non-cached asset's bytes once, with SSRF guard."""
if not assets:
return
sem = asyncio.Semaphore(_PRELOAD_CONCURRENCY)
async def _fetch(entry: dict[str, Any], media: Any) -> None:
# Cache hit → skip download; populate playback_size from stored size.
async def fetch(entry: dict[str, Any], media: Any) -> None:
cache, key = self._cache_for_entry(entry)
if cache and key:
cached = cache.get(key)
@@ -251,28 +242,40 @@ class NotificationDispatcher:
url = entry["url"]
headers = entry.get("headers") or {}
try:
# Defense-in-depth: validate even though TelegramClient
# also validates. The dispatcher is what triggers the
# download, so the guard belongs here too.
await avalidate_outbound_url(url)
except UnsafeURLError as err:
_LOGGER.warning(
"Asset preload skipped: unsafe URL (%s)", redact_exc(err),
)
return
async with sem:
try:
async with session.get(url, headers=headers) as resp:
if resp.status != 200:
return
data = await resp.read()
except aiohttp.ClientError:
except (aiohttp.ClientError, asyncio.TimeoutError, OSError):
return
if max_size is not None and len(data) > max_size:
return
entry["data"] = data
media.extra["playback_size"] = len(data)
await asyncio.gather(*(_fetch(e, m) for e, m in zip(assets, media_assets)))
raw = await asyncio.gather(
*(fetch(e, m) for e, m in zip(assets, media_assets)),
return_exceptions=True,
)
for r in raw:
if isinstance(r, Exception):
_LOGGER.warning("Asset preload raised: %s", redact_exc(r))
def _cache_for_entry(
self, entry: dict[str, Any],
) -> tuple[TelegramFileCache | None, str | None]:
"""Resolve (cache, key) for an asset entry — mirrors TelegramClient logic.
Returns (None, None) if no cache is configured or no key can be derived.
"""
cache_key = entry.get("cache_key")
if cache_key:
cache = self._asset_cache if is_asset_cache_key(cache_key) else self._url_cache
@@ -287,6 +290,10 @@ class NotificationDispatcher:
return self._url_cache, url
return None, None
# ------------------------------------------------------------------
# Per-provider handlers
# ------------------------------------------------------------------
async def _send_telegram(
self, target: TargetConfig, default_message: str, event: ServiceEvent
) -> dict[str, Any]:
@@ -296,27 +303,25 @@ class NotificationDispatcher:
max_media = target.config.get("max_media_to_send", 50)
max_group = target.config.get("max_media_per_group", 10)
chunk_delay = target.config.get("media_delay", 500)
max_size = target.config.get("max_asset_size")
if max_size:
max_size = max_size * 1024 * 1024 # MB to bytes
max_size_mb = target.config.get("max_asset_size")
max_size_bytes = max_size_mb * 1024 * 1024 if max_size_mb else None
send_large_as_docs = target.config.get("send_large_photos_as_documents", False)
if not bot_token:
return {"success": False, "error": "Missing bot_token"}
if not target.receivers:
return {"success": False, "error": "No receivers configured"}
# Prepare assets list once (shared across receivers)
# Prefer internal URL for fetching (LAN speed vs public internet)
internal_url = (target.provider_internal_url or "").rstrip("/")
external_url = (target.provider_external_url or "").rstrip("/")
assets = []
media_assets: list[Any] = [] # aligned with `assets` for preload
assets: list[dict[str, Any]] = []
media_assets: list[Any] = []
for asset in event.added_assets[:max_media]:
url = asset.preview_url or asset.thumbnail_url or asset.full_url
if not url:
continue
asset_entry = build_telegram_asset_entry(
url=url or "",
url=url,
media_type=asset.type.value,
api_key=target.provider_api_key,
internal_url=internal_url,
@@ -327,26 +332,15 @@ class NotificationDispatcher:
assets.append(asset_entry)
media_assets.append(asset)
results: list[dict[str, Any]] = []
async with self._session_ctx() as session:
# Preload all asset bytes once so (a) TelegramClient can skip its
# own download and (b) we know exact upload sizes in time for the
# oversize warning in the rendered text.
await self._preload_asset_data(assets, media_assets, session, max_size)
default_message = self._render_message(event, target, target.locale)
await self._preload_asset_data(assets, media_assets, session, max_size_bytes)
# Asset cache (when in thumbhash mode) invalidates entries when the
# asset's visual content changes. The resolver maps asset id → its
# current thumbhash. Providers that expose thumbhash put it in
# ``asset.extra["thumbhash"]`` (currently Immich).
thumbhash_map = {
asset.id: asset.extra.get("thumbhash")
for asset in event.added_assets
if asset.extra.get("thumbhash")
}
thumbhash_resolver = (
(lambda key: thumbhash_map.get(key)) if thumbhash_map else None
)
thumbhash_resolver = thumbhash_map.get if thumbhash_map else None
client = TelegramClient(
session, bot_token,
@@ -355,39 +349,51 @@ class NotificationDispatcher:
thumbhash_resolver=thumbhash_resolver,
)
for receiver in target.receivers:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, TelegramReceiver) or not receiver.chat_id:
results.append({"success": False, "error": "Invalid telegram receiver"})
continue
return {"success": False, "error": "Invalid telegram receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
text_result = await client.send_message(
chat_id=receiver.chat_id,
text=message,
disable_web_page_preview=bool(disable_preview),
)
if not text_result.get("success"):
_LOGGER.warning("Failed to send to chat %s: %s", receiver.chat_id, text_result.get("error"))
results.append(text_result)
continue
_LOGGER.warning(
"Failed to send to chat %s: %s",
receiver.chat_id, text_result.get("error"),
)
return text_result
if assets:
reply_to = text_result.get("message_id")
media_result = await client.send_notification(
chat_id=receiver.chat_id,
assets=assets,
reply_to_message_id=reply_to,
reply_to_message_id=text_result.get("message_id"),
max_group_size=max_group,
chunk_delay=chunk_delay,
max_asset_data_size=max_size,
max_asset_data_size=max_size_bytes,
send_large_photos_as_documents=send_large_as_docs,
chat_action=chat_action or None,
)
if not media_result.get("success"):
_LOGGER.warning("Text sent OK but media failed for chat %s: %s", receiver.chat_id, media_result.get("error"))
_LOGGER.warning(
"Text sent OK but media failed for chat %s: %s",
receiver.chat_id, media_result.get("error"),
)
# Preserve both outcomes — text succeeded, media
# didn't. Operators losing media-failure detail
# in the result dict made root-cause analysis
# impossible.
return {
"success": True,
"message_id": text_result.get("message_id"),
"media_error": media_result.get("error"),
"media_failed_at_chunk": media_result.get("failed_at_chunk"),
}
return text_result
results.append(text_result)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results)
@@ -397,17 +403,10 @@ class NotificationDispatcher:
if not target.receivers:
return {"success": False, "error": "No receivers configured"}
results: list[dict[str, Any]] = []
async with self._session_ctx() as session:
for receiver in target.receivers:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
results.append({"success": False, "error": "Invalid webhook receiver"})
continue
try:
await avalidate_outbound_url(receiver.url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
return {"success": False, "error": "Invalid webhook receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
payload = {
"message": message,
@@ -417,8 +416,10 @@ class NotificationDispatcher:
"collection_id": event.collection_id,
"timestamp": event.timestamp.isoformat(),
}
client = WebhookClient(session, receiver.url, receiver.headers)
results.append(await client.send(payload))
client = WebhookClient(session, receiver.url, safe_headers(receiver.headers))
return await client.send(payload)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results)
@@ -431,7 +432,7 @@ class NotificationDispatcher:
if not smtp_cfg.get("host"):
return {"success": False, "error": "SMTP not configured"}
client = EmailClient(SmtpConfig(
email_client = EmailClient(SmtpConfig(
host=smtp_cfg["host"],
port=int(smtp_cfg.get("port", 587)),
username=smtp_cfg.get("username", ""),
@@ -439,27 +440,28 @@ class NotificationDispatcher:
from_address=smtp_cfg.get("from_address", ""),
from_name=smtp_cfg.get("from_name", "Notify Bridge"),
use_tls=smtp_cfg.get("use_tls", True),
tls_mode=smtp_cfg.get("tls_mode", "auto"),
))
if not target.receivers:
return {"success": False, "error": "No receivers configured"}
subject = f"[Notify Bridge] {event.event_type.value}: {event.collection_name}"
results: list[dict[str, Any]] = []
for receiver in target.receivers:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, EmailReceiver) or not receiver.email:
results.append({"success": False, "error": "Invalid email receiver"})
continue
return {"success": False, "error": "Invalid email receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
result = await client.send(
# body_html=None lets EmailClient build a safely-escaped HTML
# alternative from body_text instead of trusting user content.
return await email_client.send(
to_email=receiver.email,
subject=subject,
body_text=message,
body_html=message,
body_html=None,
to_name=receiver.name,
)
results.append(result)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results)
async def _send_discord(
@@ -471,20 +473,16 @@ class NotificationDispatcher:
return {"success": False, "error": "No receivers configured"}
username = target.config.get("username")
results: list[dict[str, Any]] = []
async with self._session_ctx() as session:
client = DiscordClient(session)
for receiver in target.receivers:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
results.append({"success": False, "error": "Invalid discord receiver"})
continue
try:
await avalidate_outbound_url(receiver.webhook_url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
return {"success": False, "error": "Invalid discord receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
results.append(await client.send(receiver.webhook_url, message, username=username))
return await client.send(receiver.webhook_url, message, username=username)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results)
@@ -497,20 +495,16 @@ class NotificationDispatcher:
return {"success": False, "error": "No receivers configured"}
username = target.config.get("username")
results: list[dict[str, Any]] = []
async with self._session_ctx() as session:
client = SlackClient(session)
for receiver in target.receivers:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
results.append({"success": False, "error": "Invalid slack receiver"})
continue
try:
await avalidate_outbound_url(receiver.webhook_url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
return {"success": False, "error": "Invalid slack receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
results.append(await client.send(receiver.webhook_url, message, username=username))
return await client.send(receiver.webhook_url, message, username=username)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results)
@@ -526,22 +520,23 @@ class NotificationDispatcher:
try:
await avalidate_outbound_url(server_url)
except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe ntfy server_url: {err}"}
return {"success": False, "error": f"Unsafe ntfy server_url: {redact_exc(err)}"}
title = f"{event.event_type.value}: {event.collection_name}"
results: list[dict[str, Any]] = []
async with self._session_ctx() as session:
client = NtfyClient(session)
for receiver in target.receivers:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
results.append({"success": False, "error": "Invalid ntfy receiver"})
continue
return {"success": False, "error": "Invalid ntfy receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
results.append(await client.send(
return await client.send(
server_url, receiver.topic, message,
title=title, priority=receiver.priority, auth_token=auth_token,
))
)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results)
@@ -557,33 +552,108 @@ class NotificationDispatcher:
try:
await avalidate_outbound_url(homeserver)
except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"}
return {"success": False, "error": f"Unsafe matrix homeserver_url: {redact_exc(err)}"}
if not target.receivers:
return {"success": False, "error": "No receivers configured"}
results: list[dict[str, Any]] = []
async with self._session_ctx() as session:
client = MatrixClient(session, homeserver, access_token)
for receiver in target.receivers:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
results.append({"success": False, "error": "Invalid matrix receiver"})
continue
return {"success": False, "error": "Invalid matrix receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
results.append(await client.send_message(
receiver.room_id, message, html_message=message,
))
# body_html is the same plain text — Matrix accepts the
# raw message as both ``body`` and ``formatted_body``.
# If templates emit HTML in the future, generate a
# separate HTML body upstream rather than aliasing here.
return await client.send_message(
receiver.room_id, message, html_message=None,
)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results)
# ------------------------------------------------------------------
# Aggregation
# ------------------------------------------------------------------
@staticmethod
async def _fan_out(
receivers: list[Receiver],
send_one: Callable[[Receiver], Awaitable[dict[str, Any]]],
) -> list[dict[str, Any]]:
"""Run ``send_one`` per receiver with bounded concurrency.
Per-receiver exceptions are converted to failure dicts so a single
bad receiver can't cancel its peers.
"""
sem = asyncio.Semaphore(_RECEIVER_CONCURRENCY)
async def guarded(receiver: Receiver) -> dict[str, Any]:
async with sem:
try:
return await send_one(receiver)
except Exception as exc: # noqa: BLE001
_LOGGER.error("Receiver send raised: %s", redact_exc(exc))
return {"success": False, "error": redact_exc(exc)}
return await asyncio.gather(*(guarded(r) for r in receivers))
@staticmethod
def _aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
"""Aggregate broadcast results into a single result dict."""
"""Aggregate per-receiver results into a single target-level result.
Preserves the per-receiver detail under ``receivers`` so a caller
can see exactly which receivers failed, instead of getting only
the first error.
"""
if not results:
return {"success": False, "error": "No receivers configured"}
successes = sum(1 for r in results if r.get("success"))
if successes == len(results) and results:
return {"success": True, "receivers": len(results)}
elif successes > 0:
return {"success": True, "receivers": len(results), "partial_failures": len(results) - successes}
elif results:
return results[0]
return {"success": False, "error": "No receivers configured"}
failures = len(results) - successes
out: dict[str, Any] = {
"success": successes > 0,
"receivers": len(results),
"successes": successes,
"failures": failures,
"results": results,
}
if failures:
out["errors"] = [
r.get("error") for r in results if not r.get("success")
]
if successes == 0:
# Surface the first error at the top level for back-compat
# with callers that only check ``error``.
out["error"] = results[0].get("error", "All receivers failed")
return out
# ----------------------------------------------------------------------
# Provider registry — replaces the if/elif chain so adding a provider
# means just registering it here, not editing dispatch logic.
# ----------------------------------------------------------------------
_PROVIDER_HANDLERS: dict[str, _SendMethod] = {
"telegram": NotificationDispatcher._send_telegram,
"webhook": NotificationDispatcher._send_webhook,
"email": NotificationDispatcher._send_email,
"discord": NotificationDispatcher._send_discord,
"slack": NotificationDispatcher._send_slack,
"ntfy": NotificationDispatcher._send_ntfy,
"matrix": NotificationDispatcher._send_matrix,
}
def register_provider(name: str, handler: _SendMethod) -> None:
"""Register a new dispatcher provider at runtime.
Allows out-of-tree providers to extend the dispatcher without
forking. The handler must follow the
``async (dispatcher, target, default_message, event) -> dict`` shape.
"""
_PROVIDER_HANDLERS[name] = handler
@@ -2,14 +2,32 @@
from __future__ import annotations
import html
import logging
import re
import ssl
from dataclasses import dataclass
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Any
from email.headerregistry import Address
from email.message import EmailMessage
from typing import Any, Final, Literal
try: # Optional dependency — fail at first send rather than at import.
import aiosmtplib
from aiosmtplib import SMTPException
except ImportError: # pragma: no cover
aiosmtplib = None # type: ignore[assignment]
class SMTPException(Exception): # type: ignore[no-redef]
pass
_LOGGER = logging.getLogger(__name__)
_DEFAULT_TIMEOUT_S: Final = 30.0
_TlsMode = Literal["auto", "implicit", "starttls", "none"]
# RFC 5322 lite: catches the obvious-bad addresses ("foo bar", "no-at",
# embedded CRLF) without pretending to fully validate addresses.
_EMAIL_RE: Final = re.compile(r"^[^\s@\r\n,;<>]+@[^\s@\r\n,;<>]+\.[^\s@\r\n,;<>]+$")
@dataclass
class SmtpConfig:
@@ -22,6 +40,55 @@ class SmtpConfig:
from_address: str = ""
from_name: str = "Notify Bridge"
use_tls: bool = True
# Explicit TLS mode. ``auto`` (back-compat) infers from ``use_tls`` and
# ``port``: 465 → implicit; 587 with use_tls=False → starttls; 25 → none.
tls_mode: _TlsMode = "auto"
timeout_s: float = _DEFAULT_TIMEOUT_S
def _strip_header(value: str) -> str:
"""Reject CRLF and bare CR/LF in header-bound strings.
SMTP header injection turns user-controlled subject/name strings into
arbitrary headers (``\\r\\nBcc: attacker@x``). The Python stdlib
accepts CRLF when followed by SP/HT (header folding), so explicit
sanitization is required even though :class:`EmailMessage` does some
validation of its own.
"""
return re.sub(r"[\r\n]+", " ", str(value or "")).strip()
def _validate_email(addr: str) -> str:
addr = _strip_header(addr)
if not addr:
raise ValueError("email address is empty")
if not _EMAIL_RE.match(addr):
raise ValueError("email address is invalid")
return addr
def _resolve_tls(cfg: SmtpConfig) -> tuple[bool, bool]:
"""Resolve ``(use_tls, start_tls)`` flags from the config.
``tls_mode`` overrides ``use_tls``/port heuristics when provided.
"""
mode = cfg.tls_mode
if mode == "implicit":
return True, False
if mode == "starttls":
return False, True
if mode == "none":
return False, False
# auto — preserve the historical "use_tls bool + port heuristic" behavior
# but make the path explicit.
if cfg.use_tls:
return True, False
return False, cfg.port != 25
def _to_html(text: str) -> str:
"""Convert plain text to a minimal HTML body, escaped for safety."""
return "<html><body><pre>" + html.escape(text) + "</pre></body></html>"
class EmailClient:
@@ -30,30 +97,39 @@ class EmailClient:
def __init__(self, smtp_config: SmtpConfig) -> None:
self._config = smtp_config
@staticmethod
def _ssl_context() -> ssl.SSLContext:
# Explicit context so the TLS posture is auditable; aiosmtplib
# defaults look correct today but past regressions (and downstream
# repackaging) make implicit reliance fragile.
return ssl.create_default_context()
async def verify_connection(self) -> dict[str, Any]:
"""Test SMTP connection and authentication without sending an email."""
try:
import aiosmtplib
except ImportError:
if aiosmtplib is None:
return {"success": False, "error": "aiosmtplib not installed"}
cfg = self._config
if not cfg.host:
return {"success": False, "error": "SMTP host not configured"}
use_tls, start_tls = _resolve_tls(cfg)
try:
smtp = aiosmtplib.SMTP(
hostname=cfg.host,
port=cfg.port,
use_tls=cfg.use_tls,
start_tls=not cfg.use_tls and cfg.port != 25,
use_tls=use_tls,
start_tls=start_tls,
tls_context=self._ssl_context(),
timeout=cfg.timeout_s,
validate_certs=True,
)
await smtp.connect()
if cfg.username and cfg.password:
await smtp.login(cfg.username, cfg.password)
await smtp.quit()
return {"success": True}
except Exception as e:
except (SMTPException, OSError) as e:
_LOGGER.warning("SMTP verification failed for %s:%d: %s", cfg.host, cfg.port, e)
return {"success": False, "error": str(e)}
@@ -65,27 +141,52 @@ class EmailClient:
body_html: str | None = None,
to_name: str = "",
) -> dict[str, Any]:
"""Send an email. Returns {"success": True} or {"success": False, "error": "..."}."""
try:
import aiosmtplib
except ImportError:
"""Send an email.
Returns ``{"success": True}`` or ``{"success": False, "error": "..."}``.
``body_html`` is treated as already-safe markup. Pass ``None`` to
derive a safe HTML alternative from ``body_text`` automatically.
"""
if aiosmtplib is None:
return {"success": False, "error": "aiosmtplib not installed. Run: pip install aiosmtplib"}
cfg = self._config
if not cfg.host or not cfg.from_address:
return {"success": False, "error": "SMTP not configured (missing host or from_address)"}
# Build email message
msg = MIMEMultipart("alternative")
msg["From"] = f"{cfg.from_name} <{cfg.from_address}>" if cfg.from_name else cfg.from_address
msg["To"] = f"{to_name} <{to_email}>" if to_name else to_email
msg["Subject"] = subject
try:
to_addr = _validate_email(to_email)
from_addr = _validate_email(cfg.from_address)
except ValueError as exc:
return {"success": False, "error": f"Invalid email address: {exc}"}
msg.attach(MIMEText(body_text, "plain", "utf-8"))
if body_html:
msg.attach(MIMEText(body_html, "html", "utf-8"))
# EmailMessage with structured Address objects rejects CRLF and
# framework-folds long headers safely. We still strip first because
# EmailMessage's display-name slot is a pure string.
msg = EmailMessage()
from_display = _strip_header(cfg.from_name) or ""
to_display = _strip_header(to_name) or ""
try:
from_user, _, from_domain = from_addr.partition("@")
to_user, _, to_domain = to_addr.partition("@")
msg["From"] = Address(from_display, from_user, from_domain) if from_display else from_addr
msg["To"] = Address(to_display, to_user, to_domain) if to_display else to_addr
except ValueError as exc:
return {"success": False, "error": f"Invalid email address: {exc}"}
msg["Subject"] = _strip_header(subject)
msg.set_content(body_text or "", subtype="plain", charset="utf-8")
# If the caller provided HTML explicitly, honor it; otherwise build a
# safe escaped version so a stray "<" in the rendered template can't
# break the markup.
msg.add_alternative(
body_html if body_html is not None else _to_html(body_text or ""),
subtype="html",
charset="utf-8",
)
use_tls, start_tls = _resolve_tls(cfg)
try:
await aiosmtplib.send(
msg,
@@ -93,11 +194,14 @@ class EmailClient:
port=cfg.port,
username=cfg.username or None,
password=cfg.password or None,
use_tls=cfg.use_tls,
start_tls=not cfg.use_tls and cfg.port != 25,
use_tls=use_tls,
start_tls=start_tls,
tls_context=self._ssl_context(),
timeout=cfg.timeout_s,
validate_certs=True,
)
_LOGGER.info("Email sent to %s", to_email)
_LOGGER.info("Email sent to %s", to_addr)
return {"success": True}
except Exception as e:
_LOGGER.error("Failed to send email to %s: %s", to_email, e)
except (SMTPException, OSError) as e:
_LOGGER.error("Failed to send email to %s: %s", to_addr, e)
return {"success": False, "error": str(e)}
@@ -0,0 +1,196 @@
"""Shared HTTP infrastructure for notification provider clients.
Slack/Discord/ntfy/Matrix/Webhook all follow the same pattern: build a
JSON payload, POST/PUT it, decode 200-range as success, decode 4xx/5xx
into a stable error dict, and retry transient 429/503 responses with a
capped ``Retry-After``. ``HttpProviderClient`` centralizes that pattern
so every provider gets the same SSRF guard, timeouts, secret-redacted
errors, and bounded retry policy by construction — adding a new
provider doesn't get to forget any one of them.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Final, Mapping
import aiohttp
from .redact import redact, redact_exc
from .ssrf import UnsafeURLError, avalidate_outbound_url
_LOGGER = logging.getLogger(__name__)
_DEFAULT_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
_MAX_RETRIES: Final = 3
_MAX_RETRY_AFTER_S: Final = 60.0
_RETRY_STATUSES: Final = frozenset({429, 503})
# Hop-by-hop / framing headers a caller must not be able to override via
# user-supplied target config. Letting them through enables request
# smuggling, host-header bypasses of WAFs, and cache poisoning.
_FORBIDDEN_HEADERS: Final = frozenset({
"host",
"content-length",
"transfer-encoding",
"connection",
"keep-alive",
"te",
"upgrade",
"expect",
"proxy-authorization",
"proxy-connection",
})
def safe_headers(headers: Mapping[str, str] | None) -> dict[str, str]:
"""Return a copy of ``headers`` with hop-by-hop/forbidden names dropped
and CRLF-bearing values rejected.
A target config that lets a user inject ``"X-Foo": "bar\\r\\nHost: evil"``
can perform request smuggling depending on aiohttp's framing. We strip
those values rather than letting them reach the wire.
"""
if not headers:
return {}
safe: dict[str, str] = {}
for raw_name, raw_value in headers.items():
name = str(raw_name).strip()
if not name or name.lower() in _FORBIDDEN_HEADERS:
continue
if any(c in name for c in "\r\n:"):
continue
value = str(raw_value)
if "\r" in value or "\n" in value:
continue
safe[name] = value
return safe
def make_error(message: str, *, status: int | None = None, body: str | None = None) -> dict[str, Any]:
"""Build a stable failure dict shape used by every provider client."""
err: dict[str, Any] = {"success": False, "error": redact(message)}
if status is not None:
err["status_code"] = status
if body:
err["body"] = redact(body)[:200]
return err
def make_success(**extra: Any) -> dict[str, Any]:
"""Build a stable success dict shape used by every provider client."""
out: dict[str, Any] = {"success": True}
out.update(extra)
return out
def _retry_after_seconds(headers: Mapping[str, str], cap_s: float) -> float:
raw = headers.get("Retry-After") or headers.get("retry-after") or "2"
try:
seconds = float(raw)
except (TypeError, ValueError):
seconds = 2.0
return max(0.0, min(seconds, cap_s))
class HttpProviderClient:
"""Base for JSON-over-HTTP notification providers.
Subclasses call :meth:`request` instead of using ``self._session``
directly. ``request`` runs the SSRF guard (skippable for known-safe
upstreams via ``ssrf_validate=False``), enforces a per-request
timeout, retries 429/503 with a capped ``Retry-After``, and turns
transport/HTTP errors into the canonical ``{"success": False, ...}``
shape with secrets redacted.
"""
_max_retries: int = _MAX_RETRIES
# Settable per-instance so tests / hostile-upstream tuning can
# tighten the cap. Reads of this attribute fall through to the
# class default when no instance value has been set.
_MAX_RETRY_AFTER: float = _MAX_RETRY_AFTER_S
def __init__(
self,
session: aiohttp.ClientSession,
*,
timeout: aiohttp.ClientTimeout | None = None,
provider_name: str = "http",
) -> None:
self._session = session
self._timeout = timeout or _DEFAULT_TIMEOUT
self._provider = provider_name
async def request(
self,
method: str,
url: str,
*,
json: Any = None,
headers: Mapping[str, str] | None = None,
ssrf_validate: bool = True,
retry_statuses: frozenset[int] = _RETRY_STATUSES,
) -> dict[str, Any]:
"""Send a single request with retry + redaction. Always returns a dict.
On 2xx returns ``{"success": True, "status_code": int, "json": ...
OR "body": str}``. On non-2xx returns the canonical error dict.
"""
if ssrf_validate:
try:
await avalidate_outbound_url(url)
except UnsafeURLError as err:
return make_error(f"Unsafe URL: {redact_exc(err)}")
outbound_headers: dict[str, str] = {"Content-Type": "application/json"}
outbound_headers.update(safe_headers(headers))
for attempt in range(1, self._max_retries + 1):
try:
async with self._session.request(
method,
url,
json=json,
headers=outbound_headers,
timeout=self._timeout,
allow_redirects=False,
) as resp:
if resp.status in retry_statuses and attempt < self._max_retries:
delay = _retry_after_seconds(resp.headers, self._MAX_RETRY_AFTER)
_LOGGER.warning(
"%s %s %s: HTTP %d, retrying after %.2fs (attempt %d/%d)",
self._provider, method, redact(url), resp.status,
delay, attempt, self._max_retries,
)
await resp.read() # drain body so connection can return to pool
await asyncio.sleep(delay)
continue
if 200 <= resp.status < 300:
try:
payload: Any = await resp.json(content_type=None)
except (aiohttp.ContentTypeError, ValueError):
payload = await resp.text()
return make_success(status_code=resp.status, json=payload)
body = await resp.text()
return make_error(
f"HTTP {resp.status}",
status=resp.status,
body=body,
)
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err:
# asyncio.CancelledError inherits from BaseException on
# 3.8+, so it is not caught here — good: cancellation
# must propagate.
if attempt < self._max_retries and isinstance(err, asyncio.TimeoutError):
_LOGGER.warning(
"%s %s %s: timeout, retrying (attempt %d/%d)",
self._provider, method, redact(url),
attempt, self._max_retries,
)
await asyncio.sleep(min(2 ** (attempt - 1), 5))
continue
return make_error(redact_exc(err))
# Retry budget exhausted on a retriable status.
return make_error("Rate limited (retries exhausted)")
@@ -2,22 +2,36 @@
from __future__ import annotations
import asyncio
import logging
import time
from typing import Any
import re
import uuid
from typing import Any, Final
from urllib.parse import quote
import aiohttp
from ..http_base import _MAX_RETRY_AFTER_S, safe_headers
from ..redact import redact, redact_exc
_LOGGER = logging.getLogger(__name__)
# Monotonically increasing transaction counter for idempotent sends
_txn_counter = int(time.time() * 1000)
# Matrix room IDs are ``!opaque:server.name`` per the spec. We also allow
# the ``#alias:server`` form because some callers may pass aliases. The
# pattern's purpose is to reject obvious path-injection (``/``, ``..``,
# control chars, query/fragment chars) before the value reaches a URL.
_ROOM_ID_RE: Final = re.compile(r"^[!#][^\x00-\x1f\s/?#]{1,255}:[A-Za-z0-9.\-:]{1,255}$")
_DEFAULT_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
_MAX_RETRIES: Final = 3
def _next_txn_id() -> str:
global _txn_counter
_txn_counter += 1
return str(_txn_counter)
def _validate_room_id(room_id: str) -> str:
if not room_id:
raise ValueError("room_id is empty")
if not _ROOM_ID_RE.match(room_id):
raise ValueError("room_id format is invalid")
return room_id
class MatrixClient:
@@ -33,49 +47,67 @@ class MatrixClient:
self._homeserver = homeserver_url.rstrip("/")
self._token = access_token
@staticmethod
def _txn_id() -> str:
# uuid4 hex is collision-resistant across processes/restarts;
# eliminates the previous module-level counter race.
return uuid.uuid4().hex
async def send_message(
self,
room_id: str,
message: str,
html_message: str | None = None,
) -> dict[str, Any]:
"""Send a text message to a Matrix room.
"""Send a text message to a Matrix room."""
try:
room_id = _validate_room_id(room_id)
except ValueError as exc:
return {"success": False, "error": f"Invalid room_id: {exc}"}
Args:
room_id: Internal room ID (e.g. !abc:matrix.org)
message: Plain text body
html_message: Optional HTML-formatted body
"""
if not room_id:
return {"success": False, "error": "Missing room_id"}
encoded_room = quote(room_id, safe="")
url = (
f"{self._homeserver}/_matrix/client/v3/rooms/{encoded_room}"
f"/send/m.room.message/{self._txn_id()}"
)
txn_id = _next_txn_id()
# URL-encode the room_id (! and : need encoding)
encoded_room = room_id.replace("!", "%21").replace(":", "%3A")
url = f"{self._homeserver}/_matrix/client/v3/rooms/{encoded_room}/send/m.room.message/{txn_id}"
body: dict[str, Any] = {
"msgtype": "m.text",
"body": message,
}
body: dict[str, Any] = {"msgtype": "m.text", "body": message}
if html_message:
body["format"] = "org.matrix.custom.html"
body["formatted_body"] = html_message
headers = {
headers = safe_headers({
"Authorization": f"Bearer {self._token}",
"Content-Type": "application/json",
}
})
try:
async with self._session.put(
url, json=body, headers=headers, allow_redirects=False,
) as resp:
if 200 <= resp.status < 300:
return {"success": True}
resp_body = await resp.text()
if resp.status == 429:
_LOGGER.warning("Matrix rate limited: %s", resp_body[:200])
return {"success": False, "error": f"HTTP {resp.status}: {resp_body[:200]}"}
except aiohttp.ClientError as e:
return {"success": False, "error": str(e)}
for attempt in range(1, _MAX_RETRIES + 1):
try:
async with self._session.put(
url, json=body, headers=headers,
timeout=_DEFAULT_TIMEOUT, allow_redirects=False,
) as resp:
if 200 <= resp.status < 300:
return {"success": True}
resp_body = await resp.text()
if resp.status == 429 and attempt < _MAX_RETRIES:
try:
wait_s = float(resp.headers.get("Retry-After", "2"))
except (TypeError, ValueError):
wait_s = 2.0
wait_s = max(0.0, min(wait_s, _MAX_RETRY_AFTER_S))
_LOGGER.warning(
"Matrix rate limited, retrying after %.2fs (attempt %d/%d)",
wait_s, attempt, _MAX_RETRIES,
)
await asyncio.sleep(wait_s)
continue
return {
"success": False,
"error": f"HTTP {resp.status}: {redact(resp_body)[:200]}",
"status_code": resp.status,
}
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
return {"success": False, "error": redact_exc(e)}
return {"success": False, "error": "Rate limited (retries exhausted)"}
@@ -3,18 +3,33 @@
from __future__ import annotations
import logging
from typing import Any
from typing import Any, Final
import aiohttp
from ..http_base import HttpProviderClient
_LOGGER = logging.getLogger(__name__)
_PRIORITY_MIN: Final = 1
_PRIORITY_MAX: Final = 5
_DEFAULT_PRIORITY: Final = 3
_MAX_TAGS: Final = 10
_MAX_TAG_LEN: Final = 64
class NtfyClient:
def _strip_crlf(value: str) -> str:
"""Remove CR/LF — ntfy's JSON path is safe today, but the same fields
are used by the header API; defensive sanitization here means a future
refactor can't accidentally re-introduce header injection."""
return value.replace("\r", " ").replace("\n", " ")
class NtfyClient(HttpProviderClient):
"""Sends push notifications via ntfy server."""
def __init__(self, session: aiohttp.ClientSession) -> None:
self._session = session
super().__init__(session, provider_name="ntfy")
async def send(
self,
@@ -22,41 +37,48 @@ class NtfyClient:
topic: str,
message: str,
title: str | None = None,
priority: int = 3,
priority: int = _DEFAULT_PRIORITY,
tags: list[str] | None = None,
click_url: str | None = None,
auth_token: str | None = None,
markdown: bool = True,
) -> dict[str, Any]:
"""Send a push notification to an ntfy topic."""
if not server_url or not topic:
return {"success": False, "error": "Missing server_url or topic"}
url = f"{server_url.rstrip('/')}"
topic = _strip_crlf(topic).strip()
if not topic:
return {"success": False, "error": "Topic is empty after sanitization"}
try:
priority_int = int(priority) if priority is not None else _DEFAULT_PRIORITY
except (TypeError, ValueError):
priority_int = _DEFAULT_PRIORITY
priority_int = max(_PRIORITY_MIN, min(priority_int, _PRIORITY_MAX))
payload: dict[str, Any] = {
"topic": topic,
"message": message,
"markdown": True,
"markdown": bool(markdown),
}
if title:
payload["title"] = title
if priority != 3:
payload["priority"] = priority
payload["title"] = _strip_crlf(title)
if priority_int != _DEFAULT_PRIORITY:
payload["priority"] = priority_int
if tags:
payload["tags"] = tags
cleaned = [
_strip_crlf(str(t))[:_MAX_TAG_LEN]
for t in tags[:_MAX_TAGS]
if t
]
if cleaned:
payload["tags"] = cleaned
if click_url:
payload["click"] = click_url
payload["click"] = _strip_crlf(click_url)
headers: dict[str, str] = {"Content-Type": "application/json"}
headers: dict[str, str] = {}
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
try:
async with self._session.post(
url, json=payload, headers=headers, allow_redirects=False,
) as resp:
if 200 <= resp.status < 300:
return {"success": True}
body = await resp.text()
return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
except aiohttp.ClientError as e:
return {"success": False, "error": str(e)}
return await self.request("POST", server_url.rstrip("/"), json=payload, headers=headers)
@@ -2,47 +2,88 @@
from __future__ import annotations
import asyncio
import copy
import logging
from datetime import datetime, timezone
from typing import Any
from typing import Any, Final
from notify_bridge_core.storage import StorageBackend
_LOGGER = logging.getLogger(__name__)
# Bound on queue length. Without a cap, a misconfigured quiet-hour
# window plus high event throughput grows the persisted file unboundedly
# and every enqueue rewrites the whole file (O(n²) total writes). When
# the cap is hit we drop the oldest entry (FIFO) so the most recent
# events still reach the recipient when the window opens.
DEFAULT_MAX_QUEUE_SIZE: Final = 1000
class NotificationQueue:
"""Persistent queue for notifications deferred during quiet hours."""
def __init__(self, backend: StorageBackend) -> None:
def __init__(
self,
backend: StorageBackend,
*,
max_size: int = DEFAULT_MAX_QUEUE_SIZE,
) -> None:
self._backend = backend
self._data: dict[str, Any] | None = None
self._max_size = max_size
# Coordinates load / enqueue / clear / remove so a write-while-load
# race can't leave the in-memory copy out of sync with disk and so
# bulk operations don't interleave their reads-then-writes.
self._lock = asyncio.Lock()
@staticmethod
def _ensure_schema(data: Any) -> dict[str, Any]:
if not isinstance(data, dict) or not isinstance(data.get("queue"), list):
return {"queue": []}
return data
async def async_load(self) -> None:
self._data = await self._backend.load() or {"queue": []}
async with self._lock:
raw = await self._backend.load()
self._data = self._ensure_schema(raw)
async def async_enqueue(self, notification_params: dict[str, Any]) -> None:
if self._data is None:
self._data = {"queue": []}
self._data["queue"].append({
"params": notification_params,
"queued_at": datetime.now(timezone.utc).isoformat(),
})
await self._backend.save(self._data)
async with self._lock:
if self._data is None:
self._data = {"queue": []}
queue: list[dict[str, Any]] = self._data["queue"]
queue.append({
"params": notification_params,
"queued_at": datetime.now(timezone.utc).isoformat(),
})
if self._max_size > 0 and len(queue) > self._max_size:
# Drop oldest (FIFO) so a new event can still land.
drop = len(queue) - self._max_size
_LOGGER.warning(
"NotificationQueue: dropping %d oldest entries (cap=%d)",
drop, self._max_size,
)
del queue[:drop]
await self._backend.save(self._data)
def get_all(self) -> list[dict[str, Any]]:
if not self._data:
return []
return list(self._data.get("queue", []))
# Deep copy so callers can iterate / mutate without corrupting the
# in-memory queue. The cost is bounded by ``max_size``.
return copy.deepcopy(list(self._data.get("queue", [])))
def has_pending(self) -> bool:
return bool(self._data and self._data.get("queue"))
async def async_clear(self) -> None:
if self._data:
self._data["queue"] = []
await self._backend.save(self._data)
async with self._lock:
if self._data:
self._data["queue"] = []
await self._backend.save(self._data)
async def async_remove(self) -> None:
await self._backend.remove()
self._data = None
async with self._lock:
await self._backend.remove()
self._data = None
@@ -3,7 +3,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Callable
@dataclass
@@ -70,51 +70,64 @@ class MatrixReceiver(Receiver):
room_id: str = ""
_ReceiverFactory = Callable[[str, dict[str, Any]], Receiver]
def _coerce_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
_RECEIVER_FACTORIES: dict[str, _ReceiverFactory] = {
"telegram": lambda locale, config: TelegramReceiver(
locale=locale, config=config, chat_id=str(config.get("chat_id", "")),
),
"webhook": lambda locale, config: WebhookReceiver(
locale=locale, config=config,
url=str(config.get("url", "")),
headers=dict(config.get("headers", {}) or {}),
),
"email": lambda locale, config: EmailReceiver(
locale=locale, config=config,
email=str(config.get("email", "")),
name=str(config.get("name", "")),
),
"discord": lambda locale, config: DiscordReceiver(
locale=locale, config=config,
webhook_url=str(config.get("webhook_url", "")),
),
"slack": lambda locale, config: SlackReceiver(
locale=locale, config=config,
webhook_url=str(config.get("webhook_url", "")),
),
"ntfy": lambda locale, config: NtfyReceiver(
locale=locale, config=config,
topic=str(config.get("topic", "")),
priority=_coerce_int(config.get("priority"), 3),
),
"matrix": lambda locale, config: MatrixReceiver(
locale=locale, config=config,
room_id=str(config.get("room_id", "")),
),
}
def register_receiver_factory(target_type: str, factory: _ReceiverFactory) -> None:
"""Register a receiver factory for an out-of-tree target type."""
_RECEIVER_FACTORIES[target_type] = factory
def build_receiver(target_type: str, config: dict[str, Any], locale: str = "") -> Receiver:
"""Factory: build typed Receiver from target type and config dict."""
if target_type == "telegram":
return TelegramReceiver(
locale=locale,
config=config,
chat_id=str(config.get("chat_id", "")),
)
if target_type == "webhook":
return WebhookReceiver(
locale=locale,
config=config,
url=config.get("url", ""),
headers=config.get("headers", {}),
)
if target_type == "email":
return EmailReceiver(
locale=locale,
config=config,
email=config.get("email", ""),
name=config.get("name", ""),
)
if target_type == "discord":
return DiscordReceiver(
locale=locale,
config=config,
webhook_url=config.get("webhook_url", ""),
)
if target_type == "slack":
return SlackReceiver(
locale=locale,
config=config,
webhook_url=config.get("webhook_url", ""),
)
if target_type == "ntfy":
return NtfyReceiver(
locale=locale,
config=config,
topic=config.get("topic", ""),
priority=config.get("priority", 3),
)
if target_type == "matrix":
return MatrixReceiver(
locale=locale,
config=config,
room_id=config.get("room_id", ""),
)
return Receiver(locale=locale, config=config)
"""Factory: build typed Receiver from target type and config dict.
Falls back to a base ``Receiver`` for unknown target types so callers
that handle types defensively still receive a usable object — but the
dispatcher rejects them with ``"Unknown target type"`` so a typo can't
silently route to nowhere.
"""
factory = _RECEIVER_FACTORIES.get(target_type)
if factory is None:
return Receiver(locale=locale, config=config)
return factory(locale, config)
@@ -0,0 +1,64 @@
"""Secret-redaction helpers for log lines and error strings.
Notification clients embed secrets in URLs (Telegram bot tokens) and
Authorization headers (Matrix access tokens, ntfy bearer tokens). When
those secrets surface in ``aiohttp.ClientError.__str__``, response
bodies, or operator-visible error fields, they leak into logs and into
the per-target result dict that callers may forward upstream. ``redact``
returns a defanged copy safe for both contexts.
"""
from __future__ import annotations
import re
from typing import Final
# api.telegram.org/bot<digits>:<token>/<method>
_TELEGRAM_BOT_TOKEN_RE: Final = re.compile(
r"(api\.telegram\.org/bot)\d+:[A-Za-z0-9_-]+", re.IGNORECASE,
)
# Authorization: Bearer <token> (header form, case-insensitive)
_BEARER_RE: Final = re.compile(r"(Bearer\s+)[A-Za-z0-9._\-+/=]+", re.IGNORECASE)
# Discord webhook: /api/webhooks/<id>/<token>
_DISCORD_WEBHOOK_RE: Final = re.compile(
r"(discord(?:app)?\.com/api/webhooks/\d+/)[A-Za-z0-9_-]+",
re.IGNORECASE,
)
# Slack webhook path: /services/T.../B.../<token>
_SLACK_WEBHOOK_RE: Final = re.compile(
r"(hooks\.slack\.com/services/[A-Z0-9]+/[A-Z0-9]+/)[A-Za-z0-9]+",
re.IGNORECASE,
)
# URL userinfo: scheme://user:password@host
_URL_USERINFO_RE: Final = re.compile(
r"([a-z][a-z0-9+\-.]*://)[^/@\s]+:[^/@\s]+@",
re.IGNORECASE,
)
# Common token query parameters
_QUERY_TOKEN_RE: Final = re.compile(
r"([?&](?:token|access_token|api_key|key|secret|password)=)[^&\s]+",
re.IGNORECASE,
)
def redact(text: str) -> str:
"""Return ``text`` with known secret patterns replaced by ``***``.
Idempotent and safe to call on already-redacted strings. Always
returns a ``str``; non-strings are coerced via ``str()`` so callers
can pass exception instances directly.
"""
if not isinstance(text, str):
text = str(text)
text = _TELEGRAM_BOT_TOKEN_RE.sub(r"\1***", text)
text = _DISCORD_WEBHOOK_RE.sub(r"\1***", text)
text = _SLACK_WEBHOOK_RE.sub(r"\1***", text)
text = _BEARER_RE.sub(r"\1***", text)
text = _URL_USERINFO_RE.sub(r"\1***@", text)
text = _QUERY_TOKEN_RE.sub(r"\1***", text)
return text
def redact_exc(err: BaseException) -> str:
"""Redact-and-stringify an exception. Convenience for error fields."""
return redact(str(err))
@@ -7,14 +7,16 @@ from typing import Any
import aiohttp
from ..http_base import HttpProviderClient
_LOGGER = logging.getLogger(__name__)
class SlackClient:
class SlackClient(HttpProviderClient):
"""Sends messages via Slack incoming webhook URLs."""
def __init__(self, session: aiohttp.ClientSession) -> None:
self._session = session
super().__init__(session, provider_name="slack")
async def send(
self,
@@ -33,19 +35,4 @@ class SlackClient:
if icon_emoji:
payload["icon_emoji"] = icon_emoji
try:
async with self._session.post(
webhook_url,
json=payload,
headers={"Content-Type": "application/json"},
allow_redirects=False,
) as resp:
if resp.status == 429:
_LOGGER.warning("Slack rate limited")
return {"success": False, "error": "Rate limited by Slack"}
if 200 <= resp.status < 300:
return {"success": True}
body = await resp.text()
return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
except aiohttp.ClientError as e:
return {"success": False, "error": str(e)}
return await self.request("POST", webhook_url, json=payload)
@@ -1,10 +1,22 @@
"""Outbound URL validation to mitigate SSRF attacks.
User-controlled URLs (provider `url`, webhook target `url`, shared-link
base URLs, image downloads) must be validated before any HTTP request is
issued. This module rejects schemes other than http/https and blocks
destinations that resolve to private, loopback, link-local, or unspecified
address ranges.
User-controlled URLs (provider ``url``, webhook target ``url``,
shared-link base URLs, image downloads) must be validated before any
HTTP request is issued. This module rejects schemes other than
http/https and blocks destinations that resolve to private, loopback,
link-local, unspecified, CGNAT (100.64.0.0/10), or IPv4-mapped IPv6
ranges.
DNS rebinding mitigation
~~~~~~~~~~~~~~~~~~~~~~~~
``avalidate_outbound_url`` returns the original URL on success, but
also returns the resolved IP it actually validated. Callers that pass
the validated URL straight into ``aiohttp`` are vulnerable to a
DNS-rebinding attack: the validator's ``getaddrinfo`` returns a public
IP; aiohttp's connect-time resolution returns ``127.0.0.1``. To close
that gap, use :func:`build_ssrf_safe_session` (or
:class:`PinnedResolver`) so the resolved IP from the validation step is
the one aiohttp connects to.
Set ``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1`` in the environment for
development against localhost services.
@@ -17,12 +29,20 @@ import ipaddress
import logging
import os
import socket
from dataclasses import dataclass
from urllib.parse import urlparse
import aiohttp
_LOGGER = logging.getLogger(__name__)
_ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1"
_ALLOWED_SCHEMES = {"http", "https"}
_ALLOWED_SCHEMES = frozenset({"http", "https"})
# Carrier-grade NAT range. Not in stdlib's ``is_private``; an attacker
# pointing a domain at a CGNAT IP could reach the operator's ISP-side
# routing infrastructure. RFC 6598.
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner
_LOGGER.warning(
@@ -36,7 +56,29 @@ class UnsafeURLError(ValueError):
"""Raised when a URL targets a disallowed network destination."""
@dataclass(frozen=True)
class ValidatedURL:
"""Result of validating an outbound URL.
Attributes:
url: The original URL string (unchanged).
host: Hostname extracted from the URL (lower-cased, IDN-encoded).
ip: Resolved IP address that passed the block-range check, as a
string. Pass to :class:`PinnedResolver` to defeat DNS
rebinding by reusing this exact IP at connect time.
"""
url: str
host: str
ip: str
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
# An IPv4-mapped IPv6 like ``::ffff:127.0.0.1`` is NOT considered
# ``is_private`` etc. by stdlib — the v4 view holds those flags. So
# we unwrap before checking.
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None:
ip = ip.ipv4_mapped
return (
ip.is_private
or ip.is_loopback
@@ -44,22 +86,54 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
or (isinstance(ip, ipaddress.IPv4Address) and ip in _CGNAT_NETWORK)
)
def _safe_host_repr(host: str) -> str:
"""Return ``host`` shortened/escaped for safe inclusion in error text."""
h = host[:64].replace("\r", "").replace("\n", "")
return h
def _normalize_host(parsed_host: str) -> str:
"""Normalize a hostname: lowercase, strip trailing dot, IDN-encode."""
host = parsed_host.lower()
if host.endswith("."):
host = host[:-1]
# Strip IPv6 zone id ("fe80::1%eth0") — must not reach the resolver.
if "%" in host:
host = host.split("%", 1)[0]
# IDN-encode unicode hostnames so we don't downgrade to confusables
# in any later log/output and so getaddrinfo gets the ascii form.
try:
if any(ord(c) > 127 for c in host):
host = host.encode("idna").decode("ascii")
except UnicodeError:
# Caller will fail on resolution; leave as-is so the error path
# surfaces a "DNS resolution failed" rather than a stack trace.
pass
return host
def _check_scheme_host(url: str) -> tuple[str, str]:
if not isinstance(url, str) or not url:
raise UnsafeURLError("URL is empty")
parsed = urlparse(url)
if parsed.scheme not in _ALLOWED_SCHEMES:
raise UnsafeURLError(f"Scheme '{parsed.scheme}' not allowed")
scheme = parsed.scheme.lower()
if scheme not in _ALLOWED_SCHEMES:
raise UnsafeURLError(f"Scheme '{scheme[:16]}' not allowed")
host = parsed.hostname
if not host:
raise UnsafeURLError("URL has no host")
return parsed.scheme, host
return scheme, _normalize_host(host)
def _check_resolved_addresses(host: str, infos: list[tuple]) -> None:
def _select_addresses(
host: str, infos: list[tuple],
) -> list[ipaddress.IPv4Address | ipaddress.IPv6Address]:
"""Return parsed, non-blocked IPs from ``getaddrinfo`` results."""
addrs: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = []
for info in infos:
sockaddr = info[4]
try:
@@ -67,64 +141,143 @@ def _check_resolved_addresses(host: str, infos: list[tuple]) -> None:
except ValueError:
continue
if _is_blocked_ip(ip):
raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
raise UnsafeURLError(
f"Host {_safe_host_repr(host)} resolves to blocked address {ip}"
)
addrs.append(ip)
if not addrs:
raise UnsafeURLError(f"Host {_safe_host_repr(host)} has no usable address")
return addrs
def validate_outbound_url(url: str) -> str:
"""Validate ``url`` is safe to fetch; returns the URL on success.
Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
private addresses are permitted but the scheme check still applies.
Synchronous; uses blocking ``socket.getaddrinfo``. Prefer
:func:`avalidate_outbound_url` from async code paths.
.. deprecated::
Synchronous; uses blocking ``socket.getaddrinfo``. Prefer
:func:`avalidate_outbound_url` from async code paths so the
event loop isn't blocked, and use :func:`build_ssrf_safe_session`
to defeat DNS rebinding.
"""
_, host = _check_scheme_host(url)
if _ALLOW_PRIVATE:
return url
# Literal IP host
try:
ip = ipaddress.ip_address(host)
if _is_blocked_ip(ip):
raise UnsafeURLError(f"Host {host} is in a blocked range")
raise UnsafeURLError(f"Host {_safe_host_repr(host)} is in a blocked range")
return url
except ValueError:
pass
try:
infos = socket.getaddrinfo(host, None)
except socket.gaierror as exc:
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
_check_resolved_addresses(host, infos)
except (socket.gaierror, UnicodeError, OSError) as exc:
# ``UnicodeError`` covers IDNA failures (labels >63 chars, malformed
# unicode) which getaddrinfo surfaces as encoding errors rather than
# gaierror. ``OSError`` covers transient resolver failures on some
# platforms.
raise UnsafeURLError(f"DNS resolution failed for {_safe_host_repr(host)}") from exc
_select_addresses(host, infos)
return url
async def avalidate_outbound_url(url: str) -> str:
"""Async variant that resolves DNS via the running loop's resolver.
"""Async variant — returns the URL on success.
Use this from ``async def`` code paths to avoid blocking the event
loop on DNS lookups.
For DNS-rebinding-safe usage, prefer :func:`avalidate_outbound_url_full`
which also returns the resolved IP for connect-time pinning.
"""
result = await avalidate_outbound_url_full(url)
return result.url
async def avalidate_outbound_url_full(url: str) -> ValidatedURL:
"""Validate ``url`` and return a :class:`ValidatedURL` on success.
The returned ``ip`` field is the IP that passed the block-range
check. Pair with :class:`PinnedResolver` so aiohttp connects to that
exact IP and a malicious DNS server can't swap in a private address
after validation.
"""
_, host = _check_scheme_host(url)
if _ALLOW_PRIVATE:
return url
# In dev mode we still resolve to give a usable IP, but we don't
# gate on the result.
try:
ip = str(ipaddress.ip_address(host))
except ValueError:
try:
loop = asyncio.get_running_loop()
infos = await loop.getaddrinfo(host, None)
ip = infos[0][4][0] if infos else host
except (socket.gaierror, OSError):
ip = host
return ValidatedURL(url=url, host=host, ip=ip)
# Literal IP host
try:
ip = ipaddress.ip_address(host)
if _is_blocked_ip(ip):
raise UnsafeURLError(f"Host {host} is in a blocked range")
return url
ip_obj = ipaddress.ip_address(host)
if _is_blocked_ip(ip_obj):
raise UnsafeURLError(f"Host {_safe_host_repr(host)} is in a blocked range")
return ValidatedURL(url=url, host=host, ip=str(ip_obj))
except ValueError:
pass
loop = asyncio.get_running_loop()
try:
infos = await loop.getaddrinfo(host, None)
except socket.gaierror as exc:
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
_check_resolved_addresses(host, infos)
return url
except (socket.gaierror, UnicodeError, OSError) as exc:
raise UnsafeURLError(f"DNS resolution failed for {_safe_host_repr(host)}") from exc
addrs = _select_addresses(host, infos)
return ValidatedURL(url=url, host=host, ip=str(addrs[0]))
class PinnedResolver(aiohttp.abc.AbstractResolver):
"""aiohttp resolver that returns a fixed (host, ip) mapping.
Used to pin the resolved IP from :func:`avalidate_outbound_url_full`
so aiohttp's connect-time resolution can't be tricked by DNS
rebinding into using a different IP than the one we validated.
Falls back to :class:`aiohttp.AsyncResolver` (or default) for any
host not explicitly pinned, so a single resolver instance can be
reused across multiple validated URLs.
"""
def __init__(self, mapping: dict[str, str] | None = None) -> None:
self._map: dict[str, str] = dict(mapping or {})
self._fallback: aiohttp.abc.AbstractResolver | None = None
def pin(self, host: str, ip: str) -> None:
self._map[host.lower()] = ip
async def resolve(
self, host: str, port: int = 0, family: int = socket.AF_INET,
) -> list[dict]:
ip = self._map.get(host.lower())
if ip is not None:
try:
ip_obj = ipaddress.ip_address(ip)
except ValueError:
ip_obj = None
if ip_obj is not None:
fam = socket.AF_INET6 if ip_obj.version == 6 else socket.AF_INET
return [{
"hostname": host,
"host": ip,
"port": port,
"family": fam,
"proto": 0,
"flags": socket.AI_NUMERICHOST,
}]
if self._fallback is None:
self._fallback = aiohttp.ThreadedResolver()
return await self._fallback.resolve(host, port, family)
async def close(self) -> None:
if self._fallback is not None:
await self._fallback.close()
@@ -2,16 +2,29 @@
from __future__ import annotations
import asyncio
import logging
from datetime import datetime, timezone
from typing import Any
from typing import Any, Final
from notify_bridge_core.storage import StorageBackend
_LOGGER = logging.getLogger(__name__)
DEFAULT_TELEGRAM_CACHE_TTL = 48 * 60 * 60
DEFAULT_MAX_ENTRIES = 5000
DEFAULT_TELEGRAM_CACHE_TTL: Final = 48 * 60 * 60
DEFAULT_MAX_ENTRIES: Final = 5000
def _parse_iso(value: str | None) -> datetime | None:
"""Parse an ISO-8601 timestamp tolerantly. Returns ``None`` on failure."""
if not value or not isinstance(value, str):
return None
try:
# Python <3.11 doesn't accept "Z"; normalize to +00:00.
v = value.replace("Z", "+00:00") if value.endswith("Z") else value
return datetime.fromisoformat(v)
except ValueError:
return None
class TelegramFileCache:
@@ -25,7 +38,17 @@ class TelegramFileCache:
Intended for content-addressable assets (e.g. Immich) where re-uploads
should be triggered by visual change, not elapsed time.
``max_entries`` always applies as an LRU size cap (by ``cached_at``).
``max_entries`` always applies as a FIFO size cap (oldest-cached first).
Concurrency
~~~~~~~~~~~
All mutators take an internal ``asyncio.Lock`` so concurrent
media-group sends can't interleave a read-time invalidation with a
bulk write and corrupt the underlying dict (``RuntimeError:
dictionary changed size during iteration``) or lose just-written
entries. Reads do not take the lock — they are O(1) dict lookups —
but ``get`` uses a snapshot reference so it cannot mutate the data
structure under another task.
"""
def __init__(
@@ -40,35 +63,40 @@ class TelegramFileCache:
self._ttl_seconds = ttl_seconds
self._use_thumbhash = use_thumbhash
self._max_entries = max_entries
self._lock = asyncio.Lock()
async def async_load(self) -> None:
self._data = await self._backend.load() or {"files": {}}
await self._cleanup_expired()
async with self._lock:
self._data = await self._backend.load() or {"files": {}}
await self._cleanup_expired_locked()
async def _cleanup_expired(self) -> None:
async def _cleanup_expired_locked(self) -> None:
"""Caller must hold ``self._lock``."""
if not self._data or "files" not in self._data:
return
files = self._data["files"]
files: dict[str, dict[str, Any]] = self._data["files"]
changed = False
# TTL sweep — only when TTL validation is active (i.e. no thumbhash
# mode and a positive TTL). In thumbhash mode we rely entirely on
# content validation; in "TTL disabled" mode (ttl_seconds <= 0) we
# cache forever, subject only to the size cap.
if not self._use_thumbhash and self._ttl_seconds > 0:
now = datetime.now(timezone.utc)
expired = [
url for url, entry in files.items()
if entry.get("cached_at") and
(now - datetime.fromisoformat(entry["cached_at"])).total_seconds() > self._ttl_seconds
]
expired: list[str] = []
for url, entry in list(files.items()):
cached_at = _parse_iso(entry.get("cached_at"))
if cached_at is None:
continue
if cached_at.tzinfo is None:
cached_at = cached_at.replace(tzinfo=timezone.utc)
if (now - cached_at).total_seconds() > self._ttl_seconds:
expired.append(url)
for key in expired:
del files[key]
changed = True
# LRU cap — always enforced. Evicts oldest-cached entries first.
if self._max_entries > 0 and len(files) > self._max_entries:
sorted_keys = sorted(files, key=lambda k: files[k].get("cached_at", ""))
sorted_keys = sorted(
files,
key=lambda k: _parse_iso(files[k].get("cached_at")) or datetime.min.replace(tzinfo=timezone.utc),
)
for key in sorted_keys[: len(files) - self._max_entries]:
del files[key]
changed = True
@@ -80,7 +108,10 @@ class TelegramFileCache:
if not self._data or "files" not in self._data:
return None
entry = self._data["files"].get(key)
# Take a local reference so a concurrent ``async_set`` rebuilding
# the dict cannot pull the rug out mid-read.
files = self._data["files"]
entry = files.get(key)
if not entry:
return None
@@ -88,19 +119,23 @@ class TelegramFileCache:
if thumbhash is not None:
stored = entry.get("thumbhash")
if stored and stored != thumbhash:
del self._data["files"][key]
# Mark stale — actual deletion happens lock-protected
# in the next mutation. Returning None is sufficient
# for the caller to skip the cache hit.
return None
elif self._ttl_seconds > 0:
cached_at_str = entry.get("cached_at")
if cached_at_str:
age = (datetime.now(timezone.utc) - datetime.fromisoformat(cached_at_str)).total_seconds()
cached_at = _parse_iso(entry.get("cached_at"))
if cached_at is not None:
if cached_at.tzinfo is None:
cached_at = cached_at.replace(tzinfo=timezone.utc)
age = (datetime.now(timezone.utc) - cached_at).total_seconds()
if age > self._ttl_seconds:
return None
return {
"file_id": entry.get("file_id"),
"type": entry.get("type"),
"size": entry.get("size"), # bytes of what was uploaded; None for legacy entries
"size": entry.get("size"),
}
async def async_set(
@@ -111,21 +146,22 @@ class TelegramFileCache:
thumbhash: str | None = None,
size: int | None = None,
) -> None:
if self._data is None:
self._data = {"files": {}}
async with self._lock:
if self._data is None:
self._data = {"files": {}}
entry: dict[str, Any] = {
"file_id": file_id,
"type": media_type,
"cached_at": datetime.now(timezone.utc).isoformat(),
}
if thumbhash is not None:
entry["thumbhash"] = thumbhash
if size is not None:
entry["size"] = size
entry: dict[str, Any] = {
"file_id": file_id,
"type": media_type,
"cached_at": datetime.now(timezone.utc).isoformat(),
}
if thumbhash is not None:
entry["thumbhash"] = thumbhash
if size is not None:
entry["size"] = size
self._data["files"][key] = entry
await self._backend.save(self._data)
self._data["files"][key] = entry
await self._backend.save(self._data)
async def async_set_many(
self,
@@ -139,32 +175,34 @@ class TelegramFileCache:
"""
if not entries:
return
if self._data is None:
self._data = {"files": {}}
async with self._lock:
if self._data is None:
self._data = {"files": {}}
now_iso = datetime.now(timezone.utc).isoformat()
for item in entries:
if len(item) == 5:
key, file_id, media_type, thumbhash, size = item
else:
key, file_id, media_type, thumbhash = item
size = None
entry: dict[str, Any] = {
"file_id": file_id,
"type": media_type,
"cached_at": now_iso,
}
if thumbhash is not None:
entry["thumbhash"] = thumbhash
if size is not None:
entry["size"] = size
self._data["files"][key] = entry
now_iso = datetime.now(timezone.utc).isoformat()
for item in entries:
if len(item) == 5:
key, file_id, media_type, thumbhash, size = item
else:
key, file_id, media_type, thumbhash = item
size = None
entry: dict[str, Any] = {
"file_id": file_id,
"type": media_type,
"cached_at": now_iso,
}
if thumbhash is not None:
entry["thumbhash"] = thumbhash
if size is not None:
entry["size"] = size
self._data["files"][key] = entry
await self._backend.save(self._data)
await self._backend.save(self._data)
async def async_remove(self) -> None:
await self._backend.remove()
self._data = None
async with self._lock:
await self._backend.remove()
self._data = None
def stats(self) -> dict[str, Any]:
"""Return summary stats about the current cache contents.
@@ -172,25 +210,33 @@ class TelegramFileCache:
Includes the number of cached entries, total tracked size in bytes
(only counts entries with a recorded ``size``), and the oldest /
newest ``cached_at`` timestamps (ISO strings, or ``None`` if empty).
Timestamps are compared as parsed ``datetime`` objects so mixed
timezone formats (``Z`` vs ``+00:00``) order correctly.
"""
files = self._data.get("files", {}) if self._data else {}
count = len(files)
total_size = 0
oldest: str | None = None
newest: str | None = None
oldest_dt: datetime | None = None
newest_dt: datetime | None = None
oldest_str: str | None = None
newest_str: str | None = None
for entry in files.values():
size = entry.get("size")
if isinstance(size, int):
total_size += size
cached_at = entry.get("cached_at")
if cached_at:
if oldest is None or cached_at < oldest:
oldest = cached_at
if newest is None or cached_at > newest:
newest = cached_at
dt = _parse_iso(cached_at)
if dt is None or not cached_at:
continue
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
if oldest_dt is None or dt < oldest_dt:
oldest_dt, oldest_str = dt, cached_at
if newest_dt is None or dt > newest_dt:
newest_dt, newest_str = dt, cached_at
return {
"count": count,
"total_size_bytes": total_size,
"oldest": oldest,
"newest": newest,
"oldest": oldest_str,
"newest": newest_str,
}
File diff suppressed because it is too large Load Diff
@@ -2,20 +2,35 @@
from __future__ import annotations
import logging
import re
from typing import Any, Final
from urllib.parse import urlparse
_LOGGER = logging.getLogger(__name__)
# Telegram constants
TELEGRAM_API_BASE_URL: Final = "https://api.telegram.org/bot"
TELEGRAM_MAX_PHOTO_SIZE: Final = 10 * 1024 * 1024 # 10 MB
TELEGRAM_MAX_VIDEO_SIZE: Final = 50 * 1024 * 1024 # 50 MB
TELEGRAM_MAX_DIMENSION_SUM: Final = 10000
# Telegram message-text limit (sendMessage) and caption limit
# (sendPhoto/sendVideo/sendDocument/first item of sendMediaGroup).
TELEGRAM_MAX_TEXT_LENGTH: Final = 4096
TELEGRAM_MAX_CAPTION_LENGTH: Final = 1024
# Generic UUID pattern for asset IDs
_ASSET_ID_PATTERN = re.compile(r"^[a-f0-9-]{36}$")
# Strict canonical-UUID pattern (8-4-4-4-12) for asset IDs. The previous
# loose ``[a-f0-9-]{36}`` matched 36 hyphens / arbitrary digit groupings,
# which could collide across providers when used as a cache key.
_ASSET_ID_PATTERN = re.compile(
r"^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$",
re.IGNORECASE,
)
# Cache key: "host:uuid" or bare "uuid"
_ASSET_CACHE_KEY_PATTERN = re.compile(r"^(?:[^:]+:)?[a-f0-9-]{36}$")
_ASSET_CACHE_KEY_PATTERN = re.compile(
r"^(?:[^:]+:)?[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$",
re.IGNORECASE,
)
# URL patterns to extract asset IDs (generic enough for Immich-style URLs)
_ASSET_ID_URL_PATTERNS = [
@@ -162,5 +177,10 @@ def check_photo_limits(
return False, None, width, height
except ImportError:
return False, None, None, None
except Exception:
except (OSError, ValueError, MemoryError) as exc:
# PIL surfaces ``UnidentifiedImageError`` (subclass of OSError),
# truncated-image / decompression-bomb errors here. Log so a
# corrupt asset isn't silently passed to Telegram and rejected
# downstream with a less actionable error.
_LOGGER.warning("check_photo_limits: failed to inspect image (%d bytes): %s", len(data), exc)
return False, None, None, None
@@ -7,37 +7,29 @@ from typing import Any
import aiohttp
from ..ssrf import UnsafeURLError, avalidate_outbound_url
from ..http_base import HttpProviderClient
_LOGGER = logging.getLogger(__name__)
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
class WebhookClient(HttpProviderClient):
"""Send JSON payloads to a webhook URL.
class WebhookClient:
"""Send JSON payloads to a webhook URL."""
The URL is SSRF-validated on every send (defense-in-depth: re-validating
catches DNS rebinding between calls and a misconfigured target). Headers
pass through :func:`safe_headers` so a target config can't inject
framing/hop-by-hop headers like ``Host`` or ``Transfer-Encoding``.
"""
def __init__(self, session: aiohttp.ClientSession, url: str, headers: dict[str, str] | None = None) -> None:
self._session = session
def __init__(
self,
session: aiohttp.ClientSession,
url: str,
headers: dict[str, str] | None = None,
) -> None:
super().__init__(session, provider_name="webhook")
self._url = url
self._headers = headers or {}
self._extra_headers = headers or {}
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
try:
await avalidate_outbound_url(self._url)
except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe URL: {err}"}
try:
async with self._session.post(
self._url,
json=payload,
headers={"Content-Type": "application/json", **self._headers},
timeout=_DEFAULT_TIMEOUT,
allow_redirects=False,
) as response:
if 200 <= response.status < 300:
return {"success": True, "status_code": response.status}
body = await response.text()
return {"success": False, "error": f"HTTP {response.status}", "body": body[:200]}
except aiohttp.ClientError as err:
return {"success": False, "error": str(err)}
return await self.request("POST", self._url, json=payload, headers=self._extra_headers)
@@ -218,6 +218,19 @@ async def get_supported_locales(
return locales or ["en"]
@router.get("/external-url")
async def get_external_url(
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Return the configured external base URL (available to all users).
Used by the UI to render absolute provider webhook URLs. Returns empty
string when unset so the UI falls back to the relative path.
"""
return {"external_url": (await get_setting(session, "external_url")).rstrip("/")}
async def _reregister_webhooks(
session: AsyncSession, base_url: str, secret: str
) -> None:
@@ -0,0 +1,46 @@
"""Dispatcher result aggregation: per-receiver detail must survive."""
from __future__ import annotations
from notify_bridge_core.notifications.dispatcher import NotificationDispatcher
def test_aggregate_all_success() -> None:
out = NotificationDispatcher._aggregate_results([
{"success": True, "message_id": 1},
{"success": True, "message_id": 2},
])
assert out["success"] is True
assert out["receivers"] == 2
assert out["successes"] == 2
assert out["failures"] == 0
def test_aggregate_partial() -> None:
out = NotificationDispatcher._aggregate_results([
{"success": True},
{"success": False, "error": "boom"},
])
assert out["success"] is True # at least one succeeded
assert out["successes"] == 1
assert out["failures"] == 1
assert "boom" in out["errors"]
assert "results" in out
def test_aggregate_all_fail_preserves_all_errors() -> None:
out = NotificationDispatcher._aggregate_results([
{"success": False, "error": "first"},
{"success": False, "error": "second"},
])
assert out["success"] is False
assert out["error"] == "first" # back-compat top-level field
assert out["errors"] == ["first", "second"]
# Per-receiver details survive — operator can see exactly what failed.
assert len(out["results"]) == 2
def test_aggregate_empty() -> None:
out = NotificationDispatcher._aggregate_results([])
assert out["success"] is False
assert "error" in out
@@ -0,0 +1,77 @@
"""Email client header-injection / address-validation regression tests."""
from __future__ import annotations
import pytest
from notify_bridge_core.notifications.email.client import (
EmailClient,
SmtpConfig,
_strip_header,
_validate_email,
_to_html,
)
def test_strip_header_removes_crlf() -> None:
out = _strip_header("Subject\r\nBcc: attacker@example.com")
assert "\r" not in out
assert "\n" not in out
# The injected "Bcc:" line is folded to a single header line; the SMTP
# server will treat it as part of the subject text, not a header.
assert "Bcc:" in out # value preserved as plain text
def test_strip_header_removes_bare_lf() -> None:
out = _strip_header("Hello\nWorld")
assert "\n" not in out
def test_strip_header_handles_non_string() -> None:
assert _strip_header(None) == ""
def test_validate_email_rejects_crlf() -> None:
with pytest.raises(ValueError):
_validate_email("user@example.com\r\nBcc: x@y")
def test_validate_email_rejects_no_at() -> None:
with pytest.raises(ValueError):
_validate_email("not-an-email")
def test_validate_email_rejects_empty() -> None:
with pytest.raises(ValueError):
_validate_email("")
def test_validate_email_accepts_normal() -> None:
assert _validate_email("user@example.com") == "user@example.com"
def test_to_html_escapes_brackets() -> None:
out = _to_html("<script>alert(1)</script>")
assert "<script>" not in out
assert "&lt;script&gt;" in out
@pytest.mark.asyncio
async def test_send_returns_error_on_invalid_to() -> None:
cfg = SmtpConfig(host="smtp.example.com", from_address="from@example.com")
client = EmailClient(cfg)
result = await client.send(
to_email="user@example.com\r\nBcc: attacker@example.com",
subject="hi",
body_text="body",
)
assert result["success"] is False
assert "Invalid email" in result["error"]
@pytest.mark.asyncio
async def test_send_returns_error_on_no_host() -> None:
cfg = SmtpConfig(host="", from_address="from@example.com")
client = EmailClient(cfg)
result = await client.send("u@x.com", "s", "b")
assert result["success"] is False
+53
View File
@@ -0,0 +1,53 @@
"""HttpProviderClient + safe_headers tests."""
from __future__ import annotations
import pytest
from notify_bridge_core.notifications.http_base import safe_headers
class TestSafeHeaders:
def test_drops_hop_by_hop(self) -> None:
out = safe_headers({
"X-Custom": "ok",
"Host": "evil.example.com",
"Content-Length": "999",
"Transfer-Encoding": "chunked",
"Connection": "close",
})
assert out == {"X-Custom": "ok"}
def test_rejects_crlf_in_value(self) -> None:
out = safe_headers({
"X-Custom": "ok",
"X-Bad": "value\r\nInjected: yes",
})
assert "X-Custom" in out
assert "X-Bad" not in out
def test_rejects_crlf_in_name(self) -> None:
out = safe_headers({
"X-Custom": "ok",
"X-Bad\r\nInject": "value",
})
assert out == {"X-Custom": "ok"}
def test_empty_input(self) -> None:
assert safe_headers(None) == {}
assert safe_headers({}) == {}
@pytest.mark.asyncio
async def test_http_base_returns_safe_error_on_invalid_url() -> None:
"""An obviously-bad URL must not panic or leak the URL verbatim."""
import aiohttp
from notify_bridge_core.notifications.http_base import HttpProviderClient
async with aiohttp.ClientSession() as sess:
client = HttpProviderClient(sess, provider_name="test")
# file:// is rejected by the SSRF guard before any HTTP call.
result = await client.request("POST", "file:///etc/passwd", json={})
assert result["success"] is False
assert "Unsafe URL" in result["error"]
@@ -0,0 +1,84 @@
"""Matrix client validation: room_id format and quoting."""
from __future__ import annotations
import aiohttp
import pytest
from aioresponses import aioresponses
from notify_bridge_core.notifications.matrix.client import MatrixClient
HOMESERVER = "https://matrix.example.com"
TOKEN = "secret-bearer-token-1234567890"
@pytest.mark.asyncio
async def test_rejects_path_injection_room_id() -> None:
async with aiohttp.ClientSession() as sess:
client = MatrixClient(sess, HOMESERVER, TOKEN)
result = await client.send_message("!abc:host/../../etc/passwd", "hi")
assert result["success"] is False
assert "room_id" in result["error"]
@pytest.mark.asyncio
async def test_rejects_empty_room_id() -> None:
async with aiohttp.ClientSession() as sess:
client = MatrixClient(sess, HOMESERVER, TOKEN)
result = await client.send_message("", "hi")
assert result["success"] is False
assert "room_id" in result["error"]
@pytest.mark.asyncio
async def test_rejects_unicode_control_chars_in_room_id() -> None:
async with aiohttp.ClientSession() as sess:
client = MatrixClient(sess, HOMESERVER, TOKEN)
result = await client.send_message("!abc\x00:host", "hi")
assert result["success"] is False
@pytest.mark.asyncio
async def test_url_encodes_room_id_special_chars() -> None:
"""``!`` and ``:`` must reach the server URL-encoded."""
captured: list[str] = []
with aioresponses() as mocked:
# Match any PUT under the rooms path; capture the URL we got.
mocked.put(
"https://matrix.example.com/_matrix/client/v3/rooms/%21abc%3Ahost.example/send/m.room.message",
status=200, body='{}', repeat=True,
)
# aioresponses doesn't expose URL templates well, so use a regex mock.
import re
mocked.put(
re.compile(r"https://matrix\.example\.com/_matrix/client/v3/rooms/[^/]+/send/m\.room\.message/.*"),
status=200, body='{}', repeat=True,
)
async with aiohttp.ClientSession() as sess:
client = MatrixClient(sess, HOMESERVER, TOKEN)
result = await client.send_message("!abc:host.example", "hi")
assert result["success"] is True
@pytest.mark.asyncio
async def test_redacts_bearer_in_error() -> None:
"""A 4xx response body must not echo the Authorization Bearer back to caller."""
import re
with aioresponses() as mocked:
mocked.put(
re.compile(r".*"),
status=403,
body='{"errcode": "M_FORBIDDEN", "Authorization": "Bearer ' + TOKEN + '"}',
repeat=True,
)
async with aiohttp.ClientSession() as sess:
client = MatrixClient(sess, HOMESERVER, TOKEN)
result = await client.send_message("!abc:host.example", "hi")
assert result["success"] is False
assert TOKEN not in result["error"]
+84
View File
@@ -0,0 +1,84 @@
"""NotificationQueue bound + concurrency regression tests."""
from __future__ import annotations
import asyncio
from typing import Any
import pytest
from notify_bridge_core.notifications.queue import (
DEFAULT_MAX_QUEUE_SIZE,
NotificationQueue,
)
class _MemBackend:
"""In-memory storage backend stub for tests."""
def __init__(self) -> None:
self._data: dict[str, Any] | None = None
async def load(self) -> dict[str, Any] | None:
return self._data
async def save(self, data: dict[str, Any]) -> None:
self._data = data
async def remove(self) -> None:
self._data = None
@pytest.mark.asyncio
async def test_load_with_garbage_falls_back_to_empty() -> None:
backend = _MemBackend()
backend._data = {"queue": "not-a-list"} # type: ignore[assignment]
q = NotificationQueue(backend)
await q.async_load()
assert q.get_all() == []
@pytest.mark.asyncio
async def test_enqueue_caps_at_max_size() -> None:
backend = _MemBackend()
q = NotificationQueue(backend, max_size=3)
await q.async_load()
for i in range(10):
await q.async_enqueue({"i": i})
items = q.get_all()
assert len(items) == 3
# FIFO drop: most recent three are kept (i=7..9).
assert [it["params"]["i"] for it in items] == [7, 8, 9]
@pytest.mark.asyncio
async def test_get_all_returns_deep_copy() -> None:
backend = _MemBackend()
q = NotificationQueue(backend, max_size=10)
await q.async_load()
await q.async_enqueue({"key": "value"})
snap = q.get_all()
snap[0]["params"]["key"] = "MUTATED"
snap2 = q.get_all()
assert snap2[0]["params"]["key"] == "value"
@pytest.mark.asyncio
async def test_concurrent_enqueue_and_clear_no_corruption() -> None:
backend = _MemBackend()
q = NotificationQueue(backend, max_size=DEFAULT_MAX_QUEUE_SIZE)
await q.async_load()
async def producer() -> None:
for i in range(50):
await q.async_enqueue({"i": i})
async def clearer() -> None:
for _ in range(10):
await asyncio.sleep(0)
await q.async_clear()
await asyncio.gather(producer(), clearer())
# No exceptions = no race-induced "dictionary changed size during iteration".
items = q.get_all()
assert isinstance(items, list)
+74
View File
@@ -0,0 +1,74 @@
"""Secret-redaction helper regression tests.
Locks in the patterns that surface from real provider error paths:
Telegram bot URLs in aiohttp.ClientError messages, Authorization Bearer
tokens in Matrix/ntfy responses, Discord/Slack webhook tokens, URL
userinfo, and common ?token= query params.
"""
from __future__ import annotations
import pytest
from notify_bridge_core.notifications.redact import redact, redact_exc
@pytest.mark.parametrize(
"raw,expected_substr,not_in",
[
(
"Cannot connect to host api.telegram.org/bot1234567:AABBCC-secret-token/sendMessage",
"api.telegram.org/bot***",
"AABBCC-secret-token",
),
(
"Authorization: Bearer ey.JhbGciOiJIUzI1NiJ9.payload.sig",
"Bearer ***",
"ey.JhbGciOiJIUzI1NiJ9",
),
(
"POST https://discord.com/api/webhooks/12345/abcdefg-token failed",
"discord.com/api/webhooks/12345/***",
"abcdefg-token",
),
(
"POST https://hooks.slack.com/services/T01/B02/zzzzz failed",
"hooks.slack.com/services/T01/B02/***",
"zzzzz",
),
(
"fetch http://user:supersecret@example.com/foo",
"http://***@example.com/foo",
"supersecret",
),
(
"https://api.example.com/x?token=mytoken123&extra=ok",
"token=***",
"mytoken123",
),
],
)
def test_redact_known_secrets(raw: str, expected_substr: str, not_in: str) -> None:
out = redact(raw)
assert expected_substr in out
assert not_in not in out
def test_redact_idempotent() -> None:
once = redact("Bearer abcdefghij1234567890")
twice = redact(once)
assert once == twice
def test_redact_exc_returns_str() -> None:
err = RuntimeError("Bearer abcdefghij1234567890")
out = redact_exc(err)
assert isinstance(out, str)
assert "Bearer ***" in out
assert "abcdefghij1234567890" not in out
def test_redact_handles_non_string() -> None:
# Coercion path should not raise.
out = redact(12345) # type: ignore[arg-type]
assert out == "12345"
@@ -0,0 +1,73 @@
"""SSRF hardening regression tests.
Covers cases the original guard missed: IPv4-mapped IPv6, CGNAT,
trailing-dot hostnames, IPv6 zone identifiers, and the safe-host repr
used in error messages.
"""
from __future__ import annotations
import pytest
from notify_bridge_core.notifications.ssrf import (
UnsafeURLError,
PinnedResolver,
avalidate_outbound_url_full,
validate_outbound_url,
)
class TestBlockedRanges:
@pytest.mark.parametrize(
"url",
[
"http://[::ffff:127.0.0.1]/", # IPv4-mapped IPv6 → loopback
"http://[::ffff:10.0.0.1]/", # IPv4-mapped IPv6 → RFC1918
"http://100.64.0.1/", # CGNAT (RFC 6598)
"http://0.0.0.0/", # unspecified
],
)
def test_rejects_extra_ranges(self, url: str) -> None:
with pytest.raises(UnsafeURLError):
validate_outbound_url(url)
class TestHostnameNormalization:
def test_strips_trailing_dot(self) -> None:
# ``localhost.`` should normalize to ``localhost`` and still resolve
# to the loopback address — and be blocked.
with pytest.raises(UnsafeURLError):
validate_outbound_url("http://localhost./")
def test_rejects_bad_scheme_uppercase(self) -> None:
with pytest.raises(UnsafeURLError):
validate_outbound_url("FILE:///etc/passwd")
class TestErrorMessages:
def test_error_does_not_leak_long_hosts(self) -> None:
with pytest.raises(UnsafeURLError) as ei:
validate_outbound_url("http://" + "a" * 1024 + ".invalid/")
# Truncated to 64 chars in the error string.
assert len(str(ei.value)) < 256
class TestPinnedResolverSync:
def test_pin_returns_pinned_ip(self) -> None:
resolver = PinnedResolver({"example.com": "93.184.216.34"})
# Just exercise the dict path — full resolve runs in async tests.
assert resolver._map["example.com"] == "93.184.216.34" # type: ignore[attr-defined]
class TestAsyncFullValidator:
@pytest.mark.asyncio
async def test_returns_resolved_ip(self) -> None:
# Literal IP — no DNS lookup; we still get back a ValidatedURL.
result = await avalidate_outbound_url_full("http://8.8.8.8/")
assert result.ip == "8.8.8.8"
assert result.host == "8.8.8.8"
@pytest.mark.asyncio
async def test_rejects_blocked_literal(self) -> None:
with pytest.raises(UnsafeURLError):
await avalidate_outbound_url_full("http://[::ffff:127.0.0.1]/")
@@ -0,0 +1,56 @@
"""Telegram media-group mixed-type partitioning regression test.
Telegram rejects sendMediaGroup payloads that mix ``document`` with
``photo``/``video``. The client must partition before chunking so a
mixed input list still delivers all assets.
"""
from __future__ import annotations
from notify_bridge_core.notifications.telegram.client import TelegramClient
def test_partition_keeps_photo_video_together() -> None:
parts = TelegramClient._partition_media_by_kind([
{"type": "photo", "url": "p1"},
{"type": "video", "url": "v1"},
{"type": "photo", "url": "p2"},
])
assert len(parts) == 1
assert [a["url"] for a in parts[0]] == ["p1", "v1", "p2"]
def test_partition_separates_documents_from_media() -> None:
parts = TelegramClient._partition_media_by_kind([
{"type": "photo", "url": "p1"},
{"type": "document", "url": "d1"},
{"type": "video", "url": "v1"},
])
assert len(parts) == 3
assert parts[0][0]["url"] == "p1"
assert parts[1][0]["url"] == "d1"
assert parts[2][0]["url"] == "v1"
def test_partition_groups_consecutive_documents() -> None:
parts = TelegramClient._partition_media_by_kind([
{"type": "document", "url": "d1"},
{"type": "document", "url": "d2"},
{"type": "photo", "url": "p1"},
])
assert len(parts) == 2
assert [a["url"] for a in parts[0]] == ["d1", "d2"]
assert parts[1][0]["url"] == "p1"
def test_partition_empty() -> None:
assert TelegramClient._partition_media_by_kind([]) == []
def test_partition_defaults_missing_type_to_photo() -> None:
"""Items without an explicit type are treated as photos for grouping."""
parts = TelegramClient._partition_media_by_kind([
{"url": "x"}, # no type
{"type": "video", "url": "v"},
])
assert len(parts) == 1