feat: harden notification stack and switch logging selectors to icon grid

Notifications:
- Add shared http_base, redact, and SSRF hardening modules
- Refactor dispatcher, queue, receiver and per-provider clients
  (telegram, discord, email, matrix, ntfy, slack, webhook) to use
  the shared base, with bounded queue and redacted error logs
- Tests for ssrf, redact, http_base, queue bounds, dispatcher
  aggregation, telegram media partition, email and matrix clients

Frontend:
- Settings: log level / log format selectors now use IconGridSelect
  with per-option icons and i18n descriptions
- Minor providers page and entity-cache store updates

Tooling:
- Document code-review-graph MCP usage in CLAUDE.md
- Ignore .code-review-graph/, register .mcp.json
This commit is contained in:
2026-05-07 13:53:26 +03:00
parent 5bd63a2191
commit 0eb899afb9
33 changed files with 2623 additions and 1033 deletions
+2
View File
@@ -56,3 +56,5 @@ frontend/.svelte-kit/
# Logs # Logs
*.log *.log
# Added by code-review-graph
.code-review-graph/
+12
View File
@@ -0,0 +1,12 @@
{
"mcpServers": {
"code-review-graph": {
"command": "uvx",
"args": [
"code-review-graph",
"serve"
],
"type": "stdio"
}
}
}
+39
View File
@@ -43,3 +43,42 @@ Detailed context is split into focused documents under `.claude/docs/`. Read the
- Notification preview sample: `packages/server/src/notify_bridge_server/services/sample_context.py` (`_SAMPLE_CONTEXT`) - Notification preview sample: `packages/server/src/notify_bridge_server/services/sample_context.py` (`_SAMPLE_CONTEXT`)
- Command preview sample: `packages/server/src/notify_bridge_server/api/command_template_configs.py` (`sample_ctx` in `preview_raw`) - Command preview sample: `packages/server/src/notify_bridge_server/api/command_template_configs.py` (`sample_ctx` in `preview_raw`)
- Runtime validator whitelist: `packages/core/src/notify_bridge_core/templates/validator.py` - Runtime validator whitelist: `packages/core/src/notify_bridge_core/templates/validator.py`
<!-- code-review-graph MCP tools -->
## MCP Tools: code-review-graph
**IMPORTANT: This project has a knowledge graph. ALWAYS use the
code-review-graph MCP tools BEFORE using Grep/Glob/Read to explore
the codebase.** The graph is faster, cheaper (fewer tokens), and gives
you structural context (callers, dependents, test coverage) that file
scanning cannot.
### When to use graph tools FIRST
- **Exploring code**: `semantic_search_nodes` or `query_graph` instead of Grep
- **Understanding impact**: `get_impact_radius` instead of manually tracing imports
- **Code review**: `detect_changes` + `get_review_context` instead of reading entire files
- **Finding relationships**: `query_graph` with callers_of/callees_of/imports_of/tests_for
- **Architecture questions**: `get_architecture_overview` + `list_communities`
Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
### Key Tools
| Tool | Use when |
|------|----------|
| `detect_changes` | Reviewing code changes — gives risk-scored analysis |
| `get_review_context` | Need source snippets for review — token-efficient |
| `get_impact_radius` | Understanding blast radius of a change |
| `get_affected_flows` | Finding which execution paths are impacted |
| `query_graph` | Tracing callers, callees, imports, tests, dependencies |
| `semantic_search_nodes` | Finding functions/classes by name or keyword |
| `get_architecture_overview` | Understanding high-level codebase structure |
| `refactor_tool` | Planning renames, finding dead code |
### Workflow
1. The graph auto-updates on file changes (via hooks).
2. Use `detect_changes` for code review.
3. Use `get_affected_flows` to understand impact.
4. Use `query_graph` pattern="tests_for" to check coverage.
+16
View File
@@ -73,6 +73,22 @@ export const localeItems = (): GridItem[] => [
{ value: 'ru', icon: 'mdiAlphabeticalVariant', label: 'Русский', desc: t('gridDesc.localeRu') }, { value: 'ru', icon: 'mdiAlphabeticalVariant', label: 'Русский', desc: t('gridDesc.localeRu') },
]; ];
// --- Log level ---
export const logLevelItems = (): GridItem[] => [
{ value: 'DEBUG', icon: 'mdiBugOutline', label: 'DEBUG', desc: t('gridDesc.logLevelDebug') },
{ value: 'INFO', icon: 'mdiInformationOutline', label: 'INFO', desc: t('gridDesc.logLevelInfo') },
{ value: 'WARNING', icon: 'mdiAlertOutline', label: 'WARNING', desc: t('gridDesc.logLevelWarning') },
{ value: 'ERROR', icon: 'mdiAlertOctagonOutline', label: 'ERROR', desc: t('gridDesc.logLevelError') },
];
// --- Log format ---
export const logFormatItems = (): GridItem[] => [
{ value: 'text', icon: 'mdiFormatText', label: 'text', desc: t('gridDesc.logFormatText') },
{ value: 'json', icon: 'mdiCodeJson', label: 'json', desc: t('gridDesc.logFormatJson') },
];
// --- Response mode --- // --- Response mode ---
export const responseModeItems = (tFn: typeof t): GridItem[] => [ export const responseModeItems = (tFn: typeof t): GridItem[] => [
+8 -1
View File
@@ -192,7 +192,8 @@
"apiToken": "API Token", "apiToken": "API Token",
"apiTokenHint": "Optional. Needed for connection testing and repository listing.", "apiTokenHint": "Optional. Needed for connection testing and repository listing.",
"webhookUrl": "Webhook URL", "webhookUrl": "Webhook URL",
"webhookUrlHint": "Set this as the Target URL in Gitea webhook settings (relative to your bridge host).", "webhookUrlHint": "Set this as the Target URL in Gitea webhook settings. The full URL is shown when an external base URL is configured in Settings; otherwise it is relative to your bridge host.",
"webhookUrlCopyTitle": "Click to copy",
"nutHost": "NUT Server Host", "nutHost": "NUT Server Host",
"nutHostPlaceholder": "192.168.1.100 or ups.local", "nutHostPlaceholder": "192.168.1.100 or ups.local",
"nutPort": "NUT Server Port", "nutPort": "NUT Server Port",
@@ -1131,6 +1132,12 @@
"memorySourceNative": "Use Immich native memories API", "memorySourceNative": "Use Immich native memories API",
"localeEn": "English interface", "localeEn": "English interface",
"localeRu": "Russian interface", "localeRu": "Russian interface",
"logLevelDebug": "Verbose — show every step",
"logLevelInfo": "Default — high-level events",
"logLevelWarning": "Warnings and errors only",
"logLevelError": "Errors only — quietest",
"logFormatText": "Human-readable plain text",
"logFormatJson": "One JSON object per line",
"modeMedia": "Send actual photo/video files", "modeMedia": "Send actual photo/video files",
"modeText": "Send file names and links only", "modeText": "Send file names and links only",
"allEvents": "Show all event types", "allEvents": "Show all event types",
+8 -1
View File
@@ -192,7 +192,8 @@
"apiToken": "API токен", "apiToken": "API токен",
"apiTokenHint": "Необязательно. Нужен для проверки подключения и получения списка репозиториев.", "apiTokenHint": "Необязательно. Нужен для проверки подключения и получения списка репозиториев.",
"webhookUrl": "URL вебхука", "webhookUrl": "URL вебхука",
"webhookUrlHint": "Укажите этот URL в настройках вебхука Gitea (относительно хоста bridge).", "webhookUrlHint": "Укажите этот URL в настройках вебхука Gitea. Полный URL показывается, если в настройках задан внешний адрес; иначе путь указан относительно хоста bridge.",
"webhookUrlCopyTitle": "Нажмите, чтобы скопировать",
"nutHost": "Хост NUT-сервера", "nutHost": "Хост NUT-сервера",
"nutHostPlaceholder": "192.168.1.100 или ups.local", "nutHostPlaceholder": "192.168.1.100 или ups.local",
"nutPort": "Порт NUT-сервера", "nutPort": "Порт NUT-сервера",
@@ -1131,6 +1132,12 @@
"memorySourceNative": "Использовать API воспоминаний Immich", "memorySourceNative": "Использовать API воспоминаний Immich",
"localeEn": "Английский интерфейс", "localeEn": "Английский интерфейс",
"localeRu": "Русский интерфейс", "localeRu": "Русский интерфейс",
"logLevelDebug": "Подробный — каждый шаг",
"logLevelInfo": "По умолчанию — ключевые события",
"logLevelWarning": "Только предупреждения и ошибки",
"logLevelError": "Только ошибки — самый тихий",
"logFormatText": "Читаемый человеком текст",
"logFormatJson": "Один JSON-объект на строку",
"modeMedia": "Отправка файлов фото/видео", "modeMedia": "Отправка файлов фото/видео",
"modeText": "Только имена файлов и ссылки", "modeText": "Только имена файлов и ссылки",
"allEvents": "Показать все типы событий", "allEvents": "Показать все типы событий",
+28
View File
@@ -112,6 +112,34 @@ export const capabilitiesCache = (() => {
}; };
})(); })();
/** Configured external base URL — used to render absolute webhook URLs.
* Available to all authenticated users. Empty string when unset. */
export const externalUrlCache = (() => {
let data = $state<string>('');
let fetchedAt = $state(0);
let inflight: Promise<string> | null = null;
const TTL = 300_000;
return {
get value() { return data; },
invalidate() { fetchedAt = 0; },
async fetch(force = false): Promise<string> {
if (!force && fetchedAt > 0 && Date.now() - fetchedAt < TTL) return data;
if (inflight) return inflight;
inflight = (async () => {
try {
const res = await api<{ external_url: string }>('/settings/external-url');
data = (res?.external_url || '').replace(/\/+$/, '');
fetchedAt = Date.now();
return data;
} finally {
inflight = null;
}
})();
return inflight;
},
};
})();
/** Supported template locales — fetched from app settings. */ /** Supported template locales — fetched from app settings. */
export const supportedLocalesCache = (() => { export const supportedLocalesCache = (() => {
let data = $state<string[]>(['en', 'ru']); let data = $state<string[]>(['en', 'ru']);
+42 -4
View File
@@ -3,7 +3,7 @@
import { slide } from 'svelte/transition'; import { slide } from 'svelte/transition';
import { api, getBlockedBy, type BlockedByDetail } from '$lib/api'; import { api, getBlockedBy, type BlockedByDetail } from '$lib/api';
import { t } from '$lib/i18n'; import { t } from '$lib/i18n';
import { providersCache } from '$lib/stores/caches.svelte'; import { providersCache, externalUrlCache } from '$lib/stores/caches.svelte';
import PageHeader from '$lib/components/PageHeader.svelte'; import PageHeader from '$lib/components/PageHeader.svelte';
import Card from '$lib/components/Card.svelte'; import Card from '$lib/components/Card.svelte';
import Loading from '$lib/components/Loading.svelte'; import Loading from '$lib/components/Loading.svelte';
@@ -21,7 +21,7 @@
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte'; import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
import { topbarAction } from '$lib/stores/topbar-action.svelte'; import { topbarAction } from '$lib/stores/topbar-action.svelte';
import { onDestroy } from 'svelte'; import { onDestroy } from 'svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; import { snackSuccess, snackError, snackInfo } from '$lib/stores/snackbar.svelte';
import { highlightFromUrl } from '$lib/highlight'; import { highlightFromUrl } from '$lib/highlight';
import { getDescriptor, buildProviderFormDefaults } from '$lib/providers'; import { getDescriptor, buildProviderFormDefaults } from '$lib/providers';
import Button from '$lib/components/Button.svelte'; import Button from '$lib/components/Button.svelte';
@@ -45,6 +45,30 @@
let confirmDelete = $state<ServiceProvider | null>(null); let confirmDelete = $state<ServiceProvider | null>(null);
let descriptor = $derived(getDescriptor(form.type)); let descriptor = $derived(getDescriptor(form.type));
let externalUrl = $derived(externalUrlCache.value);
function buildWebhookUrl(pattern: string, token: string): string {
const path = pattern.replace('{token}', token ?? '');
return externalUrl ? `${externalUrl}${path}` : path;
}
function copyWebhookUrl(e: Event, url: string) {
e.preventDefault();
e.stopPropagation();
if (navigator.clipboard?.writeText) {
navigator.clipboard.writeText(url);
} else {
const ta = document.createElement('textarea');
ta.value = url;
ta.style.position = 'fixed';
ta.style.opacity = '0';
document.body.appendChild(ta);
ta.select();
document.execCommand('copy');
document.body.removeChild(ta);
}
snackInfo(`${t('snack.copied')}: ${url}`);
}
// Auto-update name when provider type changes (unless user manually edited) // Auto-update name when provider type changes (unless user manually edited)
$effect(() => { $effect(() => {
@@ -76,6 +100,7 @@
onclick: () => { showForm ? (showForm = false, editing = null) : openNew(); }, onclick: () => { showForm ? (showForm = false, editing = null) : openNew(); },
}); });
load(); load();
externalUrlCache.fetch().catch(() => { /* fall back to relative URLs */ });
}); });
onDestroy(() => topbarAction.clear()); onDestroy(() => topbarAction.clear());
async function load() { async function load() {
@@ -246,9 +271,15 @@
</div> </div>
{/each} {/each}
{#if descriptor?.webhookUrlPattern && editing} {#if descriptor?.webhookUrlPattern && editing}
{@const editingWebhookUrl = buildWebhookUrl(descriptor.webhookUrlPattern, providers.find(p => p.id === editing)?.webhook_token ?? '')}
<div class="bg-[var(--color-muted)] rounded-md p-3"> <div class="bg-[var(--color-muted)] rounded-md p-3">
<div class="block text-sm font-medium mb-1">{t('providers.webhookUrl')}</div> <div class="block text-sm font-medium mb-1">{t('providers.webhookUrl')}</div>
<code class="text-xs select-all break-all">{descriptor.webhookUrlPattern.replace('{token}', providers.find(p => p.id === editing)?.webhook_token ?? '')}</code> <button type="button"
onclick={(e) => copyWebhookUrl(e, editingWebhookUrl)}
title={t('providers.webhookUrlCopyTitle')}
class="text-xs break-all text-left hover:text-[var(--color-primary)] cursor-pointer font-mono w-full">
<code class="bg-transparent">{editingWebhookUrl}</code>
</button>
<p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.webhookUrlHint')}</p> <p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.webhookUrlHint')}</p>
</div> </div>
{/if} {/if}
@@ -295,7 +326,14 @@
<p class="text-xs text-[var(--color-muted-foreground)] font-mono">{provider.config.host}:{provider.config.port || 3493}</p> <p class="text-xs text-[var(--color-muted-foreground)] font-mono">{provider.config.host}:{provider.config.port || 3493}</p>
{/if} {/if}
{#if provDesc?.webhookUrlPattern} {#if provDesc?.webhookUrlPattern}
<p class="text-xs text-[var(--color-muted-foreground)] font-mono mt-0.5">{t('providers.webhookUrl')}: <span class="select-all">{provDesc.webhookUrlPattern.replace('{token}', provider.webhook_token)}</span></p> {@const webhookUrl = buildWebhookUrl(provDesc.webhookUrlPattern, provider.webhook_token)}
<p class="text-xs text-[var(--color-muted-foreground)] font-mono mt-0.5">
{t('providers.webhookUrl')}:
<button type="button"
onclick={(e) => copyWebhookUrl(e, webhookUrl)}
title={t('providers.webhookUrlCopyTitle')}
class="hover:text-[var(--color-primary)] cursor-pointer break-all text-left">{webhookUrl}</button>
</p>
{/if} {/if}
</div> </div>
</div> </div>
+6 -12
View File
@@ -12,7 +12,10 @@
import ConfirmModal from '$lib/components/ConfirmModal.svelte'; import ConfirmModal from '$lib/components/ConfirmModal.svelte';
import LocaleSelector from '$lib/components/LocaleSelector.svelte'; import LocaleSelector from '$lib/components/LocaleSelector.svelte';
import TimezoneSelector from '$lib/components/TimezoneSelector.svelte'; import TimezoneSelector from '$lib/components/TimezoneSelector.svelte';
import IconGridSelect from '$lib/components/IconGridSelect.svelte';
import { logLevelItems, logFormatItems } from '$lib/grid-items';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import { externalUrlCache } from '$lib/stores/caches.svelte';
interface CacheBucketStats { interface CacheBucketStats {
count: number; count: number;
@@ -76,6 +79,7 @@
saving = true; error = ''; saving = true; error = '';
try { try {
settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) }); settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) });
externalUrlCache.invalidate();
snackSuccess(t('settings.saved')); snackSuccess(t('settings.saved'));
} catch (err: any) { error = err.message; snackError(err.message); } } catch (err: any) { error = err.message; snackError(err.message); }
saving = false; saving = false;
@@ -221,21 +225,11 @@
<div class="grid grid-cols-1 sm:grid-cols-2 gap-4"> <div class="grid grid-cols-1 sm:grid-cols-2 gap-4">
<div> <div>
<label class="block text-xs font-medium mb-1">{t('settings.logLevel')}<Hint text={t('settings.logLevelHint')} /></label> <label class="block text-xs font-medium mb-1">{t('settings.logLevel')}<Hint text={t('settings.logLevelHint')} /></label>
<select bind:value={settings.log_level} <IconGridSelect items={logLevelItems()} bind:value={settings.log_level} columns={2} />
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
<option value="DEBUG">DEBUG</option>
<option value="INFO">INFO</option>
<option value="WARNING">WARNING</option>
<option value="ERROR">ERROR</option>
</select>
</div> </div>
<div> <div>
<label class="block text-xs font-medium mb-1">{t('settings.logFormat')}<Hint text={t('settings.logFormatHint')} /></label> <label class="block text-xs font-medium mb-1">{t('settings.logFormat')}<Hint text={t('settings.logFormatHint')} /></label>
<select bind:value={settings.log_format} <IconGridSelect items={logFormatItems()} bind:value={settings.log_format} columns={2} />
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
<option value="text">text</option>
<option value="json">json</option>
</select>
</div> </div>
<div class="sm:col-span-2"> <div class="sm:col-span-2">
<label class="block text-xs font-medium mb-1">{t('settings.logLevels')}<Hint text={t('settings.logLevelsHint')} /></label> <label class="block text-xs font-medium mb-1">{t('settings.logLevels')}<Hint text={t('settings.logLevelsHint')} /></label>
@@ -4,21 +4,24 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from typing import Any from typing import Any, Final
import aiohttp import aiohttp
from ..http_base import HttpProviderClient
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# Discord webhook content limit # Discord API constraints (per webhook docs).
MAX_CONTENT_LENGTH = 2000 MAX_CONTENT_LENGTH: Final = 2000
MAX_USERNAME_LENGTH: Final = 80
class DiscordClient: class DiscordClient(HttpProviderClient):
"""Sends messages via Discord webhook URLs.""" """Sends messages via Discord webhook URLs."""
def __init__(self, session: aiohttp.ClientSession) -> None: def __init__(self, session: aiohttp.ClientSession) -> None:
self._session = session super().__init__(session, provider_name="discord")
async def send( async def send(
self, self,
@@ -33,6 +36,8 @@ class DiscordClient:
""" """
if not webhook_url: if not webhook_url:
return {"success": False, "error": "Missing webhook_url"} return {"success": False, "error": "Missing webhook_url"}
if username and len(username) > MAX_USERNAME_LENGTH:
return {"success": False, "error": f"username exceeds {MAX_USERNAME_LENGTH} chars"}
chunks = _split_message(message, MAX_CONTENT_LENGTH) chunks = _split_message(message, MAX_CONTENT_LENGTH)
for chunk in chunks: for chunk in chunks:
@@ -42,71 +47,34 @@ class DiscordClient:
if avatar_url: if avatar_url:
payload["avatar_url"] = avatar_url payload["avatar_url"] = avatar_url
result = await self._post(webhook_url, payload) result = await self.request("POST", webhook_url, json=payload)
if not result["success"]: if not result.get("success"):
return result return result
# Small delay between chunks to respect rate limits
if len(chunks) > 1: if len(chunks) > 1:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
return {"success": True} return {"success": True}
_MAX_RETRIES = 3
_MAX_RETRY_AFTER = 60.0
async def _post(self, url: str, payload: dict) -> dict[str, Any]:
"""POST with bounded 429 retry.
We cap retries at _MAX_RETRIES and the ``Retry-After`` header at
_MAX_RETRY_AFTER seconds so a hostile or misbehaving upstream cannot
pin the dispatch task indefinitely.
"""
for attempt in range(self._MAX_RETRIES + 1):
try:
async with self._session.post(
url,
json=payload,
headers={"Content-Type": "application/json"},
allow_redirects=False,
) as resp:
if resp.status == 429 and attempt < self._MAX_RETRIES:
try:
retry_after = float(resp.headers.get("Retry-After", "2"))
except (TypeError, ValueError):
retry_after = 2.0
retry_after = max(0.0, min(retry_after, self._MAX_RETRY_AFTER))
_LOGGER.warning(
"Discord rate limited, retrying after %.1fs (attempt %d/%d)",
retry_after, attempt + 1, self._MAX_RETRIES,
)
await asyncio.sleep(retry_after)
continue
if 200 <= resp.status < 300:
return {"success": True}
body = await resp.text()
return {
"success": False,
"error": f"HTTP {resp.status}: {body[:200]}",
}
except aiohttp.ClientError as e:
return {"success": False, "error": str(e)}
return {"success": False, "error": "Rate limited (retries exhausted)"}
def _split_message(text: str, limit: int) -> list[str]: def _split_message(text: str, limit: int) -> list[str]:
"""Split message into chunks respecting the character limit.""" """Split message into chunks respecting the character limit.
Drops chunks that contain only whitespace — Discord rejects those.
"""
if len(text) <= limit: if len(text) <= limit:
return [text] return [text]
chunks = [] chunks: list[str] = []
while text: while text:
if len(text) <= limit: if len(text) <= limit:
chunks.append(text) piece = text
break text = ""
# Try to split at newline else:
split_at = text.rfind("\n", 0, limit) split_at = text.rfind("\n", 0, limit)
if split_at <= 0: if split_at <= 0:
split_at = limit split_at = limit
chunks.append(text[:split_at]) piece = text[:split_at]
text = text[split_at:].lstrip("\n") text = text[split_at:].lstrip("\n")
return chunks if piece.strip():
chunks.append(piece)
return chunks or [text]
@@ -7,7 +7,7 @@ import contextlib
import logging import logging
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, AsyncIterator from typing import Any, AsyncIterator, Awaitable, Callable, Final
import aiohttp import aiohttp
@@ -15,37 +15,20 @@ from notify_bridge_core.log_context import bind_log_context, dispatch_id_var
from notify_bridge_core.models.events import ServiceEvent from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.templates.context import build_template_context from notify_bridge_core.templates.context import build_template_context
from notify_bridge_core.templates.renderer import render_template from notify_bridge_core.templates.renderer import render_template
from .ssrf import UnsafeURLError, avalidate_outbound_url
_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
# Cap on how many asset downloads run concurrently inside
# ``_preload_asset_data``. Peak memory during a send is bounded to roughly
# ``_PRELOAD_CONCURRENCY * max_asset_size`` instead of ``max_media_to_send *
# max_asset_size``, which matters on small-RAM Docker hosts when a batch
# contains many large videos.
_PRELOAD_CONCURRENCY = 6
def _new_session() -> aiohttp.ClientSession:
"""Per-dispatch aiohttp session with a sane default timeout.
We still open a short-lived session per dispatch (connection reuse across
dispatches lives in the server-side shared session), but we always attach
a total timeout so a hung peer cannot wedge the task forever.
"""
return aiohttp.ClientSession(timeout=_HTTP_TIMEOUT)
from .http_base import safe_headers
from .receiver import ( from .receiver import (
DiscordReceiver,
EmailReceiver,
MatrixReceiver,
NtfyReceiver,
Receiver, Receiver,
SlackReceiver,
TelegramReceiver, TelegramReceiver,
WebhookReceiver, WebhookReceiver,
EmailReceiver,
DiscordReceiver,
SlackReceiver,
NtfyReceiver,
MatrixReceiver,
) )
from .redact import redact_exc
from .ssrf import UnsafeURLError, avalidate_outbound_url
from .telegram.cache import TelegramFileCache from .telegram.cache import TelegramFileCache
from .telegram.client import TelegramClient from .telegram.client import TelegramClient
from .telegram.media import ( from .telegram.media import (
@@ -58,7 +41,33 @@ from .webhook.client import WebhookClient
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DEFAULT_TEMPLATE = '{{ event_type }}: "{{ collection_name }}"' DEFAULT_TEMPLATE: Final = '{{ event_type }}: "{{ collection_name }}"'
_HTTP_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
# Cap on how many asset downloads run concurrently inside
# ``_preload_asset_data``. Peak memory during a send is bounded to roughly
# ``_PRELOAD_CONCURRENCY * max_asset_size`` instead of ``max_media_to_send *
# max_asset_size``.
_PRELOAD_CONCURRENCY: Final = 6
# Cap on how many targets the dispatcher fans out to at once. With dozens
# of targets and a single hung peer, unbounded ``gather`` can pin the
# dispatch task. The cap also protects against credential-reuse rate
# limits on shared providers.
_DISPATCH_CONCURRENCY: Final = 16
# Cap on parallel per-receiver sends within a single target.
_RECEIVER_CONCURRENCY: Final = 8
# Per-target soft timeout — at the top of the dispatch tree so a single
# misbehaving target can't hold the whole batch open. Individual provider
# clients carry their own per-request timeouts on top of this.
_TARGET_TIMEOUT_S: Final = 120.0
def _new_session() -> aiohttp.ClientSession:
return aiohttp.ClientSession(timeout=_HTTP_TIMEOUT)
@dataclass @dataclass
@@ -66,17 +75,23 @@ class TargetConfig:
"""Configuration for a notification target.""" """Configuration for a notification target."""
type: str # "telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix" type: str # "telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"
config: dict[str, Any] # target-level config (bot_token, settings, etc.) config: dict[str, Any]
template_slots: dict[str, dict[str, str]] | None = None # event_type -> {locale -> template} template_slots: dict[str, dict[str, str]] | None = None
locale: str = "en" # default locale for template resolution locale: str = "en"
date_format: str = "%d.%m.%Y, %H:%M UTC" date_format: str = "%d.%m.%Y, %H:%M UTC"
date_only_format: str = "%d.%m.%Y" date_only_format: str = "%d.%m.%Y"
provider_api_key: str | None = None # API key for downloading assets from provider provider_api_key: str | None = None
provider_internal_url: str | None = None # Internal provider URL for API key scoping provider_internal_url: str | None = None
provider_external_url: str | None = None # External domain for API key scoping provider_external_url: str | None = None
receivers: list[Receiver] = field(default_factory=list) receivers: list[Receiver] = field(default_factory=list)
_SendMethod = Callable[
["NotificationDispatcher", TargetConfig, str, ServiceEvent],
Awaitable[dict[str, Any]],
]
class NotificationDispatcher: class NotificationDispatcher:
"""Dispatches ServiceEvent notifications to configured targets.""" """Dispatches ServiceEvent notifications to configured targets."""
@@ -90,18 +105,11 @@ class NotificationDispatcher:
self._url_cache = url_cache self._url_cache = url_cache
self._asset_cache = asset_cache self._asset_cache = asset_cache
# Optional shared session owned by the caller; when supplied we reuse # Optional shared session owned by the caller; when supplied we reuse
# its connection pool instead of opening a fresh per-dispatch session # its connection pool instead of opening a fresh per-dispatch session.
# (saves a TLS handshake per outbound call).
self._shared_session = session self._shared_session = session
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]: async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]:
"""Yield an aiohttp session, reusing the shared one if provided.
When a shared session was passed in ``__init__`` we yield it without
closing (the caller owns its lifetime). Otherwise we open a
short-lived session with our default timeout and close it on exit.
"""
if self._shared_session is not None and not self._shared_session.closed: if self._shared_session is not None and not self._shared_session.closed:
yield self._shared_session yield self._shared_session
return return
@@ -115,11 +123,9 @@ class NotificationDispatcher:
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Send event notification to all targets. """Send event notification to all targets.
Returns list of results (one per target). Returns one result per target. Per-target failures are isolated;
a single bad target cannot poison the batch.
""" """
# Bind a dispatch_id so every log line emitted by the target sends
# (including deep in TelegramClient) can be correlated to the same
# upstream event.
new_id = dispatch_id_var.get() or f"disp:{uuid.uuid4().hex[:12]}" new_id = dispatch_id_var.get() or f"disp:{uuid.uuid4().hex[:12]}"
with bind_log_context(dispatch_id=new_id): with bind_log_context(dispatch_id=new_id):
@@ -128,20 +134,36 @@ class NotificationDispatcher:
event.event_type.value if hasattr(event.event_type, "value") else event.event_type, event.event_type.value if hasattr(event.event_type, "value") else event.event_type,
getattr(event, "collection_name", None), len(targets), getattr(event, "collection_name", None), len(targets),
) )
sem = asyncio.Semaphore(_DISPATCH_CONCURRENCY)
async def run_one(t: TargetConfig) -> dict[str, Any]:
async with sem:
try:
return await asyncio.wait_for(
self._send_to_target(event, t),
timeout=_TARGET_TIMEOUT_S,
)
except asyncio.TimeoutError:
return {
"success": False,
"error": f"Target dispatch timed out after {_TARGET_TIMEOUT_S}s",
}
raw_results = await asyncio.gather( raw_results = await asyncio.gather(
*[self._send_to_target(event, t) for t in targets], *[run_one(t) for t in targets],
return_exceptions=True, return_exceptions=True,
) )
results = [] results: list[dict[str, Any]] = []
failures = 0 failures = 0
for target, raw in zip(targets, raw_results): for target, raw in zip(targets, raw_results):
if isinstance(raw, Exception): if isinstance(raw, Exception):
failures += 1 failures += 1
_LOGGER.error( _LOGGER.error(
"Dispatch to target type=%s failed: %s", "Dispatch to target type=%s failed: %s",
target.type, raw, exc_info=raw, target.type, redact_exc(raw),
) )
results.append({"success": False, "error": str(raw)}) results.append({"success": False, "error": redact_exc(raw)})
else: else:
if isinstance(raw, dict) and not raw.get("success"): if isinstance(raw, dict) and not raw.get("success"):
failures += 1 failures += 1
@@ -155,7 +177,6 @@ class NotificationDispatcher:
def _resolve_template( def _resolve_template(
self, event: ServiceEvent, target: TargetConfig, locale: str, self, event: ServiceEvent, target: TargetConfig, locale: str,
) -> str: ) -> str:
"""Resolve template string for an event, with locale fallback."""
template_str = DEFAULT_TEMPLATE template_str = DEFAULT_TEMPLATE
if target.template_slots: if target.template_slots:
locale_map = target.template_slots.get(event.event_type.value) locale_map = target.template_slots.get(event.event_type.value)
@@ -166,7 +187,6 @@ class NotificationDispatcher:
def _render_message( def _render_message(
self, event: ServiceEvent, target: TargetConfig, locale: str, self, event: ServiceEvent, target: TargetConfig, locale: str,
) -> str: ) -> str:
"""Resolve template and render message for a given locale."""
template_str = self._resolve_template(event, target, locale) template_str = self._resolve_template(event, target, locale)
ctx = build_template_context( ctx = build_template_context(
event, target_type=target.type, event, target_type=target.type,
@@ -179,7 +199,6 @@ class NotificationDispatcher:
self, receiver: Receiver, default_message: str, self, receiver: Receiver, default_message: str,
event: ServiceEvent, target: TargetConfig, event: ServiceEvent, target: TargetConfig,
) -> str: ) -> str:
"""Return per-receiver message, re-rendering if receiver has a different locale."""
if receiver.locale and receiver.locale != target.locale: if receiver.locale and receiver.locale != target.locale:
return self._render_message(event, target, receiver.locale) return self._render_message(event, target, receiver.locale)
return default_message return default_message
@@ -187,21 +206,16 @@ class NotificationDispatcher:
async def _send_to_target( async def _send_to_target(
self, event: ServiceEvent, target: TargetConfig self, event: ServiceEvent, target: TargetConfig
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Send event to a single target (potentially multiple receivers).""" """Dispatch to a single target via the registered handler."""
default_message = self._render_message(event, target, target.locale) default_message = self._render_message(event, target, target.locale)
send_method = _PROVIDER_HANDLERS.get(target.type)
if send_method is None:
return {"success": False, "error": f"Unknown target type: {target.type}"}
return await send_method(self, target, default_message, event)
send_method = { # ------------------------------------------------------------------
"telegram": self._send_telegram, # Asset preload (Telegram-specific)
"webhook": self._send_webhook, # ------------------------------------------------------------------
"email": self._send_email,
"discord": self._send_discord,
"slack": self._send_slack,
"ntfy": self._send_ntfy,
"matrix": self._send_matrix,
}.get(target.type)
if send_method:
return await send_method(target, default_message, event)
return {"success": False, "error": f"Unknown target type: {target.type}"}
async def _preload_asset_data( async def _preload_asset_data(
self, self,
@@ -210,36 +224,13 @@ class NotificationDispatcher:
session: aiohttp.ClientSession, session: aiohttp.ClientSession,
max_size: int | None, max_size: int | None,
) -> None: ) -> None:
"""Download each non-cached asset's bytes once and attach to the entry. """Download each non-cached asset's bytes once, with SSRF guard."""
Three benefits:
* ``TelegramClient`` sees ``entry["data"]`` and skips its own download,
so we don't fetch each URL twice.
* We know the exact upload size, which lets the oversize warning in
the rendered text compare against real bytes (for Immich videos,
the transcoded ``/video/playback``), not the original ``file_size``.
* Assets already in the Telegram file_id cache are skipped, and their
stored size (if any) is used to populate ``playback_size`` — so
templates see consistent sizes for repeat sends without re-download.
Entries whose download fails or exceeds ``max_size`` are left without
``data``; ``TelegramClient`` will then fall back to its own download
path and apply the same checks — no regression, just no preload win.
Concurrency is bounded by ``_PRELOAD_CONCURRENCY`` so peak memory
stays predictable: at most N assets worth of bytes held in RAM at
once, regardless of ``max_media_to_send``. Total wall-clock is
unchanged for small batches and only marginally slower for large
ones (most assets fit in a single RTT and SSL negotiation cost
dominates, so 6-way parallelism is sufficient).
"""
if not assets: if not assets:
return return
sem = asyncio.Semaphore(_PRELOAD_CONCURRENCY) sem = asyncio.Semaphore(_PRELOAD_CONCURRENCY)
async def _fetch(entry: dict[str, Any], media: Any) -> None: async def fetch(entry: dict[str, Any], media: Any) -> None:
# Cache hit → skip download; populate playback_size from stored size.
cache, key = self._cache_for_entry(entry) cache, key = self._cache_for_entry(entry)
if cache and key: if cache and key:
cached = cache.get(key) cached = cache.get(key)
@@ -251,28 +242,40 @@ class NotificationDispatcher:
url = entry["url"] url = entry["url"]
headers = entry.get("headers") or {} headers = entry.get("headers") or {}
try:
# Defense-in-depth: validate even though TelegramClient
# also validates. The dispatcher is what triggers the
# download, so the guard belongs here too.
await avalidate_outbound_url(url)
except UnsafeURLError as err:
_LOGGER.warning(
"Asset preload skipped: unsafe URL (%s)", redact_exc(err),
)
return
async with sem: async with sem:
try: try:
async with session.get(url, headers=headers) as resp: async with session.get(url, headers=headers) as resp:
if resp.status != 200: if resp.status != 200:
return return
data = await resp.read() data = await resp.read()
except aiohttp.ClientError: except (aiohttp.ClientError, asyncio.TimeoutError, OSError):
return return
if max_size is not None and len(data) > max_size: if max_size is not None and len(data) > max_size:
return return
entry["data"] = data entry["data"] = data
media.extra["playback_size"] = len(data) media.extra["playback_size"] = len(data)
await asyncio.gather(*(_fetch(e, m) for e, m in zip(assets, media_assets))) raw = await asyncio.gather(
*(fetch(e, m) for e, m in zip(assets, media_assets)),
return_exceptions=True,
)
for r in raw:
if isinstance(r, Exception):
_LOGGER.warning("Asset preload raised: %s", redact_exc(r))
def _cache_for_entry( def _cache_for_entry(
self, entry: dict[str, Any], self, entry: dict[str, Any],
) -> tuple[TelegramFileCache | None, str | None]: ) -> tuple[TelegramFileCache | None, str | None]:
"""Resolve (cache, key) for an asset entry — mirrors TelegramClient logic.
Returns (None, None) if no cache is configured or no key can be derived.
"""
cache_key = entry.get("cache_key") cache_key = entry.get("cache_key")
if cache_key: if cache_key:
cache = self._asset_cache if is_asset_cache_key(cache_key) else self._url_cache cache = self._asset_cache if is_asset_cache_key(cache_key) else self._url_cache
@@ -287,6 +290,10 @@ class NotificationDispatcher:
return self._url_cache, url return self._url_cache, url
return None, None return None, None
# ------------------------------------------------------------------
# Per-provider handlers
# ------------------------------------------------------------------
async def _send_telegram( async def _send_telegram(
self, target: TargetConfig, default_message: str, event: ServiceEvent self, target: TargetConfig, default_message: str, event: ServiceEvent
) -> dict[str, Any]: ) -> dict[str, Any]:
@@ -296,27 +303,25 @@ class NotificationDispatcher:
max_media = target.config.get("max_media_to_send", 50) max_media = target.config.get("max_media_to_send", 50)
max_group = target.config.get("max_media_per_group", 10) max_group = target.config.get("max_media_per_group", 10)
chunk_delay = target.config.get("media_delay", 500) chunk_delay = target.config.get("media_delay", 500)
max_size = target.config.get("max_asset_size") max_size_mb = target.config.get("max_asset_size")
if max_size: max_size_bytes = max_size_mb * 1024 * 1024 if max_size_mb else None
max_size = max_size * 1024 * 1024 # MB to bytes
send_large_as_docs = target.config.get("send_large_photos_as_documents", False) send_large_as_docs = target.config.get("send_large_photos_as_documents", False)
if not bot_token: if not bot_token:
return {"success": False, "error": "Missing bot_token"} return {"success": False, "error": "Missing bot_token"}
if not target.receivers: if not target.receivers:
return {"success": False, "error": "No receivers configured"} return {"success": False, "error": "No receivers configured"}
# Prepare assets list once (shared across receivers)
# Prefer internal URL for fetching (LAN speed vs public internet)
internal_url = (target.provider_internal_url or "").rstrip("/") internal_url = (target.provider_internal_url or "").rstrip("/")
external_url = (target.provider_external_url or "").rstrip("/") external_url = (target.provider_external_url or "").rstrip("/")
assets = [] assets: list[dict[str, Any]] = []
media_assets: list[Any] = [] # aligned with `assets` for preload media_assets: list[Any] = []
for asset in event.added_assets[:max_media]: for asset in event.added_assets[:max_media]:
url = asset.preview_url or asset.thumbnail_url or asset.full_url url = asset.preview_url or asset.thumbnail_url or asset.full_url
if not url:
continue
asset_entry = build_telegram_asset_entry( asset_entry = build_telegram_asset_entry(
url=url or "", url=url,
media_type=asset.type.value, media_type=asset.type.value,
api_key=target.provider_api_key, api_key=target.provider_api_key,
internal_url=internal_url, internal_url=internal_url,
@@ -327,26 +332,15 @@ class NotificationDispatcher:
assets.append(asset_entry) assets.append(asset_entry)
media_assets.append(asset) media_assets.append(asset)
results: list[dict[str, Any]] = []
async with self._session_ctx() as session: async with self._session_ctx() as session:
# Preload all asset bytes once so (a) TelegramClient can skip its await self._preload_asset_data(assets, media_assets, session, max_size_bytes)
# own download and (b) we know exact upload sizes in time for the
# oversize warning in the rendered text.
await self._preload_asset_data(assets, media_assets, session, max_size)
default_message = self._render_message(event, target, target.locale)
# Asset cache (when in thumbhash mode) invalidates entries when the
# asset's visual content changes. The resolver maps asset id → its
# current thumbhash. Providers that expose thumbhash put it in
# ``asset.extra["thumbhash"]`` (currently Immich).
thumbhash_map = { thumbhash_map = {
asset.id: asset.extra.get("thumbhash") asset.id: asset.extra.get("thumbhash")
for asset in event.added_assets for asset in event.added_assets
if asset.extra.get("thumbhash") if asset.extra.get("thumbhash")
} }
thumbhash_resolver = ( thumbhash_resolver = thumbhash_map.get if thumbhash_map else None
(lambda key: thumbhash_map.get(key)) if thumbhash_map else None
)
client = TelegramClient( client = TelegramClient(
session, bot_token, session, bot_token,
@@ -355,39 +349,51 @@ class NotificationDispatcher:
thumbhash_resolver=thumbhash_resolver, thumbhash_resolver=thumbhash_resolver,
) )
for receiver in target.receivers: async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, TelegramReceiver) or not receiver.chat_id: if not isinstance(receiver, TelegramReceiver) or not receiver.chat_id:
results.append({"success": False, "error": "Invalid telegram receiver"}) return {"success": False, "error": "Invalid telegram receiver"}
continue
message = self._message_for_receiver(receiver, default_message, event, target) message = self._message_for_receiver(receiver, default_message, event, target)
text_result = await client.send_message( text_result = await client.send_message(
chat_id=receiver.chat_id, chat_id=receiver.chat_id,
text=message, text=message,
disable_web_page_preview=bool(disable_preview), disable_web_page_preview=bool(disable_preview),
) )
if not text_result.get("success"): if not text_result.get("success"):
_LOGGER.warning("Failed to send to chat %s: %s", receiver.chat_id, text_result.get("error")) _LOGGER.warning(
results.append(text_result) "Failed to send to chat %s: %s",
continue receiver.chat_id, text_result.get("error"),
)
return text_result
if assets: if assets:
reply_to = text_result.get("message_id")
media_result = await client.send_notification( media_result = await client.send_notification(
chat_id=receiver.chat_id, chat_id=receiver.chat_id,
assets=assets, assets=assets,
reply_to_message_id=reply_to, reply_to_message_id=text_result.get("message_id"),
max_group_size=max_group, max_group_size=max_group,
chunk_delay=chunk_delay, chunk_delay=chunk_delay,
max_asset_data_size=max_size, max_asset_data_size=max_size_bytes,
send_large_photos_as_documents=send_large_as_docs, send_large_photos_as_documents=send_large_as_docs,
chat_action=chat_action or None, chat_action=chat_action or None,
) )
if not media_result.get("success"): if not media_result.get("success"):
_LOGGER.warning("Text sent OK but media failed for chat %s: %s", receiver.chat_id, media_result.get("error")) _LOGGER.warning(
"Text sent OK but media failed for chat %s: %s",
receiver.chat_id, media_result.get("error"),
)
# Preserve both outcomes — text succeeded, media
# didn't. Operators losing media-failure detail
# in the result dict made root-cause analysis
# impossible.
return {
"success": True,
"message_id": text_result.get("message_id"),
"media_error": media_result.get("error"),
"media_failed_at_chunk": media_result.get("failed_at_chunk"),
}
return text_result
results.append(text_result) results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results) return self._aggregate_results(results)
@@ -397,17 +403,10 @@ class NotificationDispatcher:
if not target.receivers: if not target.receivers:
return {"success": False, "error": "No receivers configured"} return {"success": False, "error": "No receivers configured"}
results: list[dict[str, Any]] = []
async with self._session_ctx() as session: async with self._session_ctx() as session:
for receiver in target.receivers: async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, WebhookReceiver) or not receiver.url: if not isinstance(receiver, WebhookReceiver) or not receiver.url:
results.append({"success": False, "error": "Invalid webhook receiver"}) return {"success": False, "error": "Invalid webhook receiver"}
continue
try:
await avalidate_outbound_url(receiver.url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
message = self._message_for_receiver(receiver, default_message, event, target) message = self._message_for_receiver(receiver, default_message, event, target)
payload = { payload = {
"message": message, "message": message,
@@ -417,8 +416,10 @@ class NotificationDispatcher:
"collection_id": event.collection_id, "collection_id": event.collection_id,
"timestamp": event.timestamp.isoformat(), "timestamp": event.timestamp.isoformat(),
} }
client = WebhookClient(session, receiver.url, receiver.headers) client = WebhookClient(session, receiver.url, safe_headers(receiver.headers))
results.append(await client.send(payload)) return await client.send(payload)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results) return self._aggregate_results(results)
@@ -431,7 +432,7 @@ class NotificationDispatcher:
if not smtp_cfg.get("host"): if not smtp_cfg.get("host"):
return {"success": False, "error": "SMTP not configured"} return {"success": False, "error": "SMTP not configured"}
client = EmailClient(SmtpConfig( email_client = EmailClient(SmtpConfig(
host=smtp_cfg["host"], host=smtp_cfg["host"],
port=int(smtp_cfg.get("port", 587)), port=int(smtp_cfg.get("port", 587)),
username=smtp_cfg.get("username", ""), username=smtp_cfg.get("username", ""),
@@ -439,27 +440,28 @@ class NotificationDispatcher:
from_address=smtp_cfg.get("from_address", ""), from_address=smtp_cfg.get("from_address", ""),
from_name=smtp_cfg.get("from_name", "Notify Bridge"), from_name=smtp_cfg.get("from_name", "Notify Bridge"),
use_tls=smtp_cfg.get("use_tls", True), use_tls=smtp_cfg.get("use_tls", True),
tls_mode=smtp_cfg.get("tls_mode", "auto"),
)) ))
if not target.receivers: if not target.receivers:
return {"success": False, "error": "No receivers configured"} return {"success": False, "error": "No receivers configured"}
subject = f"[Notify Bridge] {event.event_type.value}: {event.collection_name}" subject = f"[Notify Bridge] {event.event_type.value}: {event.collection_name}"
results: list[dict[str, Any]] = [] async def send_one(receiver: Receiver) -> dict[str, Any]:
for receiver in target.receivers:
if not isinstance(receiver, EmailReceiver) or not receiver.email: if not isinstance(receiver, EmailReceiver) or not receiver.email:
results.append({"success": False, "error": "Invalid email receiver"}) return {"success": False, "error": "Invalid email receiver"}
continue
message = self._message_for_receiver(receiver, default_message, event, target) message = self._message_for_receiver(receiver, default_message, event, target)
result = await client.send( # body_html=None lets EmailClient build a safely-escaped HTML
# alternative from body_text instead of trusting user content.
return await email_client.send(
to_email=receiver.email, to_email=receiver.email,
subject=subject, subject=subject,
body_text=message, body_text=message,
body_html=message, body_html=None,
to_name=receiver.name, to_name=receiver.name,
) )
results.append(result)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results) return self._aggregate_results(results)
async def _send_discord( async def _send_discord(
@@ -471,20 +473,16 @@ class NotificationDispatcher:
return {"success": False, "error": "No receivers configured"} return {"success": False, "error": "No receivers configured"}
username = target.config.get("username") username = target.config.get("username")
results: list[dict[str, Any]] = []
async with self._session_ctx() as session: async with self._session_ctx() as session:
client = DiscordClient(session) client = DiscordClient(session)
for receiver in target.receivers:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url: if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
results.append({"success": False, "error": "Invalid discord receiver"}) return {"success": False, "error": "Invalid discord receiver"}
continue
try:
await avalidate_outbound_url(receiver.webhook_url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
message = self._message_for_receiver(receiver, default_message, event, target) message = self._message_for_receiver(receiver, default_message, event, target)
results.append(await client.send(receiver.webhook_url, message, username=username)) return await client.send(receiver.webhook_url, message, username=username)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results) return self._aggregate_results(results)
@@ -497,20 +495,16 @@ class NotificationDispatcher:
return {"success": False, "error": "No receivers configured"} return {"success": False, "error": "No receivers configured"}
username = target.config.get("username") username = target.config.get("username")
results: list[dict[str, Any]] = []
async with self._session_ctx() as session: async with self._session_ctx() as session:
client = SlackClient(session) client = SlackClient(session)
for receiver in target.receivers:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url: if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
results.append({"success": False, "error": "Invalid slack receiver"}) return {"success": False, "error": "Invalid slack receiver"}
continue
try:
await avalidate_outbound_url(receiver.webhook_url)
except UnsafeURLError as err:
results.append({"success": False, "error": f"Unsafe URL: {err}"})
continue
message = self._message_for_receiver(receiver, default_message, event, target) message = self._message_for_receiver(receiver, default_message, event, target)
results.append(await client.send(receiver.webhook_url, message, username=username)) return await client.send(receiver.webhook_url, message, username=username)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results) return self._aggregate_results(results)
@@ -526,22 +520,23 @@ class NotificationDispatcher:
try: try:
await avalidate_outbound_url(server_url) await avalidate_outbound_url(server_url)
except UnsafeURLError as err: except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe ntfy server_url: {err}"} return {"success": False, "error": f"Unsafe ntfy server_url: {redact_exc(err)}"}
title = f"{event.event_type.value}: {event.collection_name}" title = f"{event.event_type.value}: {event.collection_name}"
results: list[dict[str, Any]] = []
async with self._session_ctx() as session: async with self._session_ctx() as session:
client = NtfyClient(session) client = NtfyClient(session)
for receiver in target.receivers:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, NtfyReceiver) or not receiver.topic: if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
results.append({"success": False, "error": "Invalid ntfy receiver"}) return {"success": False, "error": "Invalid ntfy receiver"}
continue
message = self._message_for_receiver(receiver, default_message, event, target) message = self._message_for_receiver(receiver, default_message, event, target)
results.append(await client.send( return await client.send(
server_url, receiver.topic, message, server_url, receiver.topic, message,
title=title, priority=receiver.priority, auth_token=auth_token, title=title, priority=receiver.priority, auth_token=auth_token,
)) )
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results) return self._aggregate_results(results)
@@ -557,33 +552,108 @@ class NotificationDispatcher:
try: try:
await avalidate_outbound_url(homeserver) await avalidate_outbound_url(homeserver)
except UnsafeURLError as err: except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"} return {"success": False, "error": f"Unsafe matrix homeserver_url: {redact_exc(err)}"}
if not target.receivers: if not target.receivers:
return {"success": False, "error": "No receivers configured"} return {"success": False, "error": "No receivers configured"}
results: list[dict[str, Any]] = []
async with self._session_ctx() as session: async with self._session_ctx() as session:
client = MatrixClient(session, homeserver, access_token) client = MatrixClient(session, homeserver, access_token)
for receiver in target.receivers:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id: if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
results.append({"success": False, "error": "Invalid matrix receiver"}) return {"success": False, "error": "Invalid matrix receiver"}
continue
message = self._message_for_receiver(receiver, default_message, event, target) message = self._message_for_receiver(receiver, default_message, event, target)
results.append(await client.send_message( # body_html is the same plain text — Matrix accepts the
receiver.room_id, message, html_message=message, # raw message as both ``body`` and ``formatted_body``.
)) # If templates emit HTML in the future, generate a
# separate HTML body upstream rather than aliasing here.
return await client.send_message(
receiver.room_id, message, html_message=None,
)
results = await self._fan_out(target.receivers, send_one)
return self._aggregate_results(results) return self._aggregate_results(results)
# ------------------------------------------------------------------
# Aggregation
# ------------------------------------------------------------------
@staticmethod
async def _fan_out(
receivers: list[Receiver],
send_one: Callable[[Receiver], Awaitable[dict[str, Any]]],
) -> list[dict[str, Any]]:
"""Run ``send_one`` per receiver with bounded concurrency.
Per-receiver exceptions are converted to failure dicts so a single
bad receiver can't cancel its peers.
"""
sem = asyncio.Semaphore(_RECEIVER_CONCURRENCY)
async def guarded(receiver: Receiver) -> dict[str, Any]:
async with sem:
try:
return await send_one(receiver)
except Exception as exc: # noqa: BLE001
_LOGGER.error("Receiver send raised: %s", redact_exc(exc))
return {"success": False, "error": redact_exc(exc)}
return await asyncio.gather(*(guarded(r) for r in receivers))
@staticmethod @staticmethod
def _aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]: def _aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
"""Aggregate broadcast results into a single result dict.""" """Aggregate per-receiver results into a single target-level result.
Preserves the per-receiver detail under ``receivers`` so a caller
can see exactly which receivers failed, instead of getting only
the first error.
"""
if not results:
return {"success": False, "error": "No receivers configured"}
successes = sum(1 for r in results if r.get("success")) successes = sum(1 for r in results if r.get("success"))
if successes == len(results) and results: failures = len(results) - successes
return {"success": True, "receivers": len(results)} out: dict[str, Any] = {
elif successes > 0: "success": successes > 0,
return {"success": True, "receivers": len(results), "partial_failures": len(results) - successes} "receivers": len(results),
elif results: "successes": successes,
return results[0] "failures": failures,
return {"success": False, "error": "No receivers configured"} "results": results,
}
if failures:
out["errors"] = [
r.get("error") for r in results if not r.get("success")
]
if successes == 0:
# Surface the first error at the top level for back-compat
# with callers that only check ``error``.
out["error"] = results[0].get("error", "All receivers failed")
return out
# ----------------------------------------------------------------------
# Provider registry — replaces the if/elif chain so adding a provider
# means just registering it here, not editing dispatch logic.
# ----------------------------------------------------------------------
_PROVIDER_HANDLERS: dict[str, _SendMethod] = {
"telegram": NotificationDispatcher._send_telegram,
"webhook": NotificationDispatcher._send_webhook,
"email": NotificationDispatcher._send_email,
"discord": NotificationDispatcher._send_discord,
"slack": NotificationDispatcher._send_slack,
"ntfy": NotificationDispatcher._send_ntfy,
"matrix": NotificationDispatcher._send_matrix,
}
def register_provider(name: str, handler: _SendMethod) -> None:
"""Register a new dispatcher provider at runtime.
Allows out-of-tree providers to extend the dispatcher without
forking. The handler must follow the
``async (dispatcher, target, default_message, event) -> dict`` shape.
"""
_PROVIDER_HANDLERS[name] = handler
@@ -2,14 +2,32 @@
from __future__ import annotations from __future__ import annotations
import html
import logging import logging
import re
import ssl
from dataclasses import dataclass from dataclasses import dataclass
from email.mime.multipart import MIMEMultipart from email.headerregistry import Address
from email.mime.text import MIMEText from email.message import EmailMessage
from typing import Any from typing import Any, Final, Literal
try: # Optional dependency — fail at first send rather than at import.
import aiosmtplib
from aiosmtplib import SMTPException
except ImportError: # pragma: no cover
aiosmtplib = None # type: ignore[assignment]
class SMTPException(Exception): # type: ignore[no-redef]
pass
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_DEFAULT_TIMEOUT_S: Final = 30.0
_TlsMode = Literal["auto", "implicit", "starttls", "none"]
# RFC 5322 lite: catches the obvious-bad addresses ("foo bar", "no-at",
# embedded CRLF) without pretending to fully validate addresses.
_EMAIL_RE: Final = re.compile(r"^[^\s@\r\n,;<>]+@[^\s@\r\n,;<>]+\.[^\s@\r\n,;<>]+$")
@dataclass @dataclass
class SmtpConfig: class SmtpConfig:
@@ -22,6 +40,55 @@ class SmtpConfig:
from_address: str = "" from_address: str = ""
from_name: str = "Notify Bridge" from_name: str = "Notify Bridge"
use_tls: bool = True use_tls: bool = True
# Explicit TLS mode. ``auto`` (back-compat) infers from ``use_tls`` and
# ``port``: 465 → implicit; 587 with use_tls=False → starttls; 25 → none.
tls_mode: _TlsMode = "auto"
timeout_s: float = _DEFAULT_TIMEOUT_S
def _strip_header(value: str) -> str:
"""Reject CRLF and bare CR/LF in header-bound strings.
SMTP header injection turns user-controlled subject/name strings into
arbitrary headers (``\\r\\nBcc: attacker@x``). The Python stdlib
accepts CRLF when followed by SP/HT (header folding), so explicit
sanitization is required even though :class:`EmailMessage` does some
validation of its own.
"""
return re.sub(r"[\r\n]+", " ", str(value or "")).strip()
def _validate_email(addr: str) -> str:
addr = _strip_header(addr)
if not addr:
raise ValueError("email address is empty")
if not _EMAIL_RE.match(addr):
raise ValueError("email address is invalid")
return addr
def _resolve_tls(cfg: SmtpConfig) -> tuple[bool, bool]:
"""Resolve ``(use_tls, start_tls)`` flags from the config.
``tls_mode`` overrides ``use_tls``/port heuristics when provided.
"""
mode = cfg.tls_mode
if mode == "implicit":
return True, False
if mode == "starttls":
return False, True
if mode == "none":
return False, False
# auto — preserve the historical "use_tls bool + port heuristic" behavior
# but make the path explicit.
if cfg.use_tls:
return True, False
return False, cfg.port != 25
def _to_html(text: str) -> str:
"""Convert plain text to a minimal HTML body, escaped for safety."""
return "<html><body><pre>" + html.escape(text) + "</pre></body></html>"
class EmailClient: class EmailClient:
@@ -30,30 +97,39 @@ class EmailClient:
def __init__(self, smtp_config: SmtpConfig) -> None: def __init__(self, smtp_config: SmtpConfig) -> None:
self._config = smtp_config self._config = smtp_config
@staticmethod
def _ssl_context() -> ssl.SSLContext:
# Explicit context so the TLS posture is auditable; aiosmtplib
# defaults look correct today but past regressions (and downstream
# repackaging) make implicit reliance fragile.
return ssl.create_default_context()
async def verify_connection(self) -> dict[str, Any]: async def verify_connection(self) -> dict[str, Any]:
"""Test SMTP connection and authentication without sending an email.""" """Test SMTP connection and authentication without sending an email."""
try: if aiosmtplib is None:
import aiosmtplib
except ImportError:
return {"success": False, "error": "aiosmtplib not installed"} return {"success": False, "error": "aiosmtplib not installed"}
cfg = self._config cfg = self._config
if not cfg.host: if not cfg.host:
return {"success": False, "error": "SMTP host not configured"} return {"success": False, "error": "SMTP host not configured"}
use_tls, start_tls = _resolve_tls(cfg)
try: try:
smtp = aiosmtplib.SMTP( smtp = aiosmtplib.SMTP(
hostname=cfg.host, hostname=cfg.host,
port=cfg.port, port=cfg.port,
use_tls=cfg.use_tls, use_tls=use_tls,
start_tls=not cfg.use_tls and cfg.port != 25, start_tls=start_tls,
tls_context=self._ssl_context(),
timeout=cfg.timeout_s,
validate_certs=True,
) )
await smtp.connect() await smtp.connect()
if cfg.username and cfg.password: if cfg.username and cfg.password:
await smtp.login(cfg.username, cfg.password) await smtp.login(cfg.username, cfg.password)
await smtp.quit() await smtp.quit()
return {"success": True} return {"success": True}
except Exception as e: except (SMTPException, OSError) as e:
_LOGGER.warning("SMTP verification failed for %s:%d: %s", cfg.host, cfg.port, e) _LOGGER.warning("SMTP verification failed for %s:%d: %s", cfg.host, cfg.port, e)
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@@ -65,27 +141,52 @@ class EmailClient:
body_html: str | None = None, body_html: str | None = None,
to_name: str = "", to_name: str = "",
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Send an email. Returns {"success": True} or {"success": False, "error": "..."}.""" """Send an email.
try:
import aiosmtplib Returns ``{"success": True}`` or ``{"success": False, "error": "..."}``.
except ImportError:
``body_html`` is treated as already-safe markup. Pass ``None`` to
derive a safe HTML alternative from ``body_text`` automatically.
"""
if aiosmtplib is None:
return {"success": False, "error": "aiosmtplib not installed. Run: pip install aiosmtplib"} return {"success": False, "error": "aiosmtplib not installed. Run: pip install aiosmtplib"}
cfg = self._config cfg = self._config
if not cfg.host or not cfg.from_address: if not cfg.host or not cfg.from_address:
return {"success": False, "error": "SMTP not configured (missing host or from_address)"} return {"success": False, "error": "SMTP not configured (missing host or from_address)"}
# Build email message try:
msg = MIMEMultipart("alternative") to_addr = _validate_email(to_email)
msg["From"] = f"{cfg.from_name} <{cfg.from_address}>" if cfg.from_name else cfg.from_address from_addr = _validate_email(cfg.from_address)
msg["To"] = f"{to_name} <{to_email}>" if to_name else to_email except ValueError as exc:
msg["Subject"] = subject return {"success": False, "error": f"Invalid email address: {exc}"}
msg.attach(MIMEText(body_text, "plain", "utf-8")) # EmailMessage with structured Address objects rejects CRLF and
if body_html: # framework-folds long headers safely. We still strip first because
msg.attach(MIMEText(body_html, "html", "utf-8")) # EmailMessage's display-name slot is a pure string.
msg = EmailMessage()
from_display = _strip_header(cfg.from_name) or ""
to_display = _strip_header(to_name) or ""
try:
from_user, _, from_domain = from_addr.partition("@")
to_user, _, to_domain = to_addr.partition("@")
msg["From"] = Address(from_display, from_user, from_domain) if from_display else from_addr
msg["To"] = Address(to_display, to_user, to_domain) if to_display else to_addr
except ValueError as exc:
return {"success": False, "error": f"Invalid email address: {exc}"}
msg["Subject"] = _strip_header(subject)
msg.set_content(body_text or "", subtype="plain", charset="utf-8")
# If the caller provided HTML explicitly, honor it; otherwise build a
# safe escaped version so a stray "<" in the rendered template can't
# break the markup.
msg.add_alternative(
body_html if body_html is not None else _to_html(body_text or ""),
subtype="html",
charset="utf-8",
)
use_tls, start_tls = _resolve_tls(cfg)
try: try:
await aiosmtplib.send( await aiosmtplib.send(
msg, msg,
@@ -93,11 +194,14 @@ class EmailClient:
port=cfg.port, port=cfg.port,
username=cfg.username or None, username=cfg.username or None,
password=cfg.password or None, password=cfg.password or None,
use_tls=cfg.use_tls, use_tls=use_tls,
start_tls=not cfg.use_tls and cfg.port != 25, start_tls=start_tls,
tls_context=self._ssl_context(),
timeout=cfg.timeout_s,
validate_certs=True,
) )
_LOGGER.info("Email sent to %s", to_email) _LOGGER.info("Email sent to %s", to_addr)
return {"success": True} return {"success": True}
except Exception as e: except (SMTPException, OSError) as e:
_LOGGER.error("Failed to send email to %s: %s", to_email, e) _LOGGER.error("Failed to send email to %s: %s", to_addr, e)
return {"success": False, "error": str(e)} return {"success": False, "error": str(e)}
@@ -0,0 +1,196 @@
"""Shared HTTP infrastructure for notification provider clients.
Slack/Discord/ntfy/Matrix/Webhook all follow the same pattern: build a
JSON payload, POST/PUT it, decode 200-range as success, decode 4xx/5xx
into a stable error dict, and retry transient 429/503 responses with a
capped ``Retry-After``. ``HttpProviderClient`` centralizes that pattern
so every provider gets the same SSRF guard, timeouts, secret-redacted
errors, and bounded retry policy by construction — adding a new
provider doesn't get to forget any one of them.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Final, Mapping
import aiohttp
from .redact import redact, redact_exc
from .ssrf import UnsafeURLError, avalidate_outbound_url
_LOGGER = logging.getLogger(__name__)
_DEFAULT_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
_MAX_RETRIES: Final = 3
_MAX_RETRY_AFTER_S: Final = 60.0
_RETRY_STATUSES: Final = frozenset({429, 503})
# Hop-by-hop / framing headers a caller must not be able to override via
# user-supplied target config. Letting them through enables request
# smuggling, host-header bypasses of WAFs, and cache poisoning.
_FORBIDDEN_HEADERS: Final = frozenset({
"host",
"content-length",
"transfer-encoding",
"connection",
"keep-alive",
"te",
"upgrade",
"expect",
"proxy-authorization",
"proxy-connection",
})
def safe_headers(headers: Mapping[str, str] | None) -> dict[str, str]:
"""Return a copy of ``headers`` with hop-by-hop/forbidden names dropped
and CRLF-bearing values rejected.
A target config that lets a user inject ``"X-Foo": "bar\\r\\nHost: evil"``
can perform request smuggling depending on aiohttp's framing. We strip
those values rather than letting them reach the wire.
"""
if not headers:
return {}
safe: dict[str, str] = {}
for raw_name, raw_value in headers.items():
name = str(raw_name).strip()
if not name or name.lower() in _FORBIDDEN_HEADERS:
continue
if any(c in name for c in "\r\n:"):
continue
value = str(raw_value)
if "\r" in value or "\n" in value:
continue
safe[name] = value
return safe
def make_error(message: str, *, status: int | None = None, body: str | None = None) -> dict[str, Any]:
"""Build a stable failure dict shape used by every provider client."""
err: dict[str, Any] = {"success": False, "error": redact(message)}
if status is not None:
err["status_code"] = status
if body:
err["body"] = redact(body)[:200]
return err
def make_success(**extra: Any) -> dict[str, Any]:
"""Build a stable success dict shape used by every provider client."""
out: dict[str, Any] = {"success": True}
out.update(extra)
return out
def _retry_after_seconds(headers: Mapping[str, str], cap_s: float) -> float:
raw = headers.get("Retry-After") or headers.get("retry-after") or "2"
try:
seconds = float(raw)
except (TypeError, ValueError):
seconds = 2.0
return max(0.0, min(seconds, cap_s))
class HttpProviderClient:
"""Base for JSON-over-HTTP notification providers.
Subclasses call :meth:`request` instead of using ``self._session``
directly. ``request`` runs the SSRF guard (skippable for known-safe
upstreams via ``ssrf_validate=False``), enforces a per-request
timeout, retries 429/503 with a capped ``Retry-After``, and turns
transport/HTTP errors into the canonical ``{"success": False, ...}``
shape with secrets redacted.
"""
_max_retries: int = _MAX_RETRIES
# Settable per-instance so tests / hostile-upstream tuning can
# tighten the cap. Reads of this attribute fall through to the
# class default when no instance value has been set.
_MAX_RETRY_AFTER: float = _MAX_RETRY_AFTER_S
def __init__(
self,
session: aiohttp.ClientSession,
*,
timeout: aiohttp.ClientTimeout | None = None,
provider_name: str = "http",
) -> None:
self._session = session
self._timeout = timeout or _DEFAULT_TIMEOUT
self._provider = provider_name
async def request(
self,
method: str,
url: str,
*,
json: Any = None,
headers: Mapping[str, str] | None = None,
ssrf_validate: bool = True,
retry_statuses: frozenset[int] = _RETRY_STATUSES,
) -> dict[str, Any]:
"""Send a single request with retry + redaction. Always returns a dict.
On 2xx returns ``{"success": True, "status_code": int, "json": ...
OR "body": str}``. On non-2xx returns the canonical error dict.
"""
if ssrf_validate:
try:
await avalidate_outbound_url(url)
except UnsafeURLError as err:
return make_error(f"Unsafe URL: {redact_exc(err)}")
outbound_headers: dict[str, str] = {"Content-Type": "application/json"}
outbound_headers.update(safe_headers(headers))
for attempt in range(1, self._max_retries + 1):
try:
async with self._session.request(
method,
url,
json=json,
headers=outbound_headers,
timeout=self._timeout,
allow_redirects=False,
) as resp:
if resp.status in retry_statuses and attempt < self._max_retries:
delay = _retry_after_seconds(resp.headers, self._MAX_RETRY_AFTER)
_LOGGER.warning(
"%s %s %s: HTTP %d, retrying after %.2fs (attempt %d/%d)",
self._provider, method, redact(url), resp.status,
delay, attempt, self._max_retries,
)
await resp.read() # drain body so connection can return to pool
await asyncio.sleep(delay)
continue
if 200 <= resp.status < 300:
try:
payload: Any = await resp.json(content_type=None)
except (aiohttp.ContentTypeError, ValueError):
payload = await resp.text()
return make_success(status_code=resp.status, json=payload)
body = await resp.text()
return make_error(
f"HTTP {resp.status}",
status=resp.status,
body=body,
)
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err:
# asyncio.CancelledError inherits from BaseException on
# 3.8+, so it is not caught here — good: cancellation
# must propagate.
if attempt < self._max_retries and isinstance(err, asyncio.TimeoutError):
_LOGGER.warning(
"%s %s %s: timeout, retrying (attempt %d/%d)",
self._provider, method, redact(url),
attempt, self._max_retries,
)
await asyncio.sleep(min(2 ** (attempt - 1), 5))
continue
return make_error(redact_exc(err))
# Retry budget exhausted on a retriable status.
return make_error("Rate limited (retries exhausted)")
@@ -2,22 +2,36 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import time import re
from typing import Any import uuid
from typing import Any, Final
from urllib.parse import quote
import aiohttp import aiohttp
from ..http_base import _MAX_RETRY_AFTER_S, safe_headers
from ..redact import redact, redact_exc
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# Monotonically increasing transaction counter for idempotent sends # Matrix room IDs are ``!opaque:server.name`` per the spec. We also allow
_txn_counter = int(time.time() * 1000) # the ``#alias:server`` form because some callers may pass aliases. The
# pattern's purpose is to reject obvious path-injection (``/``, ``..``,
# control chars, query/fragment chars) before the value reaches a URL.
_ROOM_ID_RE: Final = re.compile(r"^[!#][^\x00-\x1f\s/?#]{1,255}:[A-Za-z0-9.\-:]{1,255}$")
_DEFAULT_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
_MAX_RETRIES: Final = 3
def _next_txn_id() -> str: def _validate_room_id(room_id: str) -> str:
global _txn_counter if not room_id:
_txn_counter += 1 raise ValueError("room_id is empty")
return str(_txn_counter) if not _ROOM_ID_RE.match(room_id):
raise ValueError("room_id format is invalid")
return room_id
class MatrixClient: class MatrixClient:
@@ -33,49 +47,67 @@ class MatrixClient:
self._homeserver = homeserver_url.rstrip("/") self._homeserver = homeserver_url.rstrip("/")
self._token = access_token self._token = access_token
@staticmethod
def _txn_id() -> str:
# uuid4 hex is collision-resistant across processes/restarts;
# eliminates the previous module-level counter race.
return uuid.uuid4().hex
async def send_message( async def send_message(
self, self,
room_id: str, room_id: str,
message: str, message: str,
html_message: str | None = None, html_message: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Send a text message to a Matrix room. """Send a text message to a Matrix room."""
try:
room_id = _validate_room_id(room_id)
except ValueError as exc:
return {"success": False, "error": f"Invalid room_id: {exc}"}
Args: encoded_room = quote(room_id, safe="")
room_id: Internal room ID (e.g. !abc:matrix.org) url = (
message: Plain text body f"{self._homeserver}/_matrix/client/v3/rooms/{encoded_room}"
html_message: Optional HTML-formatted body f"/send/m.room.message/{self._txn_id()}"
""" )
if not room_id:
return {"success": False, "error": "Missing room_id"}
txn_id = _next_txn_id() body: dict[str, Any] = {"msgtype": "m.text", "body": message}
# URL-encode the room_id (! and : need encoding)
encoded_room = room_id.replace("!", "%21").replace(":", "%3A")
url = f"{self._homeserver}/_matrix/client/v3/rooms/{encoded_room}/send/m.room.message/{txn_id}"
body: dict[str, Any] = {
"msgtype": "m.text",
"body": message,
}
if html_message: if html_message:
body["format"] = "org.matrix.custom.html" body["format"] = "org.matrix.custom.html"
body["formatted_body"] = html_message body["formatted_body"] = html_message
headers = { headers = safe_headers({
"Authorization": f"Bearer {self._token}", "Authorization": f"Bearer {self._token}",
"Content-Type": "application/json", "Content-Type": "application/json",
} })
try: for attempt in range(1, _MAX_RETRIES + 1):
async with self._session.put( try:
url, json=body, headers=headers, allow_redirects=False, async with self._session.put(
) as resp: url, json=body, headers=headers,
if 200 <= resp.status < 300: timeout=_DEFAULT_TIMEOUT, allow_redirects=False,
return {"success": True} ) as resp:
resp_body = await resp.text() if 200 <= resp.status < 300:
if resp.status == 429: return {"success": True}
_LOGGER.warning("Matrix rate limited: %s", resp_body[:200]) resp_body = await resp.text()
return {"success": False, "error": f"HTTP {resp.status}: {resp_body[:200]}"} if resp.status == 429 and attempt < _MAX_RETRIES:
except aiohttp.ClientError as e: try:
return {"success": False, "error": str(e)} wait_s = float(resp.headers.get("Retry-After", "2"))
except (TypeError, ValueError):
wait_s = 2.0
wait_s = max(0.0, min(wait_s, _MAX_RETRY_AFTER_S))
_LOGGER.warning(
"Matrix rate limited, retrying after %.2fs (attempt %d/%d)",
wait_s, attempt, _MAX_RETRIES,
)
await asyncio.sleep(wait_s)
continue
return {
"success": False,
"error": f"HTTP {resp.status}: {redact(resp_body)[:200]}",
"status_code": resp.status,
}
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
return {"success": False, "error": redact_exc(e)}
return {"success": False, "error": "Rate limited (retries exhausted)"}
@@ -3,18 +3,33 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any from typing import Any, Final
import aiohttp import aiohttp
from ..http_base import HttpProviderClient
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_PRIORITY_MIN: Final = 1
_PRIORITY_MAX: Final = 5
_DEFAULT_PRIORITY: Final = 3
_MAX_TAGS: Final = 10
_MAX_TAG_LEN: Final = 64
class NtfyClient:
def _strip_crlf(value: str) -> str:
"""Remove CR/LF — ntfy's JSON path is safe today, but the same fields
are used by the header API; defensive sanitization here means a future
refactor can't accidentally re-introduce header injection."""
return value.replace("\r", " ").replace("\n", " ")
class NtfyClient(HttpProviderClient):
"""Sends push notifications via ntfy server.""" """Sends push notifications via ntfy server."""
def __init__(self, session: aiohttp.ClientSession) -> None: def __init__(self, session: aiohttp.ClientSession) -> None:
self._session = session super().__init__(session, provider_name="ntfy")
async def send( async def send(
self, self,
@@ -22,41 +37,48 @@ class NtfyClient:
topic: str, topic: str,
message: str, message: str,
title: str | None = None, title: str | None = None,
priority: int = 3, priority: int = _DEFAULT_PRIORITY,
tags: list[str] | None = None, tags: list[str] | None = None,
click_url: str | None = None, click_url: str | None = None,
auth_token: str | None = None, auth_token: str | None = None,
markdown: bool = True,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Send a push notification to an ntfy topic.""" """Send a push notification to an ntfy topic."""
if not server_url or not topic: if not server_url or not topic:
return {"success": False, "error": "Missing server_url or topic"} return {"success": False, "error": "Missing server_url or topic"}
url = f"{server_url.rstrip('/')}" topic = _strip_crlf(topic).strip()
if not topic:
return {"success": False, "error": "Topic is empty after sanitization"}
try:
priority_int = int(priority) if priority is not None else _DEFAULT_PRIORITY
except (TypeError, ValueError):
priority_int = _DEFAULT_PRIORITY
priority_int = max(_PRIORITY_MIN, min(priority_int, _PRIORITY_MAX))
payload: dict[str, Any] = { payload: dict[str, Any] = {
"topic": topic, "topic": topic,
"message": message, "message": message,
"markdown": True, "markdown": bool(markdown),
} }
if title: if title:
payload["title"] = title payload["title"] = _strip_crlf(title)
if priority != 3: if priority_int != _DEFAULT_PRIORITY:
payload["priority"] = priority payload["priority"] = priority_int
if tags: if tags:
payload["tags"] = tags cleaned = [
_strip_crlf(str(t))[:_MAX_TAG_LEN]
for t in tags[:_MAX_TAGS]
if t
]
if cleaned:
payload["tags"] = cleaned
if click_url: if click_url:
payload["click"] = click_url payload["click"] = _strip_crlf(click_url)
headers: dict[str, str] = {"Content-Type": "application/json"} headers: dict[str, str] = {}
if auth_token: if auth_token:
headers["Authorization"] = f"Bearer {auth_token}" headers["Authorization"] = f"Bearer {auth_token}"
try: return await self.request("POST", server_url.rstrip("/"), json=payload, headers=headers)
async with self._session.post(
url, json=payload, headers=headers, allow_redirects=False,
) as resp:
if 200 <= resp.status < 300:
return {"success": True}
body = await resp.text()
return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
except aiohttp.ClientError as e:
return {"success": False, "error": str(e)}
@@ -2,47 +2,88 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import copy
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import Any, Final
from notify_bridge_core.storage import StorageBackend from notify_bridge_core.storage import StorageBackend
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# Bound on queue length. Without a cap, a misconfigured quiet-hour
# window plus high event throughput grows the persisted file unboundedly
# and every enqueue rewrites the whole file (O(n²) total writes). When
# the cap is hit we drop the oldest entry (FIFO) so the most recent
# events still reach the recipient when the window opens.
DEFAULT_MAX_QUEUE_SIZE: Final = 1000
class NotificationQueue: class NotificationQueue:
"""Persistent queue for notifications deferred during quiet hours.""" """Persistent queue for notifications deferred during quiet hours."""
def __init__(self, backend: StorageBackend) -> None: def __init__(
self,
backend: StorageBackend,
*,
max_size: int = DEFAULT_MAX_QUEUE_SIZE,
) -> None:
self._backend = backend self._backend = backend
self._data: dict[str, Any] | None = None self._data: dict[str, Any] | None = None
self._max_size = max_size
# Coordinates load / enqueue / clear / remove so a write-while-load
# race can't leave the in-memory copy out of sync with disk and so
# bulk operations don't interleave their reads-then-writes.
self._lock = asyncio.Lock()
@staticmethod
def _ensure_schema(data: Any) -> dict[str, Any]:
if not isinstance(data, dict) or not isinstance(data.get("queue"), list):
return {"queue": []}
return data
async def async_load(self) -> None: async def async_load(self) -> None:
self._data = await self._backend.load() or {"queue": []} async with self._lock:
raw = await self._backend.load()
self._data = self._ensure_schema(raw)
async def async_enqueue(self, notification_params: dict[str, Any]) -> None: async def async_enqueue(self, notification_params: dict[str, Any]) -> None:
if self._data is None: async with self._lock:
self._data = {"queue": []} if self._data is None:
self._data["queue"].append({ self._data = {"queue": []}
"params": notification_params, queue: list[dict[str, Any]] = self._data["queue"]
"queued_at": datetime.now(timezone.utc).isoformat(), queue.append({
}) "params": notification_params,
await self._backend.save(self._data) "queued_at": datetime.now(timezone.utc).isoformat(),
})
if self._max_size > 0 and len(queue) > self._max_size:
# Drop oldest (FIFO) so a new event can still land.
drop = len(queue) - self._max_size
_LOGGER.warning(
"NotificationQueue: dropping %d oldest entries (cap=%d)",
drop, self._max_size,
)
del queue[:drop]
await self._backend.save(self._data)
def get_all(self) -> list[dict[str, Any]]: def get_all(self) -> list[dict[str, Any]]:
if not self._data: if not self._data:
return [] return []
return list(self._data.get("queue", [])) # Deep copy so callers can iterate / mutate without corrupting the
# in-memory queue. The cost is bounded by ``max_size``.
return copy.deepcopy(list(self._data.get("queue", [])))
def has_pending(self) -> bool: def has_pending(self) -> bool:
return bool(self._data and self._data.get("queue")) return bool(self._data and self._data.get("queue"))
async def async_clear(self) -> None: async def async_clear(self) -> None:
if self._data: async with self._lock:
self._data["queue"] = [] if self._data:
await self._backend.save(self._data) self._data["queue"] = []
await self._backend.save(self._data)
async def async_remove(self) -> None: async def async_remove(self) -> None:
await self._backend.remove() async with self._lock:
self._data = None await self._backend.remove()
self._data = None
@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any, Callable
@dataclass @dataclass
@@ -70,51 +70,64 @@ class MatrixReceiver(Receiver):
room_id: str = "" room_id: str = ""
_ReceiverFactory = Callable[[str, dict[str, Any]], Receiver]
def _coerce_int(value: Any, default: int) -> int:
try:
return int(value)
except (TypeError, ValueError):
return default
_RECEIVER_FACTORIES: dict[str, _ReceiverFactory] = {
"telegram": lambda locale, config: TelegramReceiver(
locale=locale, config=config, chat_id=str(config.get("chat_id", "")),
),
"webhook": lambda locale, config: WebhookReceiver(
locale=locale, config=config,
url=str(config.get("url", "")),
headers=dict(config.get("headers", {}) or {}),
),
"email": lambda locale, config: EmailReceiver(
locale=locale, config=config,
email=str(config.get("email", "")),
name=str(config.get("name", "")),
),
"discord": lambda locale, config: DiscordReceiver(
locale=locale, config=config,
webhook_url=str(config.get("webhook_url", "")),
),
"slack": lambda locale, config: SlackReceiver(
locale=locale, config=config,
webhook_url=str(config.get("webhook_url", "")),
),
"ntfy": lambda locale, config: NtfyReceiver(
locale=locale, config=config,
topic=str(config.get("topic", "")),
priority=_coerce_int(config.get("priority"), 3),
),
"matrix": lambda locale, config: MatrixReceiver(
locale=locale, config=config,
room_id=str(config.get("room_id", "")),
),
}
def register_receiver_factory(target_type: str, factory: _ReceiverFactory) -> None:
"""Register a receiver factory for an out-of-tree target type."""
_RECEIVER_FACTORIES[target_type] = factory
def build_receiver(target_type: str, config: dict[str, Any], locale: str = "") -> Receiver: def build_receiver(target_type: str, config: dict[str, Any], locale: str = "") -> Receiver:
"""Factory: build typed Receiver from target type and config dict.""" """Factory: build typed Receiver from target type and config dict.
if target_type == "telegram":
return TelegramReceiver( Falls back to a base ``Receiver`` for unknown target types so callers
locale=locale, that handle types defensively still receive a usable object — but the
config=config, dispatcher rejects them with ``"Unknown target type"`` so a typo can't
chat_id=str(config.get("chat_id", "")), silently route to nowhere.
) """
if target_type == "webhook": factory = _RECEIVER_FACTORIES.get(target_type)
return WebhookReceiver( if factory is None:
locale=locale, return Receiver(locale=locale, config=config)
config=config, return factory(locale, config)
url=config.get("url", ""),
headers=config.get("headers", {}),
)
if target_type == "email":
return EmailReceiver(
locale=locale,
config=config,
email=config.get("email", ""),
name=config.get("name", ""),
)
if target_type == "discord":
return DiscordReceiver(
locale=locale,
config=config,
webhook_url=config.get("webhook_url", ""),
)
if target_type == "slack":
return SlackReceiver(
locale=locale,
config=config,
webhook_url=config.get("webhook_url", ""),
)
if target_type == "ntfy":
return NtfyReceiver(
locale=locale,
config=config,
topic=config.get("topic", ""),
priority=config.get("priority", 3),
)
if target_type == "matrix":
return MatrixReceiver(
locale=locale,
config=config,
room_id=config.get("room_id", ""),
)
return Receiver(locale=locale, config=config)
@@ -0,0 +1,64 @@
"""Secret-redaction helpers for log lines and error strings.
Notification clients embed secrets in URLs (Telegram bot tokens) and
Authorization headers (Matrix access tokens, ntfy bearer tokens). When
those secrets surface in ``aiohttp.ClientError.__str__``, response
bodies, or operator-visible error fields, they leak into logs and into
the per-target result dict that callers may forward upstream. ``redact``
returns a defanged copy safe for both contexts.
"""
from __future__ import annotations
import re
from typing import Final
# api.telegram.org/bot<digits>:<token>/<method>
_TELEGRAM_BOT_TOKEN_RE: Final = re.compile(
r"(api\.telegram\.org/bot)\d+:[A-Za-z0-9_-]+", re.IGNORECASE,
)
# Authorization: Bearer <token> (header form, case-insensitive)
_BEARER_RE: Final = re.compile(r"(Bearer\s+)[A-Za-z0-9._\-+/=]+", re.IGNORECASE)
# Discord webhook: /api/webhooks/<id>/<token>
_DISCORD_WEBHOOK_RE: Final = re.compile(
r"(discord(?:app)?\.com/api/webhooks/\d+/)[A-Za-z0-9_-]+",
re.IGNORECASE,
)
# Slack webhook path: /services/T.../B.../<token>
_SLACK_WEBHOOK_RE: Final = re.compile(
r"(hooks\.slack\.com/services/[A-Z0-9]+/[A-Z0-9]+/)[A-Za-z0-9]+",
re.IGNORECASE,
)
# URL userinfo: scheme://user:password@host
_URL_USERINFO_RE: Final = re.compile(
r"([a-z][a-z0-9+\-.]*://)[^/@\s]+:[^/@\s]+@",
re.IGNORECASE,
)
# Common token query parameters
_QUERY_TOKEN_RE: Final = re.compile(
r"([?&](?:token|access_token|api_key|key|secret|password)=)[^&\s]+",
re.IGNORECASE,
)
def redact(text: str) -> str:
"""Return ``text`` with known secret patterns replaced by ``***``.
Idempotent and safe to call on already-redacted strings. Always
returns a ``str``; non-strings are coerced via ``str()`` so callers
can pass exception instances directly.
"""
if not isinstance(text, str):
text = str(text)
text = _TELEGRAM_BOT_TOKEN_RE.sub(r"\1***", text)
text = _DISCORD_WEBHOOK_RE.sub(r"\1***", text)
text = _SLACK_WEBHOOK_RE.sub(r"\1***", text)
text = _BEARER_RE.sub(r"\1***", text)
text = _URL_USERINFO_RE.sub(r"\1***@", text)
text = _QUERY_TOKEN_RE.sub(r"\1***", text)
return text
def redact_exc(err: BaseException) -> str:
"""Redact-and-stringify an exception. Convenience for error fields."""
return redact(str(err))
@@ -7,14 +7,16 @@ from typing import Any
import aiohttp import aiohttp
from ..http_base import HttpProviderClient
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class SlackClient: class SlackClient(HttpProviderClient):
"""Sends messages via Slack incoming webhook URLs.""" """Sends messages via Slack incoming webhook URLs."""
def __init__(self, session: aiohttp.ClientSession) -> None: def __init__(self, session: aiohttp.ClientSession) -> None:
self._session = session super().__init__(session, provider_name="slack")
async def send( async def send(
self, self,
@@ -33,19 +35,4 @@ class SlackClient:
if icon_emoji: if icon_emoji:
payload["icon_emoji"] = icon_emoji payload["icon_emoji"] = icon_emoji
try: return await self.request("POST", webhook_url, json=payload)
async with self._session.post(
webhook_url,
json=payload,
headers={"Content-Type": "application/json"},
allow_redirects=False,
) as resp:
if resp.status == 429:
_LOGGER.warning("Slack rate limited")
return {"success": False, "error": "Rate limited by Slack"}
if 200 <= resp.status < 300:
return {"success": True}
body = await resp.text()
return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
except aiohttp.ClientError as e:
return {"success": False, "error": str(e)}
@@ -1,10 +1,22 @@
"""Outbound URL validation to mitigate SSRF attacks. """Outbound URL validation to mitigate SSRF attacks.
User-controlled URLs (provider `url`, webhook target `url`, shared-link User-controlled URLs (provider ``url``, webhook target ``url``,
base URLs, image downloads) must be validated before any HTTP request is shared-link base URLs, image downloads) must be validated before any
issued. This module rejects schemes other than http/https and blocks HTTP request is issued. This module rejects schemes other than
destinations that resolve to private, loopback, link-local, or unspecified http/https and blocks destinations that resolve to private, loopback,
address ranges. link-local, unspecified, CGNAT (100.64.0.0/10), or IPv4-mapped IPv6
ranges.
DNS rebinding mitigation
~~~~~~~~~~~~~~~~~~~~~~~~
``avalidate_outbound_url`` returns the original URL on success, but
also returns the resolved IP it actually validated. Callers that pass
the validated URL straight into ``aiohttp`` are vulnerable to a
DNS-rebinding attack: the validator's ``getaddrinfo`` returns a public
IP; aiohttp's connect-time resolution returns ``127.0.0.1``. To close
that gap, use :func:`build_ssrf_safe_session` (or
:class:`PinnedResolver`) so the resolved IP from the validation step is
the one aiohttp connects to.
Set ``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1`` in the environment for Set ``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1`` in the environment for
development against localhost services. development against localhost services.
@@ -17,12 +29,20 @@ import ipaddress
import logging import logging
import os import os
import socket import socket
from dataclasses import dataclass
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1" _ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1"
_ALLOWED_SCHEMES = {"http", "https"} _ALLOWED_SCHEMES = frozenset({"http", "https"})
# Carrier-grade NAT range. Not in stdlib's ``is_private``; an attacker
# pointing a domain at a CGNAT IP could reach the operator's ISP-side
# routing infrastructure. RFC 6598.
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner
_LOGGER.warning( _LOGGER.warning(
@@ -36,7 +56,29 @@ class UnsafeURLError(ValueError):
"""Raised when a URL targets a disallowed network destination.""" """Raised when a URL targets a disallowed network destination."""
@dataclass(frozen=True)
class ValidatedURL:
"""Result of validating an outbound URL.
Attributes:
url: The original URL string (unchanged).
host: Hostname extracted from the URL (lower-cased, IDN-encoded).
ip: Resolved IP address that passed the block-range check, as a
string. Pass to :class:`PinnedResolver` to defeat DNS
rebinding by reusing this exact IP at connect time.
"""
url: str
host: str
ip: str
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
# An IPv4-mapped IPv6 like ``::ffff:127.0.0.1`` is NOT considered
# ``is_private`` etc. by stdlib — the v4 view holds those flags. So
# we unwrap before checking.
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None:
ip = ip.ipv4_mapped
return ( return (
ip.is_private ip.is_private
or ip.is_loopback or ip.is_loopback
@@ -44,22 +86,54 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
or ip.is_multicast or ip.is_multicast
or ip.is_reserved or ip.is_reserved
or ip.is_unspecified or ip.is_unspecified
or (isinstance(ip, ipaddress.IPv4Address) and ip in _CGNAT_NETWORK)
) )
def _safe_host_repr(host: str) -> str:
"""Return ``host`` shortened/escaped for safe inclusion in error text."""
h = host[:64].replace("\r", "").replace("\n", "")
return h
def _normalize_host(parsed_host: str) -> str:
"""Normalize a hostname: lowercase, strip trailing dot, IDN-encode."""
host = parsed_host.lower()
if host.endswith("."):
host = host[:-1]
# Strip IPv6 zone id ("fe80::1%eth0") — must not reach the resolver.
if "%" in host:
host = host.split("%", 1)[0]
# IDN-encode unicode hostnames so we don't downgrade to confusables
# in any later log/output and so getaddrinfo gets the ascii form.
try:
if any(ord(c) > 127 for c in host):
host = host.encode("idna").decode("ascii")
except UnicodeError:
# Caller will fail on resolution; leave as-is so the error path
# surfaces a "DNS resolution failed" rather than a stack trace.
pass
return host
def _check_scheme_host(url: str) -> tuple[str, str]: def _check_scheme_host(url: str) -> tuple[str, str]:
if not isinstance(url, str) or not url: if not isinstance(url, str) or not url:
raise UnsafeURLError("URL is empty") raise UnsafeURLError("URL is empty")
parsed = urlparse(url) parsed = urlparse(url)
if parsed.scheme not in _ALLOWED_SCHEMES: scheme = parsed.scheme.lower()
raise UnsafeURLError(f"Scheme '{parsed.scheme}' not allowed") if scheme not in _ALLOWED_SCHEMES:
raise UnsafeURLError(f"Scheme '{scheme[:16]}' not allowed")
host = parsed.hostname host = parsed.hostname
if not host: if not host:
raise UnsafeURLError("URL has no host") raise UnsafeURLError("URL has no host")
return parsed.scheme, host return scheme, _normalize_host(host)
def _check_resolved_addresses(host: str, infos: list[tuple]) -> None: def _select_addresses(
host: str, infos: list[tuple],
) -> list[ipaddress.IPv4Address | ipaddress.IPv6Address]:
"""Return parsed, non-blocked IPs from ``getaddrinfo`` results."""
addrs: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = []
for info in infos: for info in infos:
sockaddr = info[4] sockaddr = info[4]
try: try:
@@ -67,64 +141,143 @@ def _check_resolved_addresses(host: str, infos: list[tuple]) -> None:
except ValueError: except ValueError:
continue continue
if _is_blocked_ip(ip): if _is_blocked_ip(ip):
raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}") raise UnsafeURLError(
f"Host {_safe_host_repr(host)} resolves to blocked address {ip}"
)
addrs.append(ip)
if not addrs:
raise UnsafeURLError(f"Host {_safe_host_repr(host)} has no usable address")
return addrs
def validate_outbound_url(url: str) -> str: def validate_outbound_url(url: str) -> str:
"""Validate ``url`` is safe to fetch; returns the URL on success. """Validate ``url`` is safe to fetch; returns the URL on success.
Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP .. deprecated::
is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``) Synchronous; uses blocking ``socket.getaddrinfo``. Prefer
private addresses are permitted but the scheme check still applies. :func:`avalidate_outbound_url` from async code paths so the
event loop isn't blocked, and use :func:`build_ssrf_safe_session`
Synchronous; uses blocking ``socket.getaddrinfo``. Prefer to defeat DNS rebinding.
:func:`avalidate_outbound_url` from async code paths.
""" """
_, host = _check_scheme_host(url) _, host = _check_scheme_host(url)
if _ALLOW_PRIVATE: if _ALLOW_PRIVATE:
return url return url
# Literal IP host
try: try:
ip = ipaddress.ip_address(host) ip = ipaddress.ip_address(host)
if _is_blocked_ip(ip): if _is_blocked_ip(ip):
raise UnsafeURLError(f"Host {host} is in a blocked range") raise UnsafeURLError(f"Host {_safe_host_repr(host)} is in a blocked range")
return url return url
except ValueError: except ValueError:
pass pass
try: try:
infos = socket.getaddrinfo(host, None) infos = socket.getaddrinfo(host, None)
except socket.gaierror as exc: except (socket.gaierror, UnicodeError, OSError) as exc:
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc # ``UnicodeError`` covers IDNA failures (labels >63 chars, malformed
_check_resolved_addresses(host, infos) # unicode) which getaddrinfo surfaces as encoding errors rather than
# gaierror. ``OSError`` covers transient resolver failures on some
# platforms.
raise UnsafeURLError(f"DNS resolution failed for {_safe_host_repr(host)}") from exc
_select_addresses(host, infos)
return url return url
async def avalidate_outbound_url(url: str) -> str: async def avalidate_outbound_url(url: str) -> str:
"""Async variant that resolves DNS via the running loop's resolver. """Async variant — returns the URL on success.
Use this from ``async def`` code paths to avoid blocking the event For DNS-rebinding-safe usage, prefer :func:`avalidate_outbound_url_full`
loop on DNS lookups. which also returns the resolved IP for connect-time pinning.
"""
result = await avalidate_outbound_url_full(url)
return result.url
async def avalidate_outbound_url_full(url: str) -> ValidatedURL:
"""Validate ``url`` and return a :class:`ValidatedURL` on success.
The returned ``ip`` field is the IP that passed the block-range
check. Pair with :class:`PinnedResolver` so aiohttp connects to that
exact IP and a malicious DNS server can't swap in a private address
after validation.
""" """
_, host = _check_scheme_host(url) _, host = _check_scheme_host(url)
if _ALLOW_PRIVATE: if _ALLOW_PRIVATE:
return url # In dev mode we still resolve to give a usable IP, but we don't
# gate on the result.
try:
ip = str(ipaddress.ip_address(host))
except ValueError:
try:
loop = asyncio.get_running_loop()
infos = await loop.getaddrinfo(host, None)
ip = infos[0][4][0] if infos else host
except (socket.gaierror, OSError):
ip = host
return ValidatedURL(url=url, host=host, ip=ip)
# Literal IP host
try: try:
ip = ipaddress.ip_address(host) ip_obj = ipaddress.ip_address(host)
if _is_blocked_ip(ip): if _is_blocked_ip(ip_obj):
raise UnsafeURLError(f"Host {host} is in a blocked range") raise UnsafeURLError(f"Host {_safe_host_repr(host)} is in a blocked range")
return url return ValidatedURL(url=url, host=host, ip=str(ip_obj))
except ValueError: except ValueError:
pass pass
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
try: try:
infos = await loop.getaddrinfo(host, None) infos = await loop.getaddrinfo(host, None)
except socket.gaierror as exc: except (socket.gaierror, UnicodeError, OSError) as exc:
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc raise UnsafeURLError(f"DNS resolution failed for {_safe_host_repr(host)}") from exc
_check_resolved_addresses(host, infos) addrs = _select_addresses(host, infos)
return url return ValidatedURL(url=url, host=host, ip=str(addrs[0]))
class PinnedResolver(aiohttp.abc.AbstractResolver):
"""aiohttp resolver that returns a fixed (host, ip) mapping.
Used to pin the resolved IP from :func:`avalidate_outbound_url_full`
so aiohttp's connect-time resolution can't be tricked by DNS
rebinding into using a different IP than the one we validated.
Falls back to :class:`aiohttp.AsyncResolver` (or default) for any
host not explicitly pinned, so a single resolver instance can be
reused across multiple validated URLs.
"""
def __init__(self, mapping: dict[str, str] | None = None) -> None:
self._map: dict[str, str] = dict(mapping or {})
self._fallback: aiohttp.abc.AbstractResolver | None = None
def pin(self, host: str, ip: str) -> None:
self._map[host.lower()] = ip
async def resolve(
self, host: str, port: int = 0, family: int = socket.AF_INET,
) -> list[dict]:
ip = self._map.get(host.lower())
if ip is not None:
try:
ip_obj = ipaddress.ip_address(ip)
except ValueError:
ip_obj = None
if ip_obj is not None:
fam = socket.AF_INET6 if ip_obj.version == 6 else socket.AF_INET
return [{
"hostname": host,
"host": ip,
"port": port,
"family": fam,
"proto": 0,
"flags": socket.AI_NUMERICHOST,
}]
if self._fallback is None:
self._fallback = aiohttp.ThreadedResolver()
return await self._fallback.resolve(host, port, family)
async def close(self) -> None:
if self._fallback is not None:
await self._fallback.close()
@@ -2,16 +2,29 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import Any, Final
from notify_bridge_core.storage import StorageBackend from notify_bridge_core.storage import StorageBackend
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DEFAULT_TELEGRAM_CACHE_TTL = 48 * 60 * 60 DEFAULT_TELEGRAM_CACHE_TTL: Final = 48 * 60 * 60
DEFAULT_MAX_ENTRIES = 5000 DEFAULT_MAX_ENTRIES: Final = 5000
def _parse_iso(value: str | None) -> datetime | None:
"""Parse an ISO-8601 timestamp tolerantly. Returns ``None`` on failure."""
if not value or not isinstance(value, str):
return None
try:
# Python <3.11 doesn't accept "Z"; normalize to +00:00.
v = value.replace("Z", "+00:00") if value.endswith("Z") else value
return datetime.fromisoformat(v)
except ValueError:
return None
class TelegramFileCache: class TelegramFileCache:
@@ -25,7 +38,17 @@ class TelegramFileCache:
Intended for content-addressable assets (e.g. Immich) where re-uploads Intended for content-addressable assets (e.g. Immich) where re-uploads
should be triggered by visual change, not elapsed time. should be triggered by visual change, not elapsed time.
``max_entries`` always applies as an LRU size cap (by ``cached_at``). ``max_entries`` always applies as a FIFO size cap (oldest-cached first).
Concurrency
~~~~~~~~~~~
All mutators take an internal ``asyncio.Lock`` so concurrent
media-group sends can't interleave a read-time invalidation with a
bulk write and corrupt the underlying dict (``RuntimeError:
dictionary changed size during iteration``) or lose just-written
entries. Reads do not take the lock — they are O(1) dict lookups —
but ``get`` uses a snapshot reference so it cannot mutate the data
structure under another task.
""" """
def __init__( def __init__(
@@ -40,35 +63,40 @@ class TelegramFileCache:
self._ttl_seconds = ttl_seconds self._ttl_seconds = ttl_seconds
self._use_thumbhash = use_thumbhash self._use_thumbhash = use_thumbhash
self._max_entries = max_entries self._max_entries = max_entries
self._lock = asyncio.Lock()
async def async_load(self) -> None: async def async_load(self) -> None:
self._data = await self._backend.load() or {"files": {}} async with self._lock:
await self._cleanup_expired() self._data = await self._backend.load() or {"files": {}}
await self._cleanup_expired_locked()
async def _cleanup_expired(self) -> None: async def _cleanup_expired_locked(self) -> None:
"""Caller must hold ``self._lock``."""
if not self._data or "files" not in self._data: if not self._data or "files" not in self._data:
return return
files = self._data["files"] files: dict[str, dict[str, Any]] = self._data["files"]
changed = False changed = False
# TTL sweep — only when TTL validation is active (i.e. no thumbhash
# mode and a positive TTL). In thumbhash mode we rely entirely on
# content validation; in "TTL disabled" mode (ttl_seconds <= 0) we
# cache forever, subject only to the size cap.
if not self._use_thumbhash and self._ttl_seconds > 0: if not self._use_thumbhash and self._ttl_seconds > 0:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
expired = [ expired: list[str] = []
url for url, entry in files.items() for url, entry in list(files.items()):
if entry.get("cached_at") and cached_at = _parse_iso(entry.get("cached_at"))
(now - datetime.fromisoformat(entry["cached_at"])).total_seconds() > self._ttl_seconds if cached_at is None:
] continue
if cached_at.tzinfo is None:
cached_at = cached_at.replace(tzinfo=timezone.utc)
if (now - cached_at).total_seconds() > self._ttl_seconds:
expired.append(url)
for key in expired: for key in expired:
del files[key] del files[key]
changed = True changed = True
# LRU cap — always enforced. Evicts oldest-cached entries first.
if self._max_entries > 0 and len(files) > self._max_entries: if self._max_entries > 0 and len(files) > self._max_entries:
sorted_keys = sorted(files, key=lambda k: files[k].get("cached_at", "")) sorted_keys = sorted(
files,
key=lambda k: _parse_iso(files[k].get("cached_at")) or datetime.min.replace(tzinfo=timezone.utc),
)
for key in sorted_keys[: len(files) - self._max_entries]: for key in sorted_keys[: len(files) - self._max_entries]:
del files[key] del files[key]
changed = True changed = True
@@ -80,7 +108,10 @@ class TelegramFileCache:
if not self._data or "files" not in self._data: if not self._data or "files" not in self._data:
return None return None
entry = self._data["files"].get(key) # Take a local reference so a concurrent ``async_set`` rebuilding
# the dict cannot pull the rug out mid-read.
files = self._data["files"]
entry = files.get(key)
if not entry: if not entry:
return None return None
@@ -88,19 +119,23 @@ class TelegramFileCache:
if thumbhash is not None: if thumbhash is not None:
stored = entry.get("thumbhash") stored = entry.get("thumbhash")
if stored and stored != thumbhash: if stored and stored != thumbhash:
del self._data["files"][key] # Mark stale — actual deletion happens lock-protected
# in the next mutation. Returning None is sufficient
# for the caller to skip the cache hit.
return None return None
elif self._ttl_seconds > 0: elif self._ttl_seconds > 0:
cached_at_str = entry.get("cached_at") cached_at = _parse_iso(entry.get("cached_at"))
if cached_at_str: if cached_at is not None:
age = (datetime.now(timezone.utc) - datetime.fromisoformat(cached_at_str)).total_seconds() if cached_at.tzinfo is None:
cached_at = cached_at.replace(tzinfo=timezone.utc)
age = (datetime.now(timezone.utc) - cached_at).total_seconds()
if age > self._ttl_seconds: if age > self._ttl_seconds:
return None return None
return { return {
"file_id": entry.get("file_id"), "file_id": entry.get("file_id"),
"type": entry.get("type"), "type": entry.get("type"),
"size": entry.get("size"), # bytes of what was uploaded; None for legacy entries "size": entry.get("size"),
} }
async def async_set( async def async_set(
@@ -111,21 +146,22 @@ class TelegramFileCache:
thumbhash: str | None = None, thumbhash: str | None = None,
size: int | None = None, size: int | None = None,
) -> None: ) -> None:
if self._data is None: async with self._lock:
self._data = {"files": {}} if self._data is None:
self._data = {"files": {}}
entry: dict[str, Any] = { entry: dict[str, Any] = {
"file_id": file_id, "file_id": file_id,
"type": media_type, "type": media_type,
"cached_at": datetime.now(timezone.utc).isoformat(), "cached_at": datetime.now(timezone.utc).isoformat(),
} }
if thumbhash is not None: if thumbhash is not None:
entry["thumbhash"] = thumbhash entry["thumbhash"] = thumbhash
if size is not None: if size is not None:
entry["size"] = size entry["size"] = size
self._data["files"][key] = entry self._data["files"][key] = entry
await self._backend.save(self._data) await self._backend.save(self._data)
async def async_set_many( async def async_set_many(
self, self,
@@ -139,32 +175,34 @@ class TelegramFileCache:
""" """
if not entries: if not entries:
return return
if self._data is None: async with self._lock:
self._data = {"files": {}} if self._data is None:
self._data = {"files": {}}
now_iso = datetime.now(timezone.utc).isoformat() now_iso = datetime.now(timezone.utc).isoformat()
for item in entries: for item in entries:
if len(item) == 5: if len(item) == 5:
key, file_id, media_type, thumbhash, size = item key, file_id, media_type, thumbhash, size = item
else: else:
key, file_id, media_type, thumbhash = item key, file_id, media_type, thumbhash = item
size = None size = None
entry: dict[str, Any] = { entry: dict[str, Any] = {
"file_id": file_id, "file_id": file_id,
"type": media_type, "type": media_type,
"cached_at": now_iso, "cached_at": now_iso,
} }
if thumbhash is not None: if thumbhash is not None:
entry["thumbhash"] = thumbhash entry["thumbhash"] = thumbhash
if size is not None: if size is not None:
entry["size"] = size entry["size"] = size
self._data["files"][key] = entry self._data["files"][key] = entry
await self._backend.save(self._data) await self._backend.save(self._data)
async def async_remove(self) -> None: async def async_remove(self) -> None:
await self._backend.remove() async with self._lock:
self._data = None await self._backend.remove()
self._data = None
def stats(self) -> dict[str, Any]: def stats(self) -> dict[str, Any]:
"""Return summary stats about the current cache contents. """Return summary stats about the current cache contents.
@@ -172,25 +210,33 @@ class TelegramFileCache:
Includes the number of cached entries, total tracked size in bytes Includes the number of cached entries, total tracked size in bytes
(only counts entries with a recorded ``size``), and the oldest / (only counts entries with a recorded ``size``), and the oldest /
newest ``cached_at`` timestamps (ISO strings, or ``None`` if empty). newest ``cached_at`` timestamps (ISO strings, or ``None`` if empty).
Timestamps are compared as parsed ``datetime`` objects so mixed
timezone formats (``Z`` vs ``+00:00``) order correctly.
""" """
files = self._data.get("files", {}) if self._data else {} files = self._data.get("files", {}) if self._data else {}
count = len(files) count = len(files)
total_size = 0 total_size = 0
oldest: str | None = None oldest_dt: datetime | None = None
newest: str | None = None newest_dt: datetime | None = None
oldest_str: str | None = None
newest_str: str | None = None
for entry in files.values(): for entry in files.values():
size = entry.get("size") size = entry.get("size")
if isinstance(size, int): if isinstance(size, int):
total_size += size total_size += size
cached_at = entry.get("cached_at") cached_at = entry.get("cached_at")
if cached_at: dt = _parse_iso(cached_at)
if oldest is None or cached_at < oldest: if dt is None or not cached_at:
oldest = cached_at continue
if newest is None or cached_at > newest: if dt.tzinfo is None:
newest = cached_at dt = dt.replace(tzinfo=timezone.utc)
if oldest_dt is None or dt < oldest_dt:
oldest_dt, oldest_str = dt, cached_at
if newest_dt is None or dt > newest_dt:
newest_dt, newest_str = dt, cached_at
return { return {
"count": count, "count": count,
"total_size_bytes": total_size, "total_size_bytes": total_size,
"oldest": oldest, "oldest": oldest_str,
"newest": newest, "newest": newest_str,
} }
File diff suppressed because it is too large Load Diff
@@ -2,20 +2,35 @@
from __future__ import annotations from __future__ import annotations
import logging
import re import re
from typing import Any, Final from typing import Any, Final
from urllib.parse import urlparse from urllib.parse import urlparse
_LOGGER = logging.getLogger(__name__)
# Telegram constants # Telegram constants
TELEGRAM_API_BASE_URL: Final = "https://api.telegram.org/bot" TELEGRAM_API_BASE_URL: Final = "https://api.telegram.org/bot"
TELEGRAM_MAX_PHOTO_SIZE: Final = 10 * 1024 * 1024 # 10 MB TELEGRAM_MAX_PHOTO_SIZE: Final = 10 * 1024 * 1024 # 10 MB
TELEGRAM_MAX_VIDEO_SIZE: Final = 50 * 1024 * 1024 # 50 MB TELEGRAM_MAX_VIDEO_SIZE: Final = 50 * 1024 * 1024 # 50 MB
TELEGRAM_MAX_DIMENSION_SUM: Final = 10000 TELEGRAM_MAX_DIMENSION_SUM: Final = 10000
# Telegram message-text limit (sendMessage) and caption limit
# (sendPhoto/sendVideo/sendDocument/first item of sendMediaGroup).
TELEGRAM_MAX_TEXT_LENGTH: Final = 4096
TELEGRAM_MAX_CAPTION_LENGTH: Final = 1024
# Generic UUID pattern for asset IDs # Strict canonical-UUID pattern (8-4-4-4-12) for asset IDs. The previous
_ASSET_ID_PATTERN = re.compile(r"^[a-f0-9-]{36}$") # loose ``[a-f0-9-]{36}`` matched 36 hyphens / arbitrary digit groupings,
# which could collide across providers when used as a cache key.
_ASSET_ID_PATTERN = re.compile(
r"^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$",
re.IGNORECASE,
)
# Cache key: "host:uuid" or bare "uuid" # Cache key: "host:uuid" or bare "uuid"
_ASSET_CACHE_KEY_PATTERN = re.compile(r"^(?:[^:]+:)?[a-f0-9-]{36}$") _ASSET_CACHE_KEY_PATTERN = re.compile(
r"^(?:[^:]+:)?[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$",
re.IGNORECASE,
)
# URL patterns to extract asset IDs (generic enough for Immich-style URLs) # URL patterns to extract asset IDs (generic enough for Immich-style URLs)
_ASSET_ID_URL_PATTERNS = [ _ASSET_ID_URL_PATTERNS = [
@@ -162,5 +177,10 @@ def check_photo_limits(
return False, None, width, height return False, None, width, height
except ImportError: except ImportError:
return False, None, None, None return False, None, None, None
except Exception: except (OSError, ValueError, MemoryError) as exc:
# PIL surfaces ``UnidentifiedImageError`` (subclass of OSError),
# truncated-image / decompression-bomb errors here. Log so a
# corrupt asset isn't silently passed to Telegram and rejected
# downstream with a less actionable error.
_LOGGER.warning("check_photo_limits: failed to inspect image (%d bytes): %s", len(data), exc)
return False, None, None, None return False, None, None, None
@@ -7,37 +7,29 @@ from typing import Any
import aiohttp import aiohttp
from ..ssrf import UnsafeURLError, avalidate_outbound_url from ..http_base import HttpProviderClient
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
class WebhookClient(HttpProviderClient):
"""Send JSON payloads to a webhook URL.
class WebhookClient: The URL is SSRF-validated on every send (defense-in-depth: re-validating
"""Send JSON payloads to a webhook URL.""" catches DNS rebinding between calls and a misconfigured target). Headers
pass through :func:`safe_headers` so a target config can't inject
framing/hop-by-hop headers like ``Host`` or ``Transfer-Encoding``.
"""
def __init__(self, session: aiohttp.ClientSession, url: str, headers: dict[str, str] | None = None) -> None: def __init__(
self._session = session self,
session: aiohttp.ClientSession,
url: str,
headers: dict[str, str] | None = None,
) -> None:
super().__init__(session, provider_name="webhook")
self._url = url self._url = url
self._headers = headers or {} self._extra_headers = headers or {}
async def send(self, payload: dict[str, Any]) -> dict[str, Any]: async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
try: return await self.request("POST", self._url, json=payload, headers=self._extra_headers)
await avalidate_outbound_url(self._url)
except UnsafeURLError as err:
return {"success": False, "error": f"Unsafe URL: {err}"}
try:
async with self._session.post(
self._url,
json=payload,
headers={"Content-Type": "application/json", **self._headers},
timeout=_DEFAULT_TIMEOUT,
allow_redirects=False,
) as response:
if 200 <= response.status < 300:
return {"success": True, "status_code": response.status}
body = await response.text()
return {"success": False, "error": f"HTTP {response.status}", "body": body[:200]}
except aiohttp.ClientError as err:
return {"success": False, "error": str(err)}
@@ -218,6 +218,19 @@ async def get_supported_locales(
return locales or ["en"] return locales or ["en"]
@router.get("/external-url")
async def get_external_url(
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Return the configured external base URL (available to all users).
Used by the UI to render absolute provider webhook URLs. Returns empty
string when unset so the UI falls back to the relative path.
"""
return {"external_url": (await get_setting(session, "external_url")).rstrip("/")}
async def _reregister_webhooks( async def _reregister_webhooks(
session: AsyncSession, base_url: str, secret: str session: AsyncSession, base_url: str, secret: str
) -> None: ) -> None:
@@ -0,0 +1,46 @@
"""Dispatcher result aggregation: per-receiver detail must survive."""
from __future__ import annotations
from notify_bridge_core.notifications.dispatcher import NotificationDispatcher
def test_aggregate_all_success() -> None:
out = NotificationDispatcher._aggregate_results([
{"success": True, "message_id": 1},
{"success": True, "message_id": 2},
])
assert out["success"] is True
assert out["receivers"] == 2
assert out["successes"] == 2
assert out["failures"] == 0
def test_aggregate_partial() -> None:
out = NotificationDispatcher._aggregate_results([
{"success": True},
{"success": False, "error": "boom"},
])
assert out["success"] is True # at least one succeeded
assert out["successes"] == 1
assert out["failures"] == 1
assert "boom" in out["errors"]
assert "results" in out
def test_aggregate_all_fail_preserves_all_errors() -> None:
out = NotificationDispatcher._aggregate_results([
{"success": False, "error": "first"},
{"success": False, "error": "second"},
])
assert out["success"] is False
assert out["error"] == "first" # back-compat top-level field
assert out["errors"] == ["first", "second"]
# Per-receiver details survive — operator can see exactly what failed.
assert len(out["results"]) == 2
def test_aggregate_empty() -> None:
out = NotificationDispatcher._aggregate_results([])
assert out["success"] is False
assert "error" in out
@@ -0,0 +1,77 @@
"""Email client header-injection / address-validation regression tests."""
from __future__ import annotations
import pytest
from notify_bridge_core.notifications.email.client import (
EmailClient,
SmtpConfig,
_strip_header,
_validate_email,
_to_html,
)
def test_strip_header_removes_crlf() -> None:
out = _strip_header("Subject\r\nBcc: attacker@example.com")
assert "\r" not in out
assert "\n" not in out
# The injected "Bcc:" line is folded to a single header line; the SMTP
# server will treat it as part of the subject text, not a header.
assert "Bcc:" in out # value preserved as plain text
def test_strip_header_removes_bare_lf() -> None:
out = _strip_header("Hello\nWorld")
assert "\n" not in out
def test_strip_header_handles_non_string() -> None:
assert _strip_header(None) == ""
def test_validate_email_rejects_crlf() -> None:
with pytest.raises(ValueError):
_validate_email("user@example.com\r\nBcc: x@y")
def test_validate_email_rejects_no_at() -> None:
with pytest.raises(ValueError):
_validate_email("not-an-email")
def test_validate_email_rejects_empty() -> None:
with pytest.raises(ValueError):
_validate_email("")
def test_validate_email_accepts_normal() -> None:
assert _validate_email("user@example.com") == "user@example.com"
def test_to_html_escapes_brackets() -> None:
out = _to_html("<script>alert(1)</script>")
assert "<script>" not in out
assert "&lt;script&gt;" in out
@pytest.mark.asyncio
async def test_send_returns_error_on_invalid_to() -> None:
cfg = SmtpConfig(host="smtp.example.com", from_address="from@example.com")
client = EmailClient(cfg)
result = await client.send(
to_email="user@example.com\r\nBcc: attacker@example.com",
subject="hi",
body_text="body",
)
assert result["success"] is False
assert "Invalid email" in result["error"]
@pytest.mark.asyncio
async def test_send_returns_error_on_no_host() -> None:
cfg = SmtpConfig(host="", from_address="from@example.com")
client = EmailClient(cfg)
result = await client.send("u@x.com", "s", "b")
assert result["success"] is False
+53
View File
@@ -0,0 +1,53 @@
"""HttpProviderClient + safe_headers tests."""
from __future__ import annotations
import pytest
from notify_bridge_core.notifications.http_base import safe_headers
class TestSafeHeaders:
def test_drops_hop_by_hop(self) -> None:
out = safe_headers({
"X-Custom": "ok",
"Host": "evil.example.com",
"Content-Length": "999",
"Transfer-Encoding": "chunked",
"Connection": "close",
})
assert out == {"X-Custom": "ok"}
def test_rejects_crlf_in_value(self) -> None:
out = safe_headers({
"X-Custom": "ok",
"X-Bad": "value\r\nInjected: yes",
})
assert "X-Custom" in out
assert "X-Bad" not in out
def test_rejects_crlf_in_name(self) -> None:
out = safe_headers({
"X-Custom": "ok",
"X-Bad\r\nInject": "value",
})
assert out == {"X-Custom": "ok"}
def test_empty_input(self) -> None:
assert safe_headers(None) == {}
assert safe_headers({}) == {}
@pytest.mark.asyncio
async def test_http_base_returns_safe_error_on_invalid_url() -> None:
"""An obviously-bad URL must not panic or leak the URL verbatim."""
import aiohttp
from notify_bridge_core.notifications.http_base import HttpProviderClient
async with aiohttp.ClientSession() as sess:
client = HttpProviderClient(sess, provider_name="test")
# file:// is rejected by the SSRF guard before any HTTP call.
result = await client.request("POST", "file:///etc/passwd", json={})
assert result["success"] is False
assert "Unsafe URL" in result["error"]
@@ -0,0 +1,84 @@
"""Matrix client validation: room_id format and quoting."""
from __future__ import annotations
import aiohttp
import pytest
from aioresponses import aioresponses
from notify_bridge_core.notifications.matrix.client import MatrixClient
HOMESERVER = "https://matrix.example.com"
TOKEN = "secret-bearer-token-1234567890"
@pytest.mark.asyncio
async def test_rejects_path_injection_room_id() -> None:
async with aiohttp.ClientSession() as sess:
client = MatrixClient(sess, HOMESERVER, TOKEN)
result = await client.send_message("!abc:host/../../etc/passwd", "hi")
assert result["success"] is False
assert "room_id" in result["error"]
@pytest.mark.asyncio
async def test_rejects_empty_room_id() -> None:
async with aiohttp.ClientSession() as sess:
client = MatrixClient(sess, HOMESERVER, TOKEN)
result = await client.send_message("", "hi")
assert result["success"] is False
assert "room_id" in result["error"]
@pytest.mark.asyncio
async def test_rejects_unicode_control_chars_in_room_id() -> None:
async with aiohttp.ClientSession() as sess:
client = MatrixClient(sess, HOMESERVER, TOKEN)
result = await client.send_message("!abc\x00:host", "hi")
assert result["success"] is False
@pytest.mark.asyncio
async def test_url_encodes_room_id_special_chars() -> None:
"""``!`` and ``:`` must reach the server URL-encoded."""
captured: list[str] = []
with aioresponses() as mocked:
# Match any PUT under the rooms path; capture the URL we got.
mocked.put(
"https://matrix.example.com/_matrix/client/v3/rooms/%21abc%3Ahost.example/send/m.room.message",
status=200, body='{}', repeat=True,
)
# aioresponses doesn't expose URL templates well, so use a regex mock.
import re
mocked.put(
re.compile(r"https://matrix\.example\.com/_matrix/client/v3/rooms/[^/]+/send/m\.room\.message/.*"),
status=200, body='{}', repeat=True,
)
async with aiohttp.ClientSession() as sess:
client = MatrixClient(sess, HOMESERVER, TOKEN)
result = await client.send_message("!abc:host.example", "hi")
assert result["success"] is True
@pytest.mark.asyncio
async def test_redacts_bearer_in_error() -> None:
"""A 4xx response body must not echo the Authorization Bearer back to caller."""
import re
with aioresponses() as mocked:
mocked.put(
re.compile(r".*"),
status=403,
body='{"errcode": "M_FORBIDDEN", "Authorization": "Bearer ' + TOKEN + '"}',
repeat=True,
)
async with aiohttp.ClientSession() as sess:
client = MatrixClient(sess, HOMESERVER, TOKEN)
result = await client.send_message("!abc:host.example", "hi")
assert result["success"] is False
assert TOKEN not in result["error"]
+84
View File
@@ -0,0 +1,84 @@
"""NotificationQueue bound + concurrency regression tests."""
from __future__ import annotations
import asyncio
from typing import Any
import pytest
from notify_bridge_core.notifications.queue import (
DEFAULT_MAX_QUEUE_SIZE,
NotificationQueue,
)
class _MemBackend:
"""In-memory storage backend stub for tests."""
def __init__(self) -> None:
self._data: dict[str, Any] | None = None
async def load(self) -> dict[str, Any] | None:
return self._data
async def save(self, data: dict[str, Any]) -> None:
self._data = data
async def remove(self) -> None:
self._data = None
@pytest.mark.asyncio
async def test_load_with_garbage_falls_back_to_empty() -> None:
backend = _MemBackend()
backend._data = {"queue": "not-a-list"} # type: ignore[assignment]
q = NotificationQueue(backend)
await q.async_load()
assert q.get_all() == []
@pytest.mark.asyncio
async def test_enqueue_caps_at_max_size() -> None:
backend = _MemBackend()
q = NotificationQueue(backend, max_size=3)
await q.async_load()
for i in range(10):
await q.async_enqueue({"i": i})
items = q.get_all()
assert len(items) == 3
# FIFO drop: most recent three are kept (i=7..9).
assert [it["params"]["i"] for it in items] == [7, 8, 9]
@pytest.mark.asyncio
async def test_get_all_returns_deep_copy() -> None:
backend = _MemBackend()
q = NotificationQueue(backend, max_size=10)
await q.async_load()
await q.async_enqueue({"key": "value"})
snap = q.get_all()
snap[0]["params"]["key"] = "MUTATED"
snap2 = q.get_all()
assert snap2[0]["params"]["key"] == "value"
@pytest.mark.asyncio
async def test_concurrent_enqueue_and_clear_no_corruption() -> None:
backend = _MemBackend()
q = NotificationQueue(backend, max_size=DEFAULT_MAX_QUEUE_SIZE)
await q.async_load()
async def producer() -> None:
for i in range(50):
await q.async_enqueue({"i": i})
async def clearer() -> None:
for _ in range(10):
await asyncio.sleep(0)
await q.async_clear()
await asyncio.gather(producer(), clearer())
# No exceptions = no race-induced "dictionary changed size during iteration".
items = q.get_all()
assert isinstance(items, list)
+74
View File
@@ -0,0 +1,74 @@
"""Secret-redaction helper regression tests.
Locks in the patterns that surface from real provider error paths:
Telegram bot URLs in aiohttp.ClientError messages, Authorization Bearer
tokens in Matrix/ntfy responses, Discord/Slack webhook tokens, URL
userinfo, and common ?token= query params.
"""
from __future__ import annotations
import pytest
from notify_bridge_core.notifications.redact import redact, redact_exc
@pytest.mark.parametrize(
"raw,expected_substr,not_in",
[
(
"Cannot connect to host api.telegram.org/bot1234567:AABBCC-secret-token/sendMessage",
"api.telegram.org/bot***",
"AABBCC-secret-token",
),
(
"Authorization: Bearer ey.JhbGciOiJIUzI1NiJ9.payload.sig",
"Bearer ***",
"ey.JhbGciOiJIUzI1NiJ9",
),
(
"POST https://discord.com/api/webhooks/12345/abcdefg-token failed",
"discord.com/api/webhooks/12345/***",
"abcdefg-token",
),
(
"POST https://hooks.slack.com/services/T01/B02/zzzzz failed",
"hooks.slack.com/services/T01/B02/***",
"zzzzz",
),
(
"fetch http://user:supersecret@example.com/foo",
"http://***@example.com/foo",
"supersecret",
),
(
"https://api.example.com/x?token=mytoken123&extra=ok",
"token=***",
"mytoken123",
),
],
)
def test_redact_known_secrets(raw: str, expected_substr: str, not_in: str) -> None:
out = redact(raw)
assert expected_substr in out
assert not_in not in out
def test_redact_idempotent() -> None:
once = redact("Bearer abcdefghij1234567890")
twice = redact(once)
assert once == twice
def test_redact_exc_returns_str() -> None:
err = RuntimeError("Bearer abcdefghij1234567890")
out = redact_exc(err)
assert isinstance(out, str)
assert "Bearer ***" in out
assert "abcdefghij1234567890" not in out
def test_redact_handles_non_string() -> None:
# Coercion path should not raise.
out = redact(12345) # type: ignore[arg-type]
assert out == "12345"
@@ -0,0 +1,73 @@
"""SSRF hardening regression tests.
Covers cases the original guard missed: IPv4-mapped IPv6, CGNAT,
trailing-dot hostnames, IPv6 zone identifiers, and the safe-host repr
used in error messages.
"""
from __future__ import annotations
import pytest
from notify_bridge_core.notifications.ssrf import (
UnsafeURLError,
PinnedResolver,
avalidate_outbound_url_full,
validate_outbound_url,
)
class TestBlockedRanges:
@pytest.mark.parametrize(
"url",
[
"http://[::ffff:127.0.0.1]/", # IPv4-mapped IPv6 → loopback
"http://[::ffff:10.0.0.1]/", # IPv4-mapped IPv6 → RFC1918
"http://100.64.0.1/", # CGNAT (RFC 6598)
"http://0.0.0.0/", # unspecified
],
)
def test_rejects_extra_ranges(self, url: str) -> None:
with pytest.raises(UnsafeURLError):
validate_outbound_url(url)
class TestHostnameNormalization:
def test_strips_trailing_dot(self) -> None:
# ``localhost.`` should normalize to ``localhost`` and still resolve
# to the loopback address — and be blocked.
with pytest.raises(UnsafeURLError):
validate_outbound_url("http://localhost./")
def test_rejects_bad_scheme_uppercase(self) -> None:
with pytest.raises(UnsafeURLError):
validate_outbound_url("FILE:///etc/passwd")
class TestErrorMessages:
def test_error_does_not_leak_long_hosts(self) -> None:
with pytest.raises(UnsafeURLError) as ei:
validate_outbound_url("http://" + "a" * 1024 + ".invalid/")
# Truncated to 64 chars in the error string.
assert len(str(ei.value)) < 256
class TestPinnedResolverSync:
def test_pin_returns_pinned_ip(self) -> None:
resolver = PinnedResolver({"example.com": "93.184.216.34"})
# Just exercise the dict path — full resolve runs in async tests.
assert resolver._map["example.com"] == "93.184.216.34" # type: ignore[attr-defined]
class TestAsyncFullValidator:
@pytest.mark.asyncio
async def test_returns_resolved_ip(self) -> None:
# Literal IP — no DNS lookup; we still get back a ValidatedURL.
result = await avalidate_outbound_url_full("http://8.8.8.8/")
assert result.ip == "8.8.8.8"
assert result.host == "8.8.8.8"
@pytest.mark.asyncio
async def test_rejects_blocked_literal(self) -> None:
with pytest.raises(UnsafeURLError):
await avalidate_outbound_url_full("http://[::ffff:127.0.0.1]/")
@@ -0,0 +1,56 @@
"""Telegram media-group mixed-type partitioning regression test.
Telegram rejects sendMediaGroup payloads that mix ``document`` with
``photo``/``video``. The client must partition before chunking so a
mixed input list still delivers all assets.
"""
from __future__ import annotations
from notify_bridge_core.notifications.telegram.client import TelegramClient
def test_partition_keeps_photo_video_together() -> None:
parts = TelegramClient._partition_media_by_kind([
{"type": "photo", "url": "p1"},
{"type": "video", "url": "v1"},
{"type": "photo", "url": "p2"},
])
assert len(parts) == 1
assert [a["url"] for a in parts[0]] == ["p1", "v1", "p2"]
def test_partition_separates_documents_from_media() -> None:
parts = TelegramClient._partition_media_by_kind([
{"type": "photo", "url": "p1"},
{"type": "document", "url": "d1"},
{"type": "video", "url": "v1"},
])
assert len(parts) == 3
assert parts[0][0]["url"] == "p1"
assert parts[1][0]["url"] == "d1"
assert parts[2][0]["url"] == "v1"
def test_partition_groups_consecutive_documents() -> None:
parts = TelegramClient._partition_media_by_kind([
{"type": "document", "url": "d1"},
{"type": "document", "url": "d2"},
{"type": "photo", "url": "p1"},
])
assert len(parts) == 2
assert [a["url"] for a in parts[0]] == ["d1", "d2"]
assert parts[1][0]["url"] == "p1"
def test_partition_empty() -> None:
assert TelegramClient._partition_media_by_kind([]) == []
def test_partition_defaults_missing_type_to_photo() -> None:
"""Items without an explicit type are treated as photos for grouping."""
parts = TelegramClient._partition_media_by_kind([
{"url": "x"}, # no type
{"type": "video", "url": "v"},
])
assert len(parts) == 1