feat: harden notification stack and switch logging selectors to icon grid
Notifications: - Add shared http_base, redact, and SSRF hardening modules - Refactor dispatcher, queue, receiver and per-provider clients (telegram, discord, email, matrix, ntfy, slack, webhook) to use the shared base, with bounded queue and redacted error logs - Tests for ssrf, redact, http_base, queue bounds, dispatcher aggregation, telegram media partition, email and matrix clients Frontend: - Settings: log level / log format selectors now use IconGridSelect with per-option icons and i18n descriptions - Minor providers page and entity-cache store updates Tooling: - Document code-review-graph MCP usage in CLAUDE.md - Ignore .code-review-graph/, register .mcp.json
This commit is contained in:
@@ -56,3 +56,5 @@ frontend/.svelte-kit/
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
# Added by code-review-graph
|
||||
.code-review-graph/
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"code-review-graph": {
|
||||
"command": "uvx",
|
||||
"args": [
|
||||
"code-review-graph",
|
||||
"serve"
|
||||
],
|
||||
"type": "stdio"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -43,3 +43,42 @@ Detailed context is split into focused documents under `.claude/docs/`. Read the
|
||||
- Notification preview sample: `packages/server/src/notify_bridge_server/services/sample_context.py` (`_SAMPLE_CONTEXT`)
|
||||
- Command preview sample: `packages/server/src/notify_bridge_server/api/command_template_configs.py` (`sample_ctx` in `preview_raw`)
|
||||
- Runtime validator whitelist: `packages/core/src/notify_bridge_core/templates/validator.py`
|
||||
|
||||
<!-- code-review-graph MCP tools -->
|
||||
## MCP Tools: code-review-graph
|
||||
|
||||
**IMPORTANT: This project has a knowledge graph. ALWAYS use the
|
||||
code-review-graph MCP tools BEFORE using Grep/Glob/Read to explore
|
||||
the codebase.** The graph is faster, cheaper (fewer tokens), and gives
|
||||
you structural context (callers, dependents, test coverage) that file
|
||||
scanning cannot.
|
||||
|
||||
### When to use graph tools FIRST
|
||||
|
||||
- **Exploring code**: `semantic_search_nodes` or `query_graph` instead of Grep
|
||||
- **Understanding impact**: `get_impact_radius` instead of manually tracing imports
|
||||
- **Code review**: `detect_changes` + `get_review_context` instead of reading entire files
|
||||
- **Finding relationships**: `query_graph` with callers_of/callees_of/imports_of/tests_for
|
||||
- **Architecture questions**: `get_architecture_overview` + `list_communities`
|
||||
|
||||
Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
|
||||
|
||||
### Key Tools
|
||||
|
||||
| Tool | Use when |
|
||||
|------|----------|
|
||||
| `detect_changes` | Reviewing code changes — gives risk-scored analysis |
|
||||
| `get_review_context` | Need source snippets for review — token-efficient |
|
||||
| `get_impact_radius` | Understanding blast radius of a change |
|
||||
| `get_affected_flows` | Finding which execution paths are impacted |
|
||||
| `query_graph` | Tracing callers, callees, imports, tests, dependencies |
|
||||
| `semantic_search_nodes` | Finding functions/classes by name or keyword |
|
||||
| `get_architecture_overview` | Understanding high-level codebase structure |
|
||||
| `refactor_tool` | Planning renames, finding dead code |
|
||||
|
||||
### Workflow
|
||||
|
||||
1. The graph auto-updates on file changes (via hooks).
|
||||
2. Use `detect_changes` for code review.
|
||||
3. Use `get_affected_flows` to understand impact.
|
||||
4. Use `query_graph` pattern="tests_for" to check coverage.
|
||||
|
||||
@@ -73,6 +73,22 @@ export const localeItems = (): GridItem[] => [
|
||||
{ value: 'ru', icon: 'mdiAlphabeticalVariant', label: 'Русский', desc: t('gridDesc.localeRu') },
|
||||
];
|
||||
|
||||
// --- Log level ---
|
||||
|
||||
export const logLevelItems = (): GridItem[] => [
|
||||
{ value: 'DEBUG', icon: 'mdiBugOutline', label: 'DEBUG', desc: t('gridDesc.logLevelDebug') },
|
||||
{ value: 'INFO', icon: 'mdiInformationOutline', label: 'INFO', desc: t('gridDesc.logLevelInfo') },
|
||||
{ value: 'WARNING', icon: 'mdiAlertOutline', label: 'WARNING', desc: t('gridDesc.logLevelWarning') },
|
||||
{ value: 'ERROR', icon: 'mdiAlertOctagonOutline', label: 'ERROR', desc: t('gridDesc.logLevelError') },
|
||||
];
|
||||
|
||||
// --- Log format ---
|
||||
|
||||
export const logFormatItems = (): GridItem[] => [
|
||||
{ value: 'text', icon: 'mdiFormatText', label: 'text', desc: t('gridDesc.logFormatText') },
|
||||
{ value: 'json', icon: 'mdiCodeJson', label: 'json', desc: t('gridDesc.logFormatJson') },
|
||||
];
|
||||
|
||||
// --- Response mode ---
|
||||
|
||||
export const responseModeItems = (tFn: typeof t): GridItem[] => [
|
||||
|
||||
@@ -192,7 +192,8 @@
|
||||
"apiToken": "API Token",
|
||||
"apiTokenHint": "Optional. Needed for connection testing and repository listing.",
|
||||
"webhookUrl": "Webhook URL",
|
||||
"webhookUrlHint": "Set this as the Target URL in Gitea webhook settings (relative to your bridge host).",
|
||||
"webhookUrlHint": "Set this as the Target URL in Gitea webhook settings. The full URL is shown when an external base URL is configured in Settings; otherwise it is relative to your bridge host.",
|
||||
"webhookUrlCopyTitle": "Click to copy",
|
||||
"nutHost": "NUT Server Host",
|
||||
"nutHostPlaceholder": "192.168.1.100 or ups.local",
|
||||
"nutPort": "NUT Server Port",
|
||||
@@ -1131,6 +1132,12 @@
|
||||
"memorySourceNative": "Use Immich native memories API",
|
||||
"localeEn": "English interface",
|
||||
"localeRu": "Russian interface",
|
||||
"logLevelDebug": "Verbose — show every step",
|
||||
"logLevelInfo": "Default — high-level events",
|
||||
"logLevelWarning": "Warnings and errors only",
|
||||
"logLevelError": "Errors only — quietest",
|
||||
"logFormatText": "Human-readable plain text",
|
||||
"logFormatJson": "One JSON object per line",
|
||||
"modeMedia": "Send actual photo/video files",
|
||||
"modeText": "Send file names and links only",
|
||||
"allEvents": "Show all event types",
|
||||
|
||||
@@ -192,7 +192,8 @@
|
||||
"apiToken": "API токен",
|
||||
"apiTokenHint": "Необязательно. Нужен для проверки подключения и получения списка репозиториев.",
|
||||
"webhookUrl": "URL вебхука",
|
||||
"webhookUrlHint": "Укажите этот URL в настройках вебхука Gitea (относительно хоста bridge).",
|
||||
"webhookUrlHint": "Укажите этот URL в настройках вебхука Gitea. Полный URL показывается, если в настройках задан внешний адрес; иначе путь указан относительно хоста bridge.",
|
||||
"webhookUrlCopyTitle": "Нажмите, чтобы скопировать",
|
||||
"nutHost": "Хост NUT-сервера",
|
||||
"nutHostPlaceholder": "192.168.1.100 или ups.local",
|
||||
"nutPort": "Порт NUT-сервера",
|
||||
@@ -1131,6 +1132,12 @@
|
||||
"memorySourceNative": "Использовать API воспоминаний Immich",
|
||||
"localeEn": "Английский интерфейс",
|
||||
"localeRu": "Русский интерфейс",
|
||||
"logLevelDebug": "Подробный — каждый шаг",
|
||||
"logLevelInfo": "По умолчанию — ключевые события",
|
||||
"logLevelWarning": "Только предупреждения и ошибки",
|
||||
"logLevelError": "Только ошибки — самый тихий",
|
||||
"logFormatText": "Читаемый человеком текст",
|
||||
"logFormatJson": "Один JSON-объект на строку",
|
||||
"modeMedia": "Отправка файлов фото/видео",
|
||||
"modeText": "Только имена файлов и ссылки",
|
||||
"allEvents": "Показать все типы событий",
|
||||
|
||||
@@ -112,6 +112,34 @@ export const capabilitiesCache = (() => {
|
||||
};
|
||||
})();
|
||||
|
||||
/** Configured external base URL — used to render absolute webhook URLs.
|
||||
* Available to all authenticated users. Empty string when unset. */
|
||||
export const externalUrlCache = (() => {
|
||||
let data = $state<string>('');
|
||||
let fetchedAt = $state(0);
|
||||
let inflight: Promise<string> | null = null;
|
||||
const TTL = 300_000;
|
||||
return {
|
||||
get value() { return data; },
|
||||
invalidate() { fetchedAt = 0; },
|
||||
async fetch(force = false): Promise<string> {
|
||||
if (!force && fetchedAt > 0 && Date.now() - fetchedAt < TTL) return data;
|
||||
if (inflight) return inflight;
|
||||
inflight = (async () => {
|
||||
try {
|
||||
const res = await api<{ external_url: string }>('/settings/external-url');
|
||||
data = (res?.external_url || '').replace(/\/+$/, '');
|
||||
fetchedAt = Date.now();
|
||||
return data;
|
||||
} finally {
|
||||
inflight = null;
|
||||
}
|
||||
})();
|
||||
return inflight;
|
||||
},
|
||||
};
|
||||
})();
|
||||
|
||||
/** Supported template locales — fetched from app settings. */
|
||||
export const supportedLocalesCache = (() => {
|
||||
let data = $state<string[]>(['en', 'ru']);
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import { slide } from 'svelte/transition';
|
||||
import { api, getBlockedBy, type BlockedByDetail } from '$lib/api';
|
||||
import { t } from '$lib/i18n';
|
||||
import { providersCache } from '$lib/stores/caches.svelte';
|
||||
import { providersCache, externalUrlCache } from '$lib/stores/caches.svelte';
|
||||
import PageHeader from '$lib/components/PageHeader.svelte';
|
||||
import Card from '$lib/components/Card.svelte';
|
||||
import Loading from '$lib/components/Loading.svelte';
|
||||
@@ -21,7 +21,7 @@
|
||||
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
|
||||
import { topbarAction } from '$lib/stores/topbar-action.svelte';
|
||||
import { onDestroy } from 'svelte';
|
||||
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
|
||||
import { snackSuccess, snackError, snackInfo } from '$lib/stores/snackbar.svelte';
|
||||
import { highlightFromUrl } from '$lib/highlight';
|
||||
import { getDescriptor, buildProviderFormDefaults } from '$lib/providers';
|
||||
import Button from '$lib/components/Button.svelte';
|
||||
@@ -45,6 +45,30 @@
|
||||
let confirmDelete = $state<ServiceProvider | null>(null);
|
||||
|
||||
let descriptor = $derived(getDescriptor(form.type));
|
||||
let externalUrl = $derived(externalUrlCache.value);
|
||||
|
||||
function buildWebhookUrl(pattern: string, token: string): string {
|
||||
const path = pattern.replace('{token}', token ?? '');
|
||||
return externalUrl ? `${externalUrl}${path}` : path;
|
||||
}
|
||||
|
||||
function copyWebhookUrl(e: Event, url: string) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
if (navigator.clipboard?.writeText) {
|
||||
navigator.clipboard.writeText(url);
|
||||
} else {
|
||||
const ta = document.createElement('textarea');
|
||||
ta.value = url;
|
||||
ta.style.position = 'fixed';
|
||||
ta.style.opacity = '0';
|
||||
document.body.appendChild(ta);
|
||||
ta.select();
|
||||
document.execCommand('copy');
|
||||
document.body.removeChild(ta);
|
||||
}
|
||||
snackInfo(`${t('snack.copied')}: ${url}`);
|
||||
}
|
||||
|
||||
// Auto-update name when provider type changes (unless user manually edited)
|
||||
$effect(() => {
|
||||
@@ -76,6 +100,7 @@
|
||||
onclick: () => { showForm ? (showForm = false, editing = null) : openNew(); },
|
||||
});
|
||||
load();
|
||||
externalUrlCache.fetch().catch(() => { /* fall back to relative URLs */ });
|
||||
});
|
||||
onDestroy(() => topbarAction.clear());
|
||||
async function load() {
|
||||
@@ -246,9 +271,15 @@
|
||||
</div>
|
||||
{/each}
|
||||
{#if descriptor?.webhookUrlPattern && editing}
|
||||
{@const editingWebhookUrl = buildWebhookUrl(descriptor.webhookUrlPattern, providers.find(p => p.id === editing)?.webhook_token ?? '')}
|
||||
<div class="bg-[var(--color-muted)] rounded-md p-3">
|
||||
<div class="block text-sm font-medium mb-1">{t('providers.webhookUrl')}</div>
|
||||
<code class="text-xs select-all break-all">{descriptor.webhookUrlPattern.replace('{token}', providers.find(p => p.id === editing)?.webhook_token ?? '')}</code>
|
||||
<button type="button"
|
||||
onclick={(e) => copyWebhookUrl(e, editingWebhookUrl)}
|
||||
title={t('providers.webhookUrlCopyTitle')}
|
||||
class="text-xs break-all text-left hover:text-[var(--color-primary)] cursor-pointer font-mono w-full">
|
||||
<code class="bg-transparent">{editingWebhookUrl}</code>
|
||||
</button>
|
||||
<p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.webhookUrlHint')}</p>
|
||||
</div>
|
||||
{/if}
|
||||
@@ -295,7 +326,14 @@
|
||||
<p class="text-xs text-[var(--color-muted-foreground)] font-mono">{provider.config.host}:{provider.config.port || 3493}</p>
|
||||
{/if}
|
||||
{#if provDesc?.webhookUrlPattern}
|
||||
<p class="text-xs text-[var(--color-muted-foreground)] font-mono mt-0.5">{t('providers.webhookUrl')}: <span class="select-all">{provDesc.webhookUrlPattern.replace('{token}', provider.webhook_token)}</span></p>
|
||||
{@const webhookUrl = buildWebhookUrl(provDesc.webhookUrlPattern, provider.webhook_token)}
|
||||
<p class="text-xs text-[var(--color-muted-foreground)] font-mono mt-0.5">
|
||||
{t('providers.webhookUrl')}:
|
||||
<button type="button"
|
||||
onclick={(e) => copyWebhookUrl(e, webhookUrl)}
|
||||
title={t('providers.webhookUrlCopyTitle')}
|
||||
class="hover:text-[var(--color-primary)] cursor-pointer break-all text-left">{webhookUrl}</button>
|
||||
</p>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -12,7 +12,10 @@
|
||||
import ConfirmModal from '$lib/components/ConfirmModal.svelte';
|
||||
import LocaleSelector from '$lib/components/LocaleSelector.svelte';
|
||||
import TimezoneSelector from '$lib/components/TimezoneSelector.svelte';
|
||||
import IconGridSelect from '$lib/components/IconGridSelect.svelte';
|
||||
import { logLevelItems, logFormatItems } from '$lib/grid-items';
|
||||
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
|
||||
import { externalUrlCache } from '$lib/stores/caches.svelte';
|
||||
|
||||
interface CacheBucketStats {
|
||||
count: number;
|
||||
@@ -76,6 +79,7 @@
|
||||
saving = true; error = '';
|
||||
try {
|
||||
settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) });
|
||||
externalUrlCache.invalidate();
|
||||
snackSuccess(t('settings.saved'));
|
||||
} catch (err: any) { error = err.message; snackError(err.message); }
|
||||
saving = false;
|
||||
@@ -221,21 +225,11 @@
|
||||
<div class="grid grid-cols-1 sm:grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label class="block text-xs font-medium mb-1">{t('settings.logLevel')}<Hint text={t('settings.logLevelHint')} /></label>
|
||||
<select bind:value={settings.log_level}
|
||||
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
|
||||
<option value="DEBUG">DEBUG</option>
|
||||
<option value="INFO">INFO</option>
|
||||
<option value="WARNING">WARNING</option>
|
||||
<option value="ERROR">ERROR</option>
|
||||
</select>
|
||||
<IconGridSelect items={logLevelItems()} bind:value={settings.log_level} columns={2} />
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-xs font-medium mb-1">{t('settings.logFormat')}<Hint text={t('settings.logFormatHint')} /></label>
|
||||
<select bind:value={settings.log_format}
|
||||
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
|
||||
<option value="text">text</option>
|
||||
<option value="json">json</option>
|
||||
</select>
|
||||
<IconGridSelect items={logFormatItems()} bind:value={settings.log_format} columns={2} />
|
||||
</div>
|
||||
<div class="sm:col-span-2">
|
||||
<label class="block text-xs font-medium mb-1">{t('settings.logLevels')}<Hint text={t('settings.logLevelsHint')} /></label>
|
||||
|
||||
@@ -4,21 +4,24 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Final
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..http_base import HttpProviderClient
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Discord webhook content limit
|
||||
MAX_CONTENT_LENGTH = 2000
|
||||
# Discord API constraints (per webhook docs).
|
||||
MAX_CONTENT_LENGTH: Final = 2000
|
||||
MAX_USERNAME_LENGTH: Final = 80
|
||||
|
||||
|
||||
class DiscordClient:
|
||||
class DiscordClient(HttpProviderClient):
|
||||
"""Sends messages via Discord webhook URLs."""
|
||||
|
||||
def __init__(self, session: aiohttp.ClientSession) -> None:
|
||||
self._session = session
|
||||
super().__init__(session, provider_name="discord")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
@@ -33,6 +36,8 @@ class DiscordClient:
|
||||
"""
|
||||
if not webhook_url:
|
||||
return {"success": False, "error": "Missing webhook_url"}
|
||||
if username and len(username) > MAX_USERNAME_LENGTH:
|
||||
return {"success": False, "error": f"username exceeds {MAX_USERNAME_LENGTH} chars"}
|
||||
|
||||
chunks = _split_message(message, MAX_CONTENT_LENGTH)
|
||||
for chunk in chunks:
|
||||
@@ -42,71 +47,34 @@ class DiscordClient:
|
||||
if avatar_url:
|
||||
payload["avatar_url"] = avatar_url
|
||||
|
||||
result = await self._post(webhook_url, payload)
|
||||
if not result["success"]:
|
||||
result = await self.request("POST", webhook_url, json=payload)
|
||||
if not result.get("success"):
|
||||
return result
|
||||
|
||||
# Small delay between chunks to respect rate limits
|
||||
if len(chunks) > 1:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return {"success": True}
|
||||
|
||||
_MAX_RETRIES = 3
|
||||
_MAX_RETRY_AFTER = 60.0
|
||||
|
||||
async def _post(self, url: str, payload: dict) -> dict[str, Any]:
|
||||
"""POST with bounded 429 retry.
|
||||
|
||||
We cap retries at _MAX_RETRIES and the ``Retry-After`` header at
|
||||
_MAX_RETRY_AFTER seconds so a hostile or misbehaving upstream cannot
|
||||
pin the dispatch task indefinitely.
|
||||
"""
|
||||
for attempt in range(self._MAX_RETRIES + 1):
|
||||
try:
|
||||
async with self._session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
allow_redirects=False,
|
||||
) as resp:
|
||||
if resp.status == 429 and attempt < self._MAX_RETRIES:
|
||||
try:
|
||||
retry_after = float(resp.headers.get("Retry-After", "2"))
|
||||
except (TypeError, ValueError):
|
||||
retry_after = 2.0
|
||||
retry_after = max(0.0, min(retry_after, self._MAX_RETRY_AFTER))
|
||||
_LOGGER.warning(
|
||||
"Discord rate limited, retrying after %.1fs (attempt %d/%d)",
|
||||
retry_after, attempt + 1, self._MAX_RETRIES,
|
||||
)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
body = await resp.text()
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"HTTP {resp.status}: {body[:200]}",
|
||||
}
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
return {"success": False, "error": "Rate limited (retries exhausted)"}
|
||||
|
||||
|
||||
def _split_message(text: str, limit: int) -> list[str]:
|
||||
"""Split message into chunks respecting the character limit."""
|
||||
"""Split message into chunks respecting the character limit.
|
||||
|
||||
Drops chunks that contain only whitespace — Discord rejects those.
|
||||
"""
|
||||
if len(text) <= limit:
|
||||
return [text]
|
||||
chunks = []
|
||||
chunks: list[str] = []
|
||||
while text:
|
||||
if len(text) <= limit:
|
||||
chunks.append(text)
|
||||
break
|
||||
# Try to split at newline
|
||||
split_at = text.rfind("\n", 0, limit)
|
||||
if split_at <= 0:
|
||||
split_at = limit
|
||||
chunks.append(text[:split_at])
|
||||
text = text[split_at:].lstrip("\n")
|
||||
return chunks
|
||||
piece = text
|
||||
text = ""
|
||||
else:
|
||||
split_at = text.rfind("\n", 0, limit)
|
||||
if split_at <= 0:
|
||||
split_at = limit
|
||||
piece = text[:split_at]
|
||||
text = text[split_at:].lstrip("\n")
|
||||
if piece.strip():
|
||||
chunks.append(piece)
|
||||
return chunks or [text]
|
||||
|
||||
@@ -7,7 +7,7 @@ import contextlib
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncIterator
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Final
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -15,37 +15,20 @@ from notify_bridge_core.log_context import bind_log_context, dispatch_id_var
|
||||
from notify_bridge_core.models.events import ServiceEvent
|
||||
from notify_bridge_core.templates.context import build_template_context
|
||||
from notify_bridge_core.templates.renderer import render_template
|
||||
from .ssrf import UnsafeURLError, avalidate_outbound_url
|
||||
|
||||
_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||
|
||||
# Cap on how many asset downloads run concurrently inside
|
||||
# ``_preload_asset_data``. Peak memory during a send is bounded to roughly
|
||||
# ``_PRELOAD_CONCURRENCY * max_asset_size`` instead of ``max_media_to_send *
|
||||
# max_asset_size``, which matters on small-RAM Docker hosts when a batch
|
||||
# contains many large videos.
|
||||
_PRELOAD_CONCURRENCY = 6
|
||||
|
||||
|
||||
def _new_session() -> aiohttp.ClientSession:
|
||||
"""Per-dispatch aiohttp session with a sane default timeout.
|
||||
|
||||
We still open a short-lived session per dispatch (connection reuse across
|
||||
dispatches lives in the server-side shared session), but we always attach
|
||||
a total timeout so a hung peer cannot wedge the task forever.
|
||||
"""
|
||||
return aiohttp.ClientSession(timeout=_HTTP_TIMEOUT)
|
||||
|
||||
from .http_base import safe_headers
|
||||
from .receiver import (
|
||||
DiscordReceiver,
|
||||
EmailReceiver,
|
||||
MatrixReceiver,
|
||||
NtfyReceiver,
|
||||
Receiver,
|
||||
SlackReceiver,
|
||||
TelegramReceiver,
|
||||
WebhookReceiver,
|
||||
EmailReceiver,
|
||||
DiscordReceiver,
|
||||
SlackReceiver,
|
||||
NtfyReceiver,
|
||||
MatrixReceiver,
|
||||
)
|
||||
from .redact import redact_exc
|
||||
from .ssrf import UnsafeURLError, avalidate_outbound_url
|
||||
from .telegram.cache import TelegramFileCache
|
||||
from .telegram.client import TelegramClient
|
||||
from .telegram.media import (
|
||||
@@ -58,7 +41,33 @@ from .webhook.client import WebhookClient
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TEMPLATE = '{{ event_type }}: "{{ collection_name }}"'
|
||||
DEFAULT_TEMPLATE: Final = '{{ event_type }}: "{{ collection_name }}"'
|
||||
|
||||
_HTTP_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
|
||||
|
||||
# Cap on how many asset downloads run concurrently inside
|
||||
# ``_preload_asset_data``. Peak memory during a send is bounded to roughly
|
||||
# ``_PRELOAD_CONCURRENCY * max_asset_size`` instead of ``max_media_to_send *
|
||||
# max_asset_size``.
|
||||
_PRELOAD_CONCURRENCY: Final = 6
|
||||
|
||||
# Cap on how many targets the dispatcher fans out to at once. With dozens
|
||||
# of targets and a single hung peer, unbounded ``gather`` can pin the
|
||||
# dispatch task. The cap also protects against credential-reuse rate
|
||||
# limits on shared providers.
|
||||
_DISPATCH_CONCURRENCY: Final = 16
|
||||
|
||||
# Cap on parallel per-receiver sends within a single target.
|
||||
_RECEIVER_CONCURRENCY: Final = 8
|
||||
|
||||
# Per-target soft timeout — at the top of the dispatch tree so a single
|
||||
# misbehaving target can't hold the whole batch open. Individual provider
|
||||
# clients carry their own per-request timeouts on top of this.
|
||||
_TARGET_TIMEOUT_S: Final = 120.0
|
||||
|
||||
|
||||
def _new_session() -> aiohttp.ClientSession:
|
||||
return aiohttp.ClientSession(timeout=_HTTP_TIMEOUT)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -66,17 +75,23 @@ class TargetConfig:
|
||||
"""Configuration for a notification target."""
|
||||
|
||||
type: str # "telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"
|
||||
config: dict[str, Any] # target-level config (bot_token, settings, etc.)
|
||||
template_slots: dict[str, dict[str, str]] | None = None # event_type -> {locale -> template}
|
||||
locale: str = "en" # default locale for template resolution
|
||||
config: dict[str, Any]
|
||||
template_slots: dict[str, dict[str, str]] | None = None
|
||||
locale: str = "en"
|
||||
date_format: str = "%d.%m.%Y, %H:%M UTC"
|
||||
date_only_format: str = "%d.%m.%Y"
|
||||
provider_api_key: str | None = None # API key for downloading assets from provider
|
||||
provider_internal_url: str | None = None # Internal provider URL for API key scoping
|
||||
provider_external_url: str | None = None # External domain for API key scoping
|
||||
provider_api_key: str | None = None
|
||||
provider_internal_url: str | None = None
|
||||
provider_external_url: str | None = None
|
||||
receivers: list[Receiver] = field(default_factory=list)
|
||||
|
||||
|
||||
_SendMethod = Callable[
|
||||
["NotificationDispatcher", TargetConfig, str, ServiceEvent],
|
||||
Awaitable[dict[str, Any]],
|
||||
]
|
||||
|
||||
|
||||
class NotificationDispatcher:
|
||||
"""Dispatches ServiceEvent notifications to configured targets."""
|
||||
|
||||
@@ -90,18 +105,11 @@ class NotificationDispatcher:
|
||||
self._url_cache = url_cache
|
||||
self._asset_cache = asset_cache
|
||||
# Optional shared session owned by the caller; when supplied we reuse
|
||||
# its connection pool instead of opening a fresh per-dispatch session
|
||||
# (saves a TLS handshake per outbound call).
|
||||
# its connection pool instead of opening a fresh per-dispatch session.
|
||||
self._shared_session = session
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]:
|
||||
"""Yield an aiohttp session, reusing the shared one if provided.
|
||||
|
||||
When a shared session was passed in ``__init__`` we yield it without
|
||||
closing (the caller owns its lifetime). Otherwise we open a
|
||||
short-lived session with our default timeout and close it on exit.
|
||||
"""
|
||||
if self._shared_session is not None and not self._shared_session.closed:
|
||||
yield self._shared_session
|
||||
return
|
||||
@@ -115,11 +123,9 @@ class NotificationDispatcher:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Send event notification to all targets.
|
||||
|
||||
Returns list of results (one per target).
|
||||
Returns one result per target. Per-target failures are isolated;
|
||||
a single bad target cannot poison the batch.
|
||||
"""
|
||||
# Bind a dispatch_id so every log line emitted by the target sends
|
||||
# (including deep in TelegramClient) can be correlated to the same
|
||||
# upstream event.
|
||||
new_id = dispatch_id_var.get() or f"disp:{uuid.uuid4().hex[:12]}"
|
||||
|
||||
with bind_log_context(dispatch_id=new_id):
|
||||
@@ -128,20 +134,36 @@ class NotificationDispatcher:
|
||||
event.event_type.value if hasattr(event.event_type, "value") else event.event_type,
|
||||
getattr(event, "collection_name", None), len(targets),
|
||||
)
|
||||
|
||||
sem = asyncio.Semaphore(_DISPATCH_CONCURRENCY)
|
||||
|
||||
async def run_one(t: TargetConfig) -> dict[str, Any]:
|
||||
async with sem:
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._send_to_target(event, t),
|
||||
timeout=_TARGET_TIMEOUT_S,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Target dispatch timed out after {_TARGET_TIMEOUT_S}s",
|
||||
}
|
||||
|
||||
raw_results = await asyncio.gather(
|
||||
*[self._send_to_target(event, t) for t in targets],
|
||||
*[run_one(t) for t in targets],
|
||||
return_exceptions=True,
|
||||
)
|
||||
results = []
|
||||
results: list[dict[str, Any]] = []
|
||||
failures = 0
|
||||
for target, raw in zip(targets, raw_results):
|
||||
if isinstance(raw, Exception):
|
||||
failures += 1
|
||||
_LOGGER.error(
|
||||
"Dispatch to target type=%s failed: %s",
|
||||
target.type, raw, exc_info=raw,
|
||||
target.type, redact_exc(raw),
|
||||
)
|
||||
results.append({"success": False, "error": str(raw)})
|
||||
results.append({"success": False, "error": redact_exc(raw)})
|
||||
else:
|
||||
if isinstance(raw, dict) and not raw.get("success"):
|
||||
failures += 1
|
||||
@@ -155,7 +177,6 @@ class NotificationDispatcher:
|
||||
def _resolve_template(
|
||||
self, event: ServiceEvent, target: TargetConfig, locale: str,
|
||||
) -> str:
|
||||
"""Resolve template string for an event, with locale fallback."""
|
||||
template_str = DEFAULT_TEMPLATE
|
||||
if target.template_slots:
|
||||
locale_map = target.template_slots.get(event.event_type.value)
|
||||
@@ -166,7 +187,6 @@ class NotificationDispatcher:
|
||||
def _render_message(
|
||||
self, event: ServiceEvent, target: TargetConfig, locale: str,
|
||||
) -> str:
|
||||
"""Resolve template and render message for a given locale."""
|
||||
template_str = self._resolve_template(event, target, locale)
|
||||
ctx = build_template_context(
|
||||
event, target_type=target.type,
|
||||
@@ -179,7 +199,6 @@ class NotificationDispatcher:
|
||||
self, receiver: Receiver, default_message: str,
|
||||
event: ServiceEvent, target: TargetConfig,
|
||||
) -> str:
|
||||
"""Return per-receiver message, re-rendering if receiver has a different locale."""
|
||||
if receiver.locale and receiver.locale != target.locale:
|
||||
return self._render_message(event, target, receiver.locale)
|
||||
return default_message
|
||||
@@ -187,21 +206,16 @@ class NotificationDispatcher:
|
||||
async def _send_to_target(
|
||||
self, event: ServiceEvent, target: TargetConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Send event to a single target (potentially multiple receivers)."""
|
||||
"""Dispatch to a single target via the registered handler."""
|
||||
default_message = self._render_message(event, target, target.locale)
|
||||
send_method = _PROVIDER_HANDLERS.get(target.type)
|
||||
if send_method is None:
|
||||
return {"success": False, "error": f"Unknown target type: {target.type}"}
|
||||
return await send_method(self, target, default_message, event)
|
||||
|
||||
send_method = {
|
||||
"telegram": self._send_telegram,
|
||||
"webhook": self._send_webhook,
|
||||
"email": self._send_email,
|
||||
"discord": self._send_discord,
|
||||
"slack": self._send_slack,
|
||||
"ntfy": self._send_ntfy,
|
||||
"matrix": self._send_matrix,
|
||||
}.get(target.type)
|
||||
if send_method:
|
||||
return await send_method(target, default_message, event)
|
||||
return {"success": False, "error": f"Unknown target type: {target.type}"}
|
||||
# ------------------------------------------------------------------
|
||||
# Asset preload (Telegram-specific)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _preload_asset_data(
|
||||
self,
|
||||
@@ -210,36 +224,13 @@ class NotificationDispatcher:
|
||||
session: aiohttp.ClientSession,
|
||||
max_size: int | None,
|
||||
) -> None:
|
||||
"""Download each non-cached asset's bytes once and attach to the entry.
|
||||
|
||||
Three benefits:
|
||||
* ``TelegramClient`` sees ``entry["data"]`` and skips its own download,
|
||||
so we don't fetch each URL twice.
|
||||
* We know the exact upload size, which lets the oversize warning in
|
||||
the rendered text compare against real bytes (for Immich videos,
|
||||
the transcoded ``/video/playback``), not the original ``file_size``.
|
||||
* Assets already in the Telegram file_id cache are skipped, and their
|
||||
stored size (if any) is used to populate ``playback_size`` — so
|
||||
templates see consistent sizes for repeat sends without re-download.
|
||||
|
||||
Entries whose download fails or exceeds ``max_size`` are left without
|
||||
``data``; ``TelegramClient`` will then fall back to its own download
|
||||
path and apply the same checks — no regression, just no preload win.
|
||||
|
||||
Concurrency is bounded by ``_PRELOAD_CONCURRENCY`` so peak memory
|
||||
stays predictable: at most N assets worth of bytes held in RAM at
|
||||
once, regardless of ``max_media_to_send``. Total wall-clock is
|
||||
unchanged for small batches and only marginally slower for large
|
||||
ones (most assets fit in a single RTT and SSL negotiation cost
|
||||
dominates, so 6-way parallelism is sufficient).
|
||||
"""
|
||||
"""Download each non-cached asset's bytes once, with SSRF guard."""
|
||||
if not assets:
|
||||
return
|
||||
|
||||
sem = asyncio.Semaphore(_PRELOAD_CONCURRENCY)
|
||||
|
||||
async def _fetch(entry: dict[str, Any], media: Any) -> None:
|
||||
# Cache hit → skip download; populate playback_size from stored size.
|
||||
async def fetch(entry: dict[str, Any], media: Any) -> None:
|
||||
cache, key = self._cache_for_entry(entry)
|
||||
if cache and key:
|
||||
cached = cache.get(key)
|
||||
@@ -251,28 +242,40 @@ class NotificationDispatcher:
|
||||
|
||||
url = entry["url"]
|
||||
headers = entry.get("headers") or {}
|
||||
try:
|
||||
# Defense-in-depth: validate even though TelegramClient
|
||||
# also validates. The dispatcher is what triggers the
|
||||
# download, so the guard belongs here too.
|
||||
await avalidate_outbound_url(url)
|
||||
except UnsafeURLError as err:
|
||||
_LOGGER.warning(
|
||||
"Asset preload skipped: unsafe URL (%s)", redact_exc(err),
|
||||
)
|
||||
return
|
||||
async with sem:
|
||||
try:
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
return
|
||||
data = await resp.read()
|
||||
except aiohttp.ClientError:
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError, OSError):
|
||||
return
|
||||
if max_size is not None and len(data) > max_size:
|
||||
return
|
||||
entry["data"] = data
|
||||
media.extra["playback_size"] = len(data)
|
||||
|
||||
await asyncio.gather(*(_fetch(e, m) for e, m in zip(assets, media_assets)))
|
||||
raw = await asyncio.gather(
|
||||
*(fetch(e, m) for e, m in zip(assets, media_assets)),
|
||||
return_exceptions=True,
|
||||
)
|
||||
for r in raw:
|
||||
if isinstance(r, Exception):
|
||||
_LOGGER.warning("Asset preload raised: %s", redact_exc(r))
|
||||
|
||||
def _cache_for_entry(
|
||||
self, entry: dict[str, Any],
|
||||
) -> tuple[TelegramFileCache | None, str | None]:
|
||||
"""Resolve (cache, key) for an asset entry — mirrors TelegramClient logic.
|
||||
|
||||
Returns (None, None) if no cache is configured or no key can be derived.
|
||||
"""
|
||||
cache_key = entry.get("cache_key")
|
||||
if cache_key:
|
||||
cache = self._asset_cache if is_asset_cache_key(cache_key) else self._url_cache
|
||||
@@ -287,6 +290,10 @@ class NotificationDispatcher:
|
||||
return self._url_cache, url
|
||||
return None, None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-provider handlers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _send_telegram(
|
||||
self, target: TargetConfig, default_message: str, event: ServiceEvent
|
||||
) -> dict[str, Any]:
|
||||
@@ -296,27 +303,25 @@ class NotificationDispatcher:
|
||||
max_media = target.config.get("max_media_to_send", 50)
|
||||
max_group = target.config.get("max_media_per_group", 10)
|
||||
chunk_delay = target.config.get("media_delay", 500)
|
||||
max_size = target.config.get("max_asset_size")
|
||||
if max_size:
|
||||
max_size = max_size * 1024 * 1024 # MB to bytes
|
||||
max_size_mb = target.config.get("max_asset_size")
|
||||
max_size_bytes = max_size_mb * 1024 * 1024 if max_size_mb else None
|
||||
send_large_as_docs = target.config.get("send_large_photos_as_documents", False)
|
||||
|
||||
if not bot_token:
|
||||
return {"success": False, "error": "Missing bot_token"}
|
||||
|
||||
if not target.receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
# Prepare assets list once (shared across receivers)
|
||||
# Prefer internal URL for fetching (LAN speed vs public internet)
|
||||
internal_url = (target.provider_internal_url or "").rstrip("/")
|
||||
external_url = (target.provider_external_url or "").rstrip("/")
|
||||
assets = []
|
||||
media_assets: list[Any] = [] # aligned with `assets` for preload
|
||||
assets: list[dict[str, Any]] = []
|
||||
media_assets: list[Any] = []
|
||||
for asset in event.added_assets[:max_media]:
|
||||
url = asset.preview_url or asset.thumbnail_url or asset.full_url
|
||||
if not url:
|
||||
continue
|
||||
asset_entry = build_telegram_asset_entry(
|
||||
url=url or "",
|
||||
url=url,
|
||||
media_type=asset.type.value,
|
||||
api_key=target.provider_api_key,
|
||||
internal_url=internal_url,
|
||||
@@ -327,26 +332,15 @@ class NotificationDispatcher:
|
||||
assets.append(asset_entry)
|
||||
media_assets.append(asset)
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with self._session_ctx() as session:
|
||||
# Preload all asset bytes once so (a) TelegramClient can skip its
|
||||
# own download and (b) we know exact upload sizes in time for the
|
||||
# oversize warning in the rendered text.
|
||||
await self._preload_asset_data(assets, media_assets, session, max_size)
|
||||
default_message = self._render_message(event, target, target.locale)
|
||||
await self._preload_asset_data(assets, media_assets, session, max_size_bytes)
|
||||
|
||||
# Asset cache (when in thumbhash mode) invalidates entries when the
|
||||
# asset's visual content changes. The resolver maps asset id → its
|
||||
# current thumbhash. Providers that expose thumbhash put it in
|
||||
# ``asset.extra["thumbhash"]`` (currently Immich).
|
||||
thumbhash_map = {
|
||||
asset.id: asset.extra.get("thumbhash")
|
||||
for asset in event.added_assets
|
||||
if asset.extra.get("thumbhash")
|
||||
}
|
||||
thumbhash_resolver = (
|
||||
(lambda key: thumbhash_map.get(key)) if thumbhash_map else None
|
||||
)
|
||||
thumbhash_resolver = thumbhash_map.get if thumbhash_map else None
|
||||
|
||||
client = TelegramClient(
|
||||
session, bot_token,
|
||||
@@ -355,39 +349,51 @@ class NotificationDispatcher:
|
||||
thumbhash_resolver=thumbhash_resolver,
|
||||
)
|
||||
|
||||
for receiver in target.receivers:
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, TelegramReceiver) or not receiver.chat_id:
|
||||
results.append({"success": False, "error": "Invalid telegram receiver"})
|
||||
continue
|
||||
|
||||
return {"success": False, "error": "Invalid telegram receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
|
||||
text_result = await client.send_message(
|
||||
chat_id=receiver.chat_id,
|
||||
text=message,
|
||||
disable_web_page_preview=bool(disable_preview),
|
||||
)
|
||||
if not text_result.get("success"):
|
||||
_LOGGER.warning("Failed to send to chat %s: %s", receiver.chat_id, text_result.get("error"))
|
||||
results.append(text_result)
|
||||
continue
|
||||
_LOGGER.warning(
|
||||
"Failed to send to chat %s: %s",
|
||||
receiver.chat_id, text_result.get("error"),
|
||||
)
|
||||
return text_result
|
||||
|
||||
if assets:
|
||||
reply_to = text_result.get("message_id")
|
||||
media_result = await client.send_notification(
|
||||
chat_id=receiver.chat_id,
|
||||
assets=assets,
|
||||
reply_to_message_id=reply_to,
|
||||
reply_to_message_id=text_result.get("message_id"),
|
||||
max_group_size=max_group,
|
||||
chunk_delay=chunk_delay,
|
||||
max_asset_data_size=max_size,
|
||||
max_asset_data_size=max_size_bytes,
|
||||
send_large_photos_as_documents=send_large_as_docs,
|
||||
chat_action=chat_action or None,
|
||||
)
|
||||
if not media_result.get("success"):
|
||||
_LOGGER.warning("Text sent OK but media failed for chat %s: %s", receiver.chat_id, media_result.get("error"))
|
||||
_LOGGER.warning(
|
||||
"Text sent OK but media failed for chat %s: %s",
|
||||
receiver.chat_id, media_result.get("error"),
|
||||
)
|
||||
# Preserve both outcomes — text succeeded, media
|
||||
# didn't. Operators losing media-failure detail
|
||||
# in the result dict made root-cause analysis
|
||||
# impossible.
|
||||
return {
|
||||
"success": True,
|
||||
"message_id": text_result.get("message_id"),
|
||||
"media_error": media_result.get("error"),
|
||||
"media_failed_at_chunk": media_result.get("failed_at_chunk"),
|
||||
}
|
||||
return text_result
|
||||
|
||||
results.append(text_result)
|
||||
results = await self._fan_out(target.receivers, send_one)
|
||||
|
||||
return self._aggregate_results(results)
|
||||
|
||||
@@ -397,17 +403,10 @@ class NotificationDispatcher:
|
||||
if not target.receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with self._session_ctx() as session:
|
||||
for receiver in target.receivers:
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
|
||||
results.append({"success": False, "error": "Invalid webhook receiver"})
|
||||
continue
|
||||
try:
|
||||
await avalidate_outbound_url(receiver.url)
|
||||
except UnsafeURLError as err:
|
||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||
continue
|
||||
return {"success": False, "error": "Invalid webhook receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
payload = {
|
||||
"message": message,
|
||||
@@ -417,8 +416,10 @@ class NotificationDispatcher:
|
||||
"collection_id": event.collection_id,
|
||||
"timestamp": event.timestamp.isoformat(),
|
||||
}
|
||||
client = WebhookClient(session, receiver.url, receiver.headers)
|
||||
results.append(await client.send(payload))
|
||||
client = WebhookClient(session, receiver.url, safe_headers(receiver.headers))
|
||||
return await client.send(payload)
|
||||
|
||||
results = await self._fan_out(target.receivers, send_one)
|
||||
|
||||
return self._aggregate_results(results)
|
||||
|
||||
@@ -431,7 +432,7 @@ class NotificationDispatcher:
|
||||
if not smtp_cfg.get("host"):
|
||||
return {"success": False, "error": "SMTP not configured"}
|
||||
|
||||
client = EmailClient(SmtpConfig(
|
||||
email_client = EmailClient(SmtpConfig(
|
||||
host=smtp_cfg["host"],
|
||||
port=int(smtp_cfg.get("port", 587)),
|
||||
username=smtp_cfg.get("username", ""),
|
||||
@@ -439,27 +440,28 @@ class NotificationDispatcher:
|
||||
from_address=smtp_cfg.get("from_address", ""),
|
||||
from_name=smtp_cfg.get("from_name", "Notify Bridge"),
|
||||
use_tls=smtp_cfg.get("use_tls", True),
|
||||
tls_mode=smtp_cfg.get("tls_mode", "auto"),
|
||||
))
|
||||
|
||||
if not target.receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
subject = f"[Notify Bridge] {event.event_type.value}: {event.collection_name}"
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for receiver in target.receivers:
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, EmailReceiver) or not receiver.email:
|
||||
results.append({"success": False, "error": "Invalid email receiver"})
|
||||
continue
|
||||
return {"success": False, "error": "Invalid email receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
result = await client.send(
|
||||
# body_html=None lets EmailClient build a safely-escaped HTML
|
||||
# alternative from body_text instead of trusting user content.
|
||||
return await email_client.send(
|
||||
to_email=receiver.email,
|
||||
subject=subject,
|
||||
body_text=message,
|
||||
body_html=message,
|
||||
body_html=None,
|
||||
to_name=receiver.name,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
results = await self._fan_out(target.receivers, send_one)
|
||||
return self._aggregate_results(results)
|
||||
|
||||
async def _send_discord(
|
||||
@@ -471,20 +473,16 @@ class NotificationDispatcher:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
username = target.config.get("username")
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with self._session_ctx() as session:
|
||||
client = DiscordClient(session)
|
||||
for receiver in target.receivers:
|
||||
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
|
||||
results.append({"success": False, "error": "Invalid discord receiver"})
|
||||
continue
|
||||
try:
|
||||
await avalidate_outbound_url(receiver.webhook_url)
|
||||
except UnsafeURLError as err:
|
||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||
continue
|
||||
return {"success": False, "error": "Invalid discord receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
results.append(await client.send(receiver.webhook_url, message, username=username))
|
||||
return await client.send(receiver.webhook_url, message, username=username)
|
||||
|
||||
results = await self._fan_out(target.receivers, send_one)
|
||||
|
||||
return self._aggregate_results(results)
|
||||
|
||||
@@ -497,20 +495,16 @@ class NotificationDispatcher:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
username = target.config.get("username")
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with self._session_ctx() as session:
|
||||
client = SlackClient(session)
|
||||
for receiver in target.receivers:
|
||||
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
|
||||
results.append({"success": False, "error": "Invalid slack receiver"})
|
||||
continue
|
||||
try:
|
||||
await avalidate_outbound_url(receiver.webhook_url)
|
||||
except UnsafeURLError as err:
|
||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||
continue
|
||||
return {"success": False, "error": "Invalid slack receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
results.append(await client.send(receiver.webhook_url, message, username=username))
|
||||
return await client.send(receiver.webhook_url, message, username=username)
|
||||
|
||||
results = await self._fan_out(target.receivers, send_one)
|
||||
|
||||
return self._aggregate_results(results)
|
||||
|
||||
@@ -526,22 +520,23 @@ class NotificationDispatcher:
|
||||
try:
|
||||
await avalidate_outbound_url(server_url)
|
||||
except UnsafeURLError as err:
|
||||
return {"success": False, "error": f"Unsafe ntfy server_url: {err}"}
|
||||
return {"success": False, "error": f"Unsafe ntfy server_url: {redact_exc(err)}"}
|
||||
|
||||
title = f"{event.event_type.value}: {event.collection_name}"
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with self._session_ctx() as session:
|
||||
client = NtfyClient(session)
|
||||
for receiver in target.receivers:
|
||||
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
|
||||
results.append({"success": False, "error": "Invalid ntfy receiver"})
|
||||
continue
|
||||
return {"success": False, "error": "Invalid ntfy receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
results.append(await client.send(
|
||||
return await client.send(
|
||||
server_url, receiver.topic, message,
|
||||
title=title, priority=receiver.priority, auth_token=auth_token,
|
||||
))
|
||||
)
|
||||
|
||||
results = await self._fan_out(target.receivers, send_one)
|
||||
|
||||
return self._aggregate_results(results)
|
||||
|
||||
@@ -557,33 +552,108 @@ class NotificationDispatcher:
|
||||
try:
|
||||
await avalidate_outbound_url(homeserver)
|
||||
except UnsafeURLError as err:
|
||||
return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"}
|
||||
return {"success": False, "error": f"Unsafe matrix homeserver_url: {redact_exc(err)}"}
|
||||
|
||||
if not target.receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with self._session_ctx() as session:
|
||||
client = MatrixClient(session, homeserver, access_token)
|
||||
for receiver in target.receivers:
|
||||
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
|
||||
results.append({"success": False, "error": "Invalid matrix receiver"})
|
||||
continue
|
||||
return {"success": False, "error": "Invalid matrix receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
results.append(await client.send_message(
|
||||
receiver.room_id, message, html_message=message,
|
||||
))
|
||||
# body_html is the same plain text — Matrix accepts the
|
||||
# raw message as both ``body`` and ``formatted_body``.
|
||||
# If templates emit HTML in the future, generate a
|
||||
# separate HTML body upstream rather than aliasing here.
|
||||
return await client.send_message(
|
||||
receiver.room_id, message, html_message=None,
|
||||
)
|
||||
|
||||
results = await self._fan_out(target.receivers, send_one)
|
||||
|
||||
return self._aggregate_results(results)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Aggregation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
async def _fan_out(
|
||||
receivers: list[Receiver],
|
||||
send_one: Callable[[Receiver], Awaitable[dict[str, Any]]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Run ``send_one`` per receiver with bounded concurrency.
|
||||
|
||||
Per-receiver exceptions are converted to failure dicts so a single
|
||||
bad receiver can't cancel its peers.
|
||||
"""
|
||||
sem = asyncio.Semaphore(_RECEIVER_CONCURRENCY)
|
||||
|
||||
async def guarded(receiver: Receiver) -> dict[str, Any]:
|
||||
async with sem:
|
||||
try:
|
||||
return await send_one(receiver)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
_LOGGER.error("Receiver send raised: %s", redact_exc(exc))
|
||||
return {"success": False, "error": redact_exc(exc)}
|
||||
|
||||
return await asyncio.gather(*(guarded(r) for r in receivers))
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Aggregate broadcast results into a single result dict."""
|
||||
"""Aggregate per-receiver results into a single target-level result.
|
||||
|
||||
Preserves the per-receiver detail under ``receivers`` so a caller
|
||||
can see exactly which receivers failed, instead of getting only
|
||||
the first error.
|
||||
"""
|
||||
if not results:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
successes = sum(1 for r in results if r.get("success"))
|
||||
if successes == len(results) and results:
|
||||
return {"success": True, "receivers": len(results)}
|
||||
elif successes > 0:
|
||||
return {"success": True, "receivers": len(results), "partial_failures": len(results) - successes}
|
||||
elif results:
|
||||
return results[0]
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
failures = len(results) - successes
|
||||
out: dict[str, Any] = {
|
||||
"success": successes > 0,
|
||||
"receivers": len(results),
|
||||
"successes": successes,
|
||||
"failures": failures,
|
||||
"results": results,
|
||||
}
|
||||
if failures:
|
||||
out["errors"] = [
|
||||
r.get("error") for r in results if not r.get("success")
|
||||
]
|
||||
if successes == 0:
|
||||
# Surface the first error at the top level for back-compat
|
||||
# with callers that only check ``error``.
|
||||
out["error"] = results[0].get("error", "All receivers failed")
|
||||
return out
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Provider registry — replaces the if/elif chain so adding a provider
|
||||
# means just registering it here, not editing dispatch logic.
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
_PROVIDER_HANDLERS: dict[str, _SendMethod] = {
|
||||
"telegram": NotificationDispatcher._send_telegram,
|
||||
"webhook": NotificationDispatcher._send_webhook,
|
||||
"email": NotificationDispatcher._send_email,
|
||||
"discord": NotificationDispatcher._send_discord,
|
||||
"slack": NotificationDispatcher._send_slack,
|
||||
"ntfy": NotificationDispatcher._send_ntfy,
|
||||
"matrix": NotificationDispatcher._send_matrix,
|
||||
}
|
||||
|
||||
|
||||
def register_provider(name: str, handler: _SendMethod) -> None:
|
||||
"""Register a new dispatcher provider at runtime.
|
||||
|
||||
Allows out-of-tree providers to extend the dispatcher without
|
||||
forking. The handler must follow the
|
||||
``async (dispatcher, target, default_message, event) -> dict`` shape.
|
||||
"""
|
||||
_PROVIDER_HANDLERS[name] = handler
|
||||
|
||||
@@ -2,14 +2,32 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import logging
|
||||
import re
|
||||
import ssl
|
||||
from dataclasses import dataclass
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any
|
||||
from email.headerregistry import Address
|
||||
from email.message import EmailMessage
|
||||
from typing import Any, Final, Literal
|
||||
|
||||
try: # Optional dependency — fail at first send rather than at import.
|
||||
import aiosmtplib
|
||||
from aiosmtplib import SMTPException
|
||||
except ImportError: # pragma: no cover
|
||||
aiosmtplib = None # type: ignore[assignment]
|
||||
|
||||
class SMTPException(Exception): # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_TIMEOUT_S: Final = 30.0
|
||||
_TlsMode = Literal["auto", "implicit", "starttls", "none"]
|
||||
# RFC 5322 lite: catches the obvious-bad addresses ("foo bar", "no-at",
|
||||
# embedded CRLF) without pretending to fully validate addresses.
|
||||
_EMAIL_RE: Final = re.compile(r"^[^\s@\r\n,;<>]+@[^\s@\r\n,;<>]+\.[^\s@\r\n,;<>]+$")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmtpConfig:
|
||||
@@ -22,6 +40,55 @@ class SmtpConfig:
|
||||
from_address: str = ""
|
||||
from_name: str = "Notify Bridge"
|
||||
use_tls: bool = True
|
||||
# Explicit TLS mode. ``auto`` (back-compat) infers from ``use_tls`` and
|
||||
# ``port``: 465 → implicit; 587 with use_tls=False → starttls; 25 → none.
|
||||
tls_mode: _TlsMode = "auto"
|
||||
timeout_s: float = _DEFAULT_TIMEOUT_S
|
||||
|
||||
|
||||
def _strip_header(value: str) -> str:
|
||||
"""Reject CRLF and bare CR/LF in header-bound strings.
|
||||
|
||||
SMTP header injection turns user-controlled subject/name strings into
|
||||
arbitrary headers (``\\r\\nBcc: attacker@x``). The Python stdlib
|
||||
accepts CRLF when followed by SP/HT (header folding), so explicit
|
||||
sanitization is required even though :class:`EmailMessage` does some
|
||||
validation of its own.
|
||||
"""
|
||||
return re.sub(r"[\r\n]+", " ", str(value or "")).strip()
|
||||
|
||||
|
||||
def _validate_email(addr: str) -> str:
|
||||
addr = _strip_header(addr)
|
||||
if not addr:
|
||||
raise ValueError("email address is empty")
|
||||
if not _EMAIL_RE.match(addr):
|
||||
raise ValueError("email address is invalid")
|
||||
return addr
|
||||
|
||||
|
||||
def _resolve_tls(cfg: SmtpConfig) -> tuple[bool, bool]:
|
||||
"""Resolve ``(use_tls, start_tls)`` flags from the config.
|
||||
|
||||
``tls_mode`` overrides ``use_tls``/port heuristics when provided.
|
||||
"""
|
||||
mode = cfg.tls_mode
|
||||
if mode == "implicit":
|
||||
return True, False
|
||||
if mode == "starttls":
|
||||
return False, True
|
||||
if mode == "none":
|
||||
return False, False
|
||||
# auto — preserve the historical "use_tls bool + port heuristic" behavior
|
||||
# but make the path explicit.
|
||||
if cfg.use_tls:
|
||||
return True, False
|
||||
return False, cfg.port != 25
|
||||
|
||||
|
||||
def _to_html(text: str) -> str:
|
||||
"""Convert plain text to a minimal HTML body, escaped for safety."""
|
||||
return "<html><body><pre>" + html.escape(text) + "</pre></body></html>"
|
||||
|
||||
|
||||
class EmailClient:
|
||||
@@ -30,30 +97,39 @@ class EmailClient:
|
||||
def __init__(self, smtp_config: SmtpConfig) -> None:
|
||||
self._config = smtp_config
|
||||
|
||||
@staticmethod
|
||||
def _ssl_context() -> ssl.SSLContext:
|
||||
# Explicit context so the TLS posture is auditable; aiosmtplib
|
||||
# defaults look correct today but past regressions (and downstream
|
||||
# repackaging) make implicit reliance fragile.
|
||||
return ssl.create_default_context()
|
||||
|
||||
async def verify_connection(self) -> dict[str, Any]:
|
||||
"""Test SMTP connection and authentication without sending an email."""
|
||||
try:
|
||||
import aiosmtplib
|
||||
except ImportError:
|
||||
if aiosmtplib is None:
|
||||
return {"success": False, "error": "aiosmtplib not installed"}
|
||||
|
||||
cfg = self._config
|
||||
if not cfg.host:
|
||||
return {"success": False, "error": "SMTP host not configured"}
|
||||
|
||||
use_tls, start_tls = _resolve_tls(cfg)
|
||||
try:
|
||||
smtp = aiosmtplib.SMTP(
|
||||
hostname=cfg.host,
|
||||
port=cfg.port,
|
||||
use_tls=cfg.use_tls,
|
||||
start_tls=not cfg.use_tls and cfg.port != 25,
|
||||
use_tls=use_tls,
|
||||
start_tls=start_tls,
|
||||
tls_context=self._ssl_context(),
|
||||
timeout=cfg.timeout_s,
|
||||
validate_certs=True,
|
||||
)
|
||||
await smtp.connect()
|
||||
if cfg.username and cfg.password:
|
||||
await smtp.login(cfg.username, cfg.password)
|
||||
await smtp.quit()
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
except (SMTPException, OSError) as e:
|
||||
_LOGGER.warning("SMTP verification failed for %s:%d: %s", cfg.host, cfg.port, e)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@@ -65,27 +141,52 @@ class EmailClient:
|
||||
body_html: str | None = None,
|
||||
to_name: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Send an email. Returns {"success": True} or {"success": False, "error": "..."}."""
|
||||
try:
|
||||
import aiosmtplib
|
||||
except ImportError:
|
||||
"""Send an email.
|
||||
|
||||
Returns ``{"success": True}`` or ``{"success": False, "error": "..."}``.
|
||||
|
||||
``body_html`` is treated as already-safe markup. Pass ``None`` to
|
||||
derive a safe HTML alternative from ``body_text`` automatically.
|
||||
"""
|
||||
if aiosmtplib is None:
|
||||
return {"success": False, "error": "aiosmtplib not installed. Run: pip install aiosmtplib"}
|
||||
|
||||
cfg = self._config
|
||||
|
||||
if not cfg.host or not cfg.from_address:
|
||||
return {"success": False, "error": "SMTP not configured (missing host or from_address)"}
|
||||
|
||||
# Build email message
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["From"] = f"{cfg.from_name} <{cfg.from_address}>" if cfg.from_name else cfg.from_address
|
||||
msg["To"] = f"{to_name} <{to_email}>" if to_name else to_email
|
||||
msg["Subject"] = subject
|
||||
try:
|
||||
to_addr = _validate_email(to_email)
|
||||
from_addr = _validate_email(cfg.from_address)
|
||||
except ValueError as exc:
|
||||
return {"success": False, "error": f"Invalid email address: {exc}"}
|
||||
|
||||
msg.attach(MIMEText(body_text, "plain", "utf-8"))
|
||||
if body_html:
|
||||
msg.attach(MIMEText(body_html, "html", "utf-8"))
|
||||
# EmailMessage with structured Address objects rejects CRLF and
|
||||
# framework-folds long headers safely. We still strip first because
|
||||
# EmailMessage's display-name slot is a pure string.
|
||||
msg = EmailMessage()
|
||||
from_display = _strip_header(cfg.from_name) or ""
|
||||
to_display = _strip_header(to_name) or ""
|
||||
try:
|
||||
from_user, _, from_domain = from_addr.partition("@")
|
||||
to_user, _, to_domain = to_addr.partition("@")
|
||||
msg["From"] = Address(from_display, from_user, from_domain) if from_display else from_addr
|
||||
msg["To"] = Address(to_display, to_user, to_domain) if to_display else to_addr
|
||||
except ValueError as exc:
|
||||
return {"success": False, "error": f"Invalid email address: {exc}"}
|
||||
msg["Subject"] = _strip_header(subject)
|
||||
|
||||
msg.set_content(body_text or "", subtype="plain", charset="utf-8")
|
||||
# If the caller provided HTML explicitly, honor it; otherwise build a
|
||||
# safe escaped version so a stray "<" in the rendered template can't
|
||||
# break the markup.
|
||||
msg.add_alternative(
|
||||
body_html if body_html is not None else _to_html(body_text or ""),
|
||||
subtype="html",
|
||||
charset="utf-8",
|
||||
)
|
||||
|
||||
use_tls, start_tls = _resolve_tls(cfg)
|
||||
try:
|
||||
await aiosmtplib.send(
|
||||
msg,
|
||||
@@ -93,11 +194,14 @@ class EmailClient:
|
||||
port=cfg.port,
|
||||
username=cfg.username or None,
|
||||
password=cfg.password or None,
|
||||
use_tls=cfg.use_tls,
|
||||
start_tls=not cfg.use_tls and cfg.port != 25,
|
||||
use_tls=use_tls,
|
||||
start_tls=start_tls,
|
||||
tls_context=self._ssl_context(),
|
||||
timeout=cfg.timeout_s,
|
||||
validate_certs=True,
|
||||
)
|
||||
_LOGGER.info("Email sent to %s", to_email)
|
||||
_LOGGER.info("Email sent to %s", to_addr)
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
_LOGGER.error("Failed to send email to %s: %s", to_email, e)
|
||||
except (SMTPException, OSError) as e:
|
||||
_LOGGER.error("Failed to send email to %s: %s", to_addr, e)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
"""Shared HTTP infrastructure for notification provider clients.
|
||||
|
||||
Slack/Discord/ntfy/Matrix/Webhook all follow the same pattern: build a
|
||||
JSON payload, POST/PUT it, decode 200-range as success, decode 4xx/5xx
|
||||
into a stable error dict, and retry transient 429/503 responses with a
|
||||
capped ``Retry-After``. ``HttpProviderClient`` centralizes that pattern
|
||||
so every provider gets the same SSRF guard, timeouts, secret-redacted
|
||||
errors, and bounded retry policy by construction — adding a new
|
||||
provider doesn't get to forget any one of them.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Final, Mapping
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .redact import redact, redact_exc
|
||||
from .ssrf import UnsafeURLError, avalidate_outbound_url
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
|
||||
_MAX_RETRIES: Final = 3
|
||||
_MAX_RETRY_AFTER_S: Final = 60.0
|
||||
_RETRY_STATUSES: Final = frozenset({429, 503})
|
||||
|
||||
# Hop-by-hop / framing headers a caller must not be able to override via
|
||||
# user-supplied target config. Letting them through enables request
|
||||
# smuggling, host-header bypasses of WAFs, and cache poisoning.
|
||||
_FORBIDDEN_HEADERS: Final = frozenset({
|
||||
"host",
|
||||
"content-length",
|
||||
"transfer-encoding",
|
||||
"connection",
|
||||
"keep-alive",
|
||||
"te",
|
||||
"upgrade",
|
||||
"expect",
|
||||
"proxy-authorization",
|
||||
"proxy-connection",
|
||||
})
|
||||
|
||||
|
||||
def safe_headers(headers: Mapping[str, str] | None) -> dict[str, str]:
|
||||
"""Return a copy of ``headers`` with hop-by-hop/forbidden names dropped
|
||||
and CRLF-bearing values rejected.
|
||||
|
||||
A target config that lets a user inject ``"X-Foo": "bar\\r\\nHost: evil"``
|
||||
can perform request smuggling depending on aiohttp's framing. We strip
|
||||
those values rather than letting them reach the wire.
|
||||
"""
|
||||
if not headers:
|
||||
return {}
|
||||
safe: dict[str, str] = {}
|
||||
for raw_name, raw_value in headers.items():
|
||||
name = str(raw_name).strip()
|
||||
if not name or name.lower() in _FORBIDDEN_HEADERS:
|
||||
continue
|
||||
if any(c in name for c in "\r\n:"):
|
||||
continue
|
||||
value = str(raw_value)
|
||||
if "\r" in value or "\n" in value:
|
||||
continue
|
||||
safe[name] = value
|
||||
return safe
|
||||
|
||||
|
||||
def make_error(message: str, *, status: int | None = None, body: str | None = None) -> dict[str, Any]:
|
||||
"""Build a stable failure dict shape used by every provider client."""
|
||||
err: dict[str, Any] = {"success": False, "error": redact(message)}
|
||||
if status is not None:
|
||||
err["status_code"] = status
|
||||
if body:
|
||||
err["body"] = redact(body)[:200]
|
||||
return err
|
||||
|
||||
|
||||
def make_success(**extra: Any) -> dict[str, Any]:
|
||||
"""Build a stable success dict shape used by every provider client."""
|
||||
out: dict[str, Any] = {"success": True}
|
||||
out.update(extra)
|
||||
return out
|
||||
|
||||
|
||||
def _retry_after_seconds(headers: Mapping[str, str], cap_s: float) -> float:
|
||||
raw = headers.get("Retry-After") or headers.get("retry-after") or "2"
|
||||
try:
|
||||
seconds = float(raw)
|
||||
except (TypeError, ValueError):
|
||||
seconds = 2.0
|
||||
return max(0.0, min(seconds, cap_s))
|
||||
|
||||
|
||||
class HttpProviderClient:
|
||||
"""Base for JSON-over-HTTP notification providers.
|
||||
|
||||
Subclasses call :meth:`request` instead of using ``self._session``
|
||||
directly. ``request`` runs the SSRF guard (skippable for known-safe
|
||||
upstreams via ``ssrf_validate=False``), enforces a per-request
|
||||
timeout, retries 429/503 with a capped ``Retry-After``, and turns
|
||||
transport/HTTP errors into the canonical ``{"success": False, ...}``
|
||||
shape with secrets redacted.
|
||||
"""
|
||||
|
||||
_max_retries: int = _MAX_RETRIES
|
||||
# Settable per-instance so tests / hostile-upstream tuning can
|
||||
# tighten the cap. Reads of this attribute fall through to the
|
||||
# class default when no instance value has been set.
|
||||
_MAX_RETRY_AFTER: float = _MAX_RETRY_AFTER_S
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: aiohttp.ClientSession,
|
||||
*,
|
||||
timeout: aiohttp.ClientTimeout | None = None,
|
||||
provider_name: str = "http",
|
||||
) -> None:
|
||||
self._session = session
|
||||
self._timeout = timeout or _DEFAULT_TIMEOUT
|
||||
self._provider = provider_name
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
json: Any = None,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
ssrf_validate: bool = True,
|
||||
retry_statuses: frozenset[int] = _RETRY_STATUSES,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a single request with retry + redaction. Always returns a dict.
|
||||
|
||||
On 2xx returns ``{"success": True, "status_code": int, "json": ...
|
||||
OR "body": str}``. On non-2xx returns the canonical error dict.
|
||||
"""
|
||||
if ssrf_validate:
|
||||
try:
|
||||
await avalidate_outbound_url(url)
|
||||
except UnsafeURLError as err:
|
||||
return make_error(f"Unsafe URL: {redact_exc(err)}")
|
||||
|
||||
outbound_headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
outbound_headers.update(safe_headers(headers))
|
||||
|
||||
for attempt in range(1, self._max_retries + 1):
|
||||
try:
|
||||
async with self._session.request(
|
||||
method,
|
||||
url,
|
||||
json=json,
|
||||
headers=outbound_headers,
|
||||
timeout=self._timeout,
|
||||
allow_redirects=False,
|
||||
) as resp:
|
||||
if resp.status in retry_statuses and attempt < self._max_retries:
|
||||
delay = _retry_after_seconds(resp.headers, self._MAX_RETRY_AFTER)
|
||||
_LOGGER.warning(
|
||||
"%s %s %s: HTTP %d, retrying after %.2fs (attempt %d/%d)",
|
||||
self._provider, method, redact(url), resp.status,
|
||||
delay, attempt, self._max_retries,
|
||||
)
|
||||
await resp.read() # drain body so connection can return to pool
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
if 200 <= resp.status < 300:
|
||||
try:
|
||||
payload: Any = await resp.json(content_type=None)
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
payload = await resp.text()
|
||||
return make_success(status_code=resp.status, json=payload)
|
||||
body = await resp.text()
|
||||
return make_error(
|
||||
f"HTTP {resp.status}",
|
||||
status=resp.status,
|
||||
body=body,
|
||||
)
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err:
|
||||
# asyncio.CancelledError inherits from BaseException on
|
||||
# 3.8+, so it is not caught here — good: cancellation
|
||||
# must propagate.
|
||||
if attempt < self._max_retries and isinstance(err, asyncio.TimeoutError):
|
||||
_LOGGER.warning(
|
||||
"%s %s %s: timeout, retrying (attempt %d/%d)",
|
||||
self._provider, method, redact(url),
|
||||
attempt, self._max_retries,
|
||||
)
|
||||
await asyncio.sleep(min(2 ** (attempt - 1), 5))
|
||||
continue
|
||||
return make_error(redact_exc(err))
|
||||
|
||||
# Retry budget exhausted on a retriable status.
|
||||
return make_error("Rate limited (retries exhausted)")
|
||||
@@ -2,22 +2,36 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Final
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..http_base import _MAX_RETRY_AFTER_S, safe_headers
|
||||
from ..redact import redact, redact_exc
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Monotonically increasing transaction counter for idempotent sends
|
||||
_txn_counter = int(time.time() * 1000)
|
||||
# Matrix room IDs are ``!opaque:server.name`` per the spec. We also allow
|
||||
# the ``#alias:server`` form because some callers may pass aliases. The
|
||||
# pattern's purpose is to reject obvious path-injection (``/``, ``..``,
|
||||
# control chars, query/fragment chars) before the value reaches a URL.
|
||||
_ROOM_ID_RE: Final = re.compile(r"^[!#][^\x00-\x1f\s/?#]{1,255}:[A-Za-z0-9.\-:]{1,255}$")
|
||||
|
||||
_DEFAULT_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
|
||||
_MAX_RETRIES: Final = 3
|
||||
|
||||
|
||||
def _next_txn_id() -> str:
|
||||
global _txn_counter
|
||||
_txn_counter += 1
|
||||
return str(_txn_counter)
|
||||
def _validate_room_id(room_id: str) -> str:
|
||||
if not room_id:
|
||||
raise ValueError("room_id is empty")
|
||||
if not _ROOM_ID_RE.match(room_id):
|
||||
raise ValueError("room_id format is invalid")
|
||||
return room_id
|
||||
|
||||
|
||||
class MatrixClient:
|
||||
@@ -33,49 +47,67 @@ class MatrixClient:
|
||||
self._homeserver = homeserver_url.rstrip("/")
|
||||
self._token = access_token
|
||||
|
||||
@staticmethod
|
||||
def _txn_id() -> str:
|
||||
# uuid4 hex is collision-resistant across processes/restarts;
|
||||
# eliminates the previous module-level counter race.
|
||||
return uuid.uuid4().hex
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
room_id: str,
|
||||
message: str,
|
||||
html_message: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a text message to a Matrix room.
|
||||
"""Send a text message to a Matrix room."""
|
||||
try:
|
||||
room_id = _validate_room_id(room_id)
|
||||
except ValueError as exc:
|
||||
return {"success": False, "error": f"Invalid room_id: {exc}"}
|
||||
|
||||
Args:
|
||||
room_id: Internal room ID (e.g. !abc:matrix.org)
|
||||
message: Plain text body
|
||||
html_message: Optional HTML-formatted body
|
||||
"""
|
||||
if not room_id:
|
||||
return {"success": False, "error": "Missing room_id"}
|
||||
encoded_room = quote(room_id, safe="")
|
||||
url = (
|
||||
f"{self._homeserver}/_matrix/client/v3/rooms/{encoded_room}"
|
||||
f"/send/m.room.message/{self._txn_id()}"
|
||||
)
|
||||
|
||||
txn_id = _next_txn_id()
|
||||
# URL-encode the room_id (! and : need encoding)
|
||||
encoded_room = room_id.replace("!", "%21").replace(":", "%3A")
|
||||
url = f"{self._homeserver}/_matrix/client/v3/rooms/{encoded_room}/send/m.room.message/{txn_id}"
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"msgtype": "m.text",
|
||||
"body": message,
|
||||
}
|
||||
body: dict[str, Any] = {"msgtype": "m.text", "body": message}
|
||||
if html_message:
|
||||
body["format"] = "org.matrix.custom.html"
|
||||
body["formatted_body"] = html_message
|
||||
|
||||
headers = {
|
||||
headers = safe_headers({
|
||||
"Authorization": f"Bearer {self._token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
})
|
||||
|
||||
try:
|
||||
async with self._session.put(
|
||||
url, json=body, headers=headers, allow_redirects=False,
|
||||
) as resp:
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
resp_body = await resp.text()
|
||||
if resp.status == 429:
|
||||
_LOGGER.warning("Matrix rate limited: %s", resp_body[:200])
|
||||
return {"success": False, "error": f"HTTP {resp.status}: {resp_body[:200]}"}
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
for attempt in range(1, _MAX_RETRIES + 1):
|
||||
try:
|
||||
async with self._session.put(
|
||||
url, json=body, headers=headers,
|
||||
timeout=_DEFAULT_TIMEOUT, allow_redirects=False,
|
||||
) as resp:
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
resp_body = await resp.text()
|
||||
if resp.status == 429 and attempt < _MAX_RETRIES:
|
||||
try:
|
||||
wait_s = float(resp.headers.get("Retry-After", "2"))
|
||||
except (TypeError, ValueError):
|
||||
wait_s = 2.0
|
||||
wait_s = max(0.0, min(wait_s, _MAX_RETRY_AFTER_S))
|
||||
_LOGGER.warning(
|
||||
"Matrix rate limited, retrying after %.2fs (attempt %d/%d)",
|
||||
wait_s, attempt, _MAX_RETRIES,
|
||||
)
|
||||
await asyncio.sleep(wait_s)
|
||||
continue
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"HTTP {resp.status}: {redact(resp_body)[:200]}",
|
||||
"status_code": resp.status,
|
||||
}
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
|
||||
return {"success": False, "error": redact_exc(e)}
|
||||
|
||||
return {"success": False, "error": "Rate limited (retries exhausted)"}
|
||||
|
||||
@@ -3,18 +3,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Final
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..http_base import HttpProviderClient
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_PRIORITY_MIN: Final = 1
|
||||
_PRIORITY_MAX: Final = 5
|
||||
_DEFAULT_PRIORITY: Final = 3
|
||||
_MAX_TAGS: Final = 10
|
||||
_MAX_TAG_LEN: Final = 64
|
||||
|
||||
class NtfyClient:
|
||||
|
||||
def _strip_crlf(value: str) -> str:
|
||||
"""Remove CR/LF — ntfy's JSON path is safe today, but the same fields
|
||||
are used by the header API; defensive sanitization here means a future
|
||||
refactor can't accidentally re-introduce header injection."""
|
||||
return value.replace("\r", " ").replace("\n", " ")
|
||||
|
||||
|
||||
class NtfyClient(HttpProviderClient):
|
||||
"""Sends push notifications via ntfy server."""
|
||||
|
||||
def __init__(self, session: aiohttp.ClientSession) -> None:
|
||||
self._session = session
|
||||
super().__init__(session, provider_name="ntfy")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
@@ -22,41 +37,48 @@ class NtfyClient:
|
||||
topic: str,
|
||||
message: str,
|
||||
title: str | None = None,
|
||||
priority: int = 3,
|
||||
priority: int = _DEFAULT_PRIORITY,
|
||||
tags: list[str] | None = None,
|
||||
click_url: str | None = None,
|
||||
auth_token: str | None = None,
|
||||
markdown: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a push notification to an ntfy topic."""
|
||||
if not server_url or not topic:
|
||||
return {"success": False, "error": "Missing server_url or topic"}
|
||||
|
||||
url = f"{server_url.rstrip('/')}"
|
||||
topic = _strip_crlf(topic).strip()
|
||||
if not topic:
|
||||
return {"success": False, "error": "Topic is empty after sanitization"}
|
||||
|
||||
try:
|
||||
priority_int = int(priority) if priority is not None else _DEFAULT_PRIORITY
|
||||
except (TypeError, ValueError):
|
||||
priority_int = _DEFAULT_PRIORITY
|
||||
priority_int = max(_PRIORITY_MIN, min(priority_int, _PRIORITY_MAX))
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"topic": topic,
|
||||
"message": message,
|
||||
"markdown": True,
|
||||
"markdown": bool(markdown),
|
||||
}
|
||||
if title:
|
||||
payload["title"] = title
|
||||
if priority != 3:
|
||||
payload["priority"] = priority
|
||||
payload["title"] = _strip_crlf(title)
|
||||
if priority_int != _DEFAULT_PRIORITY:
|
||||
payload["priority"] = priority_int
|
||||
if tags:
|
||||
payload["tags"] = tags
|
||||
cleaned = [
|
||||
_strip_crlf(str(t))[:_MAX_TAG_LEN]
|
||||
for t in tags[:_MAX_TAGS]
|
||||
if t
|
||||
]
|
||||
if cleaned:
|
||||
payload["tags"] = cleaned
|
||||
if click_url:
|
||||
payload["click"] = click_url
|
||||
payload["click"] = _strip_crlf(click_url)
|
||||
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
headers: dict[str, str] = {}
|
||||
if auth_token:
|
||||
headers["Authorization"] = f"Bearer {auth_token}"
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, json=payload, headers=headers, allow_redirects=False,
|
||||
) as resp:
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
body = await resp.text()
|
||||
return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
return await self.request("POST", server_url.rstrip("/"), json=payload, headers=headers)
|
||||
|
||||
@@ -2,47 +2,88 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import Any, Final
|
||||
|
||||
from notify_bridge_core.storage import StorageBackend
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Bound on queue length. Without a cap, a misconfigured quiet-hour
|
||||
# window plus high event throughput grows the persisted file unboundedly
|
||||
# and every enqueue rewrites the whole file (O(n²) total writes). When
|
||||
# the cap is hit we drop the oldest entry (FIFO) so the most recent
|
||||
# events still reach the recipient when the window opens.
|
||||
DEFAULT_MAX_QUEUE_SIZE: Final = 1000
|
||||
|
||||
|
||||
class NotificationQueue:
|
||||
"""Persistent queue for notifications deferred during quiet hours."""
|
||||
|
||||
def __init__(self, backend: StorageBackend) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
backend: StorageBackend,
|
||||
*,
|
||||
max_size: int = DEFAULT_MAX_QUEUE_SIZE,
|
||||
) -> None:
|
||||
self._backend = backend
|
||||
self._data: dict[str, Any] | None = None
|
||||
self._max_size = max_size
|
||||
# Coordinates load / enqueue / clear / remove so a write-while-load
|
||||
# race can't leave the in-memory copy out of sync with disk and so
|
||||
# bulk operations don't interleave their reads-then-writes.
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _ensure_schema(data: Any) -> dict[str, Any]:
|
||||
if not isinstance(data, dict) or not isinstance(data.get("queue"), list):
|
||||
return {"queue": []}
|
||||
return data
|
||||
|
||||
async def async_load(self) -> None:
|
||||
self._data = await self._backend.load() or {"queue": []}
|
||||
async with self._lock:
|
||||
raw = await self._backend.load()
|
||||
self._data = self._ensure_schema(raw)
|
||||
|
||||
async def async_enqueue(self, notification_params: dict[str, Any]) -> None:
|
||||
if self._data is None:
|
||||
self._data = {"queue": []}
|
||||
self._data["queue"].append({
|
||||
"params": notification_params,
|
||||
"queued_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
await self._backend.save(self._data)
|
||||
async with self._lock:
|
||||
if self._data is None:
|
||||
self._data = {"queue": []}
|
||||
queue: list[dict[str, Any]] = self._data["queue"]
|
||||
queue.append({
|
||||
"params": notification_params,
|
||||
"queued_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
if self._max_size > 0 and len(queue) > self._max_size:
|
||||
# Drop oldest (FIFO) so a new event can still land.
|
||||
drop = len(queue) - self._max_size
|
||||
_LOGGER.warning(
|
||||
"NotificationQueue: dropping %d oldest entries (cap=%d)",
|
||||
drop, self._max_size,
|
||||
)
|
||||
del queue[:drop]
|
||||
await self._backend.save(self._data)
|
||||
|
||||
def get_all(self) -> list[dict[str, Any]]:
|
||||
if not self._data:
|
||||
return []
|
||||
return list(self._data.get("queue", []))
|
||||
# Deep copy so callers can iterate / mutate without corrupting the
|
||||
# in-memory queue. The cost is bounded by ``max_size``.
|
||||
return copy.deepcopy(list(self._data.get("queue", [])))
|
||||
|
||||
def has_pending(self) -> bool:
|
||||
return bool(self._data and self._data.get("queue"))
|
||||
|
||||
async def async_clear(self) -> None:
|
||||
if self._data:
|
||||
self._data["queue"] = []
|
||||
await self._backend.save(self._data)
|
||||
async with self._lock:
|
||||
if self._data:
|
||||
self._data["queue"] = []
|
||||
await self._backend.save(self._data)
|
||||
|
||||
async def async_remove(self) -> None:
|
||||
await self._backend.remove()
|
||||
self._data = None
|
||||
async with self._lock:
|
||||
await self._backend.remove()
|
||||
self._data = None
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -70,51 +70,64 @@ class MatrixReceiver(Receiver):
|
||||
room_id: str = ""
|
||||
|
||||
|
||||
_ReceiverFactory = Callable[[str, dict[str, Any]], Receiver]
|
||||
|
||||
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
_RECEIVER_FACTORIES: dict[str, _ReceiverFactory] = {
|
||||
"telegram": lambda locale, config: TelegramReceiver(
|
||||
locale=locale, config=config, chat_id=str(config.get("chat_id", "")),
|
||||
),
|
||||
"webhook": lambda locale, config: WebhookReceiver(
|
||||
locale=locale, config=config,
|
||||
url=str(config.get("url", "")),
|
||||
headers=dict(config.get("headers", {}) or {}),
|
||||
),
|
||||
"email": lambda locale, config: EmailReceiver(
|
||||
locale=locale, config=config,
|
||||
email=str(config.get("email", "")),
|
||||
name=str(config.get("name", "")),
|
||||
),
|
||||
"discord": lambda locale, config: DiscordReceiver(
|
||||
locale=locale, config=config,
|
||||
webhook_url=str(config.get("webhook_url", "")),
|
||||
),
|
||||
"slack": lambda locale, config: SlackReceiver(
|
||||
locale=locale, config=config,
|
||||
webhook_url=str(config.get("webhook_url", "")),
|
||||
),
|
||||
"ntfy": lambda locale, config: NtfyReceiver(
|
||||
locale=locale, config=config,
|
||||
topic=str(config.get("topic", "")),
|
||||
priority=_coerce_int(config.get("priority"), 3),
|
||||
),
|
||||
"matrix": lambda locale, config: MatrixReceiver(
|
||||
locale=locale, config=config,
|
||||
room_id=str(config.get("room_id", "")),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def register_receiver_factory(target_type: str, factory: _ReceiverFactory) -> None:
|
||||
"""Register a receiver factory for an out-of-tree target type."""
|
||||
_RECEIVER_FACTORIES[target_type] = factory
|
||||
|
||||
|
||||
def build_receiver(target_type: str, config: dict[str, Any], locale: str = "") -> Receiver:
|
||||
"""Factory: build typed Receiver from target type and config dict."""
|
||||
if target_type == "telegram":
|
||||
return TelegramReceiver(
|
||||
locale=locale,
|
||||
config=config,
|
||||
chat_id=str(config.get("chat_id", "")),
|
||||
)
|
||||
if target_type == "webhook":
|
||||
return WebhookReceiver(
|
||||
locale=locale,
|
||||
config=config,
|
||||
url=config.get("url", ""),
|
||||
headers=config.get("headers", {}),
|
||||
)
|
||||
if target_type == "email":
|
||||
return EmailReceiver(
|
||||
locale=locale,
|
||||
config=config,
|
||||
email=config.get("email", ""),
|
||||
name=config.get("name", ""),
|
||||
)
|
||||
if target_type == "discord":
|
||||
return DiscordReceiver(
|
||||
locale=locale,
|
||||
config=config,
|
||||
webhook_url=config.get("webhook_url", ""),
|
||||
)
|
||||
if target_type == "slack":
|
||||
return SlackReceiver(
|
||||
locale=locale,
|
||||
config=config,
|
||||
webhook_url=config.get("webhook_url", ""),
|
||||
)
|
||||
if target_type == "ntfy":
|
||||
return NtfyReceiver(
|
||||
locale=locale,
|
||||
config=config,
|
||||
topic=config.get("topic", ""),
|
||||
priority=config.get("priority", 3),
|
||||
)
|
||||
if target_type == "matrix":
|
||||
return MatrixReceiver(
|
||||
locale=locale,
|
||||
config=config,
|
||||
room_id=config.get("room_id", ""),
|
||||
)
|
||||
return Receiver(locale=locale, config=config)
|
||||
"""Factory: build typed Receiver from target type and config dict.
|
||||
|
||||
Falls back to a base ``Receiver`` for unknown target types so callers
|
||||
that handle types defensively still receive a usable object — but the
|
||||
dispatcher rejects them with ``"Unknown target type"`` so a typo can't
|
||||
silently route to nowhere.
|
||||
"""
|
||||
factory = _RECEIVER_FACTORIES.get(target_type)
|
||||
if factory is None:
|
||||
return Receiver(locale=locale, config=config)
|
||||
return factory(locale, config)
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Secret-redaction helpers for log lines and error strings.
|
||||
|
||||
Notification clients embed secrets in URLs (Telegram bot tokens) and
|
||||
Authorization headers (Matrix access tokens, ntfy bearer tokens). When
|
||||
those secrets surface in ``aiohttp.ClientError.__str__``, response
|
||||
bodies, or operator-visible error fields, they leak into logs and into
|
||||
the per-target result dict that callers may forward upstream. ``redact``
|
||||
returns a defanged copy safe for both contexts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Final
|
||||
|
||||
# api.telegram.org/bot<digits>:<token>/<method>
|
||||
_TELEGRAM_BOT_TOKEN_RE: Final = re.compile(
|
||||
r"(api\.telegram\.org/bot)\d+:[A-Za-z0-9_-]+", re.IGNORECASE,
|
||||
)
|
||||
# Authorization: Bearer <token> (header form, case-insensitive)
|
||||
_BEARER_RE: Final = re.compile(r"(Bearer\s+)[A-Za-z0-9._\-+/=]+", re.IGNORECASE)
|
||||
# Discord webhook: /api/webhooks/<id>/<token>
|
||||
_DISCORD_WEBHOOK_RE: Final = re.compile(
|
||||
r"(discord(?:app)?\.com/api/webhooks/\d+/)[A-Za-z0-9_-]+",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# Slack webhook path: /services/T.../B.../<token>
|
||||
_SLACK_WEBHOOK_RE: Final = re.compile(
|
||||
r"(hooks\.slack\.com/services/[A-Z0-9]+/[A-Z0-9]+/)[A-Za-z0-9]+",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# URL userinfo: scheme://user:password@host
|
||||
_URL_USERINFO_RE: Final = re.compile(
|
||||
r"([a-z][a-z0-9+\-.]*://)[^/@\s]+:[^/@\s]+@",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# Common token query parameters
|
||||
_QUERY_TOKEN_RE: Final = re.compile(
|
||||
r"([?&](?:token|access_token|api_key|key|secret|password)=)[^&\s]+",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def redact(text: str) -> str:
|
||||
"""Return ``text`` with known secret patterns replaced by ``***``.
|
||||
|
||||
Idempotent and safe to call on already-redacted strings. Always
|
||||
returns a ``str``; non-strings are coerced via ``str()`` so callers
|
||||
can pass exception instances directly.
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
text = str(text)
|
||||
text = _TELEGRAM_BOT_TOKEN_RE.sub(r"\1***", text)
|
||||
text = _DISCORD_WEBHOOK_RE.sub(r"\1***", text)
|
||||
text = _SLACK_WEBHOOK_RE.sub(r"\1***", text)
|
||||
text = _BEARER_RE.sub(r"\1***", text)
|
||||
text = _URL_USERINFO_RE.sub(r"\1***@", text)
|
||||
text = _QUERY_TOKEN_RE.sub(r"\1***", text)
|
||||
return text
|
||||
|
||||
|
||||
def redact_exc(err: BaseException) -> str:
|
||||
"""Redact-and-stringify an exception. Convenience for error fields."""
|
||||
return redact(str(err))
|
||||
@@ -7,14 +7,16 @@ from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..http_base import HttpProviderClient
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackClient:
|
||||
class SlackClient(HttpProviderClient):
|
||||
"""Sends messages via Slack incoming webhook URLs."""
|
||||
|
||||
def __init__(self, session: aiohttp.ClientSession) -> None:
|
||||
self._session = session
|
||||
super().__init__(session, provider_name="slack")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
@@ -33,19 +35,4 @@ class SlackClient:
|
||||
if icon_emoji:
|
||||
payload["icon_emoji"] = icon_emoji
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
webhook_url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
allow_redirects=False,
|
||||
) as resp:
|
||||
if resp.status == 429:
|
||||
_LOGGER.warning("Slack rate limited")
|
||||
return {"success": False, "error": "Rate limited by Slack"}
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
body = await resp.text()
|
||||
return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
return await self.request("POST", webhook_url, json=payload)
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
"""Outbound URL validation to mitigate SSRF attacks.
|
||||
|
||||
User-controlled URLs (provider `url`, webhook target `url`, shared-link
|
||||
base URLs, image downloads) must be validated before any HTTP request is
|
||||
issued. This module rejects schemes other than http/https and blocks
|
||||
destinations that resolve to private, loopback, link-local, or unspecified
|
||||
address ranges.
|
||||
User-controlled URLs (provider ``url``, webhook target ``url``,
|
||||
shared-link base URLs, image downloads) must be validated before any
|
||||
HTTP request is issued. This module rejects schemes other than
|
||||
http/https and blocks destinations that resolve to private, loopback,
|
||||
link-local, unspecified, CGNAT (100.64.0.0/10), or IPv4-mapped IPv6
|
||||
ranges.
|
||||
|
||||
DNS rebinding mitigation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
``avalidate_outbound_url`` returns the original URL on success, but
|
||||
also returns the resolved IP it actually validated. Callers that pass
|
||||
the validated URL straight into ``aiohttp`` are vulnerable to a
|
||||
DNS-rebinding attack: the validator's ``getaddrinfo`` returns a public
|
||||
IP; aiohttp's connect-time resolution returns ``127.0.0.1``. To close
|
||||
that gap, use :func:`build_ssrf_safe_session` (or
|
||||
:class:`PinnedResolver`) so the resolved IP from the validation step is
|
||||
the one aiohttp connects to.
|
||||
|
||||
Set ``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1`` in the environment for
|
||||
development against localhost services.
|
||||
@@ -17,12 +29,20 @@ import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1"
|
||||
_ALLOWED_SCHEMES = {"http", "https"}
|
||||
_ALLOWED_SCHEMES = frozenset({"http", "https"})
|
||||
|
||||
# Carrier-grade NAT range. Not in stdlib's ``is_private``; an attacker
|
||||
# pointing a domain at a CGNAT IP could reach the operator's ISP-side
|
||||
# routing infrastructure. RFC 6598.
|
||||
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
|
||||
|
||||
if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner
|
||||
_LOGGER.warning(
|
||||
@@ -36,7 +56,29 @@ class UnsafeURLError(ValueError):
|
||||
"""Raised when a URL targets a disallowed network destination."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ValidatedURL:
|
||||
"""Result of validating an outbound URL.
|
||||
|
||||
Attributes:
|
||||
url: The original URL string (unchanged).
|
||||
host: Hostname extracted from the URL (lower-cased, IDN-encoded).
|
||||
ip: Resolved IP address that passed the block-range check, as a
|
||||
string. Pass to :class:`PinnedResolver` to defeat DNS
|
||||
rebinding by reusing this exact IP at connect time.
|
||||
"""
|
||||
|
||||
url: str
|
||||
host: str
|
||||
ip: str
|
||||
|
||||
|
||||
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
# An IPv4-mapped IPv6 like ``::ffff:127.0.0.1`` is NOT considered
|
||||
# ``is_private`` etc. by stdlib — the v4 view holds those flags. So
|
||||
# we unwrap before checking.
|
||||
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None:
|
||||
ip = ip.ipv4_mapped
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
@@ -44,22 +86,54 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
or ip.is_unspecified
|
||||
or (isinstance(ip, ipaddress.IPv4Address) and ip in _CGNAT_NETWORK)
|
||||
)
|
||||
|
||||
|
||||
def _safe_host_repr(host: str) -> str:
|
||||
"""Return ``host`` shortened/escaped for safe inclusion in error text."""
|
||||
h = host[:64].replace("\r", "").replace("\n", "")
|
||||
return h
|
||||
|
||||
|
||||
def _normalize_host(parsed_host: str) -> str:
|
||||
"""Normalize a hostname: lowercase, strip trailing dot, IDN-encode."""
|
||||
host = parsed_host.lower()
|
||||
if host.endswith("."):
|
||||
host = host[:-1]
|
||||
# Strip IPv6 zone id ("fe80::1%eth0") — must not reach the resolver.
|
||||
if "%" in host:
|
||||
host = host.split("%", 1)[0]
|
||||
# IDN-encode unicode hostnames so we don't downgrade to confusables
|
||||
# in any later log/output and so getaddrinfo gets the ascii form.
|
||||
try:
|
||||
if any(ord(c) > 127 for c in host):
|
||||
host = host.encode("idna").decode("ascii")
|
||||
except UnicodeError:
|
||||
# Caller will fail on resolution; leave as-is so the error path
|
||||
# surfaces a "DNS resolution failed" rather than a stack trace.
|
||||
pass
|
||||
return host
|
||||
|
||||
|
||||
def _check_scheme_host(url: str) -> tuple[str, str]:
|
||||
if not isinstance(url, str) or not url:
|
||||
raise UnsafeURLError("URL is empty")
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in _ALLOWED_SCHEMES:
|
||||
raise UnsafeURLError(f"Scheme '{parsed.scheme}' not allowed")
|
||||
scheme = parsed.scheme.lower()
|
||||
if scheme not in _ALLOWED_SCHEMES:
|
||||
raise UnsafeURLError(f"Scheme '{scheme[:16]}' not allowed")
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise UnsafeURLError("URL has no host")
|
||||
return parsed.scheme, host
|
||||
return scheme, _normalize_host(host)
|
||||
|
||||
|
||||
def _check_resolved_addresses(host: str, infos: list[tuple]) -> None:
|
||||
def _select_addresses(
|
||||
host: str, infos: list[tuple],
|
||||
) -> list[ipaddress.IPv4Address | ipaddress.IPv6Address]:
|
||||
"""Return parsed, non-blocked IPs from ``getaddrinfo`` results."""
|
||||
addrs: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = []
|
||||
for info in infos:
|
||||
sockaddr = info[4]
|
||||
try:
|
||||
@@ -67,64 +141,143 @@ def _check_resolved_addresses(host: str, infos: list[tuple]) -> None:
|
||||
except ValueError:
|
||||
continue
|
||||
if _is_blocked_ip(ip):
|
||||
raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
|
||||
raise UnsafeURLError(
|
||||
f"Host {_safe_host_repr(host)} resolves to blocked address {ip}"
|
||||
)
|
||||
addrs.append(ip)
|
||||
if not addrs:
|
||||
raise UnsafeURLError(f"Host {_safe_host_repr(host)} has no usable address")
|
||||
return addrs
|
||||
|
||||
|
||||
def validate_outbound_url(url: str) -> str:
|
||||
"""Validate ``url`` is safe to fetch; returns the URL on success.
|
||||
|
||||
Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
|
||||
is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
|
||||
private addresses are permitted but the scheme check still applies.
|
||||
|
||||
Synchronous; uses blocking ``socket.getaddrinfo``. Prefer
|
||||
:func:`avalidate_outbound_url` from async code paths.
|
||||
.. deprecated::
|
||||
Synchronous; uses blocking ``socket.getaddrinfo``. Prefer
|
||||
:func:`avalidate_outbound_url` from async code paths so the
|
||||
event loop isn't blocked, and use :func:`build_ssrf_safe_session`
|
||||
to defeat DNS rebinding.
|
||||
"""
|
||||
_, host = _check_scheme_host(url)
|
||||
|
||||
if _ALLOW_PRIVATE:
|
||||
return url
|
||||
|
||||
# Literal IP host
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if _is_blocked_ip(ip):
|
||||
raise UnsafeURLError(f"Host {host} is in a blocked range")
|
||||
raise UnsafeURLError(f"Host {_safe_host_repr(host)} is in a blocked range")
|
||||
return url
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
infos = socket.getaddrinfo(host, None)
|
||||
except socket.gaierror as exc:
|
||||
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
|
||||
_check_resolved_addresses(host, infos)
|
||||
except (socket.gaierror, UnicodeError, OSError) as exc:
|
||||
# ``UnicodeError`` covers IDNA failures (labels >63 chars, malformed
|
||||
# unicode) which getaddrinfo surfaces as encoding errors rather than
|
||||
# gaierror. ``OSError`` covers transient resolver failures on some
|
||||
# platforms.
|
||||
raise UnsafeURLError(f"DNS resolution failed for {_safe_host_repr(host)}") from exc
|
||||
_select_addresses(host, infos)
|
||||
return url
|
||||
|
||||
|
||||
async def avalidate_outbound_url(url: str) -> str:
|
||||
"""Async variant that resolves DNS via the running loop's resolver.
|
||||
"""Async variant — returns the URL on success.
|
||||
|
||||
Use this from ``async def`` code paths to avoid blocking the event
|
||||
loop on DNS lookups.
|
||||
For DNS-rebinding-safe usage, prefer :func:`avalidate_outbound_url_full`
|
||||
which also returns the resolved IP for connect-time pinning.
|
||||
"""
|
||||
result = await avalidate_outbound_url_full(url)
|
||||
return result.url
|
||||
|
||||
|
||||
async def avalidate_outbound_url_full(url: str) -> ValidatedURL:
|
||||
"""Validate ``url`` and return a :class:`ValidatedURL` on success.
|
||||
|
||||
The returned ``ip`` field is the IP that passed the block-range
|
||||
check. Pair with :class:`PinnedResolver` so aiohttp connects to that
|
||||
exact IP and a malicious DNS server can't swap in a private address
|
||||
after validation.
|
||||
"""
|
||||
_, host = _check_scheme_host(url)
|
||||
|
||||
if _ALLOW_PRIVATE:
|
||||
return url
|
||||
# In dev mode we still resolve to give a usable IP, but we don't
|
||||
# gate on the result.
|
||||
try:
|
||||
ip = str(ipaddress.ip_address(host))
|
||||
except ValueError:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
infos = await loop.getaddrinfo(host, None)
|
||||
ip = infos[0][4][0] if infos else host
|
||||
except (socket.gaierror, OSError):
|
||||
ip = host
|
||||
return ValidatedURL(url=url, host=host, ip=ip)
|
||||
|
||||
# Literal IP host
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if _is_blocked_ip(ip):
|
||||
raise UnsafeURLError(f"Host {host} is in a blocked range")
|
||||
return url
|
||||
ip_obj = ipaddress.ip_address(host)
|
||||
if _is_blocked_ip(ip_obj):
|
||||
raise UnsafeURLError(f"Host {_safe_host_repr(host)} is in a blocked range")
|
||||
return ValidatedURL(url=url, host=host, ip=str(ip_obj))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
infos = await loop.getaddrinfo(host, None)
|
||||
except socket.gaierror as exc:
|
||||
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
|
||||
_check_resolved_addresses(host, infos)
|
||||
return url
|
||||
except (socket.gaierror, UnicodeError, OSError) as exc:
|
||||
raise UnsafeURLError(f"DNS resolution failed for {_safe_host_repr(host)}") from exc
|
||||
addrs = _select_addresses(host, infos)
|
||||
return ValidatedURL(url=url, host=host, ip=str(addrs[0]))
|
||||
|
||||
|
||||
class PinnedResolver(aiohttp.abc.AbstractResolver):
|
||||
"""aiohttp resolver that returns a fixed (host, ip) mapping.
|
||||
|
||||
Used to pin the resolved IP from :func:`avalidate_outbound_url_full`
|
||||
so aiohttp's connect-time resolution can't be tricked by DNS
|
||||
rebinding into using a different IP than the one we validated.
|
||||
|
||||
Falls back to :class:`aiohttp.AsyncResolver` (or default) for any
|
||||
host not explicitly pinned, so a single resolver instance can be
|
||||
reused across multiple validated URLs.
|
||||
"""
|
||||
|
||||
def __init__(self, mapping: dict[str, str] | None = None) -> None:
|
||||
self._map: dict[str, str] = dict(mapping or {})
|
||||
self._fallback: aiohttp.abc.AbstractResolver | None = None
|
||||
|
||||
def pin(self, host: str, ip: str) -> None:
|
||||
self._map[host.lower()] = ip
|
||||
|
||||
async def resolve(
|
||||
self, host: str, port: int = 0, family: int = socket.AF_INET,
|
||||
) -> list[dict]:
|
||||
ip = self._map.get(host.lower())
|
||||
if ip is not None:
|
||||
try:
|
||||
ip_obj = ipaddress.ip_address(ip)
|
||||
except ValueError:
|
||||
ip_obj = None
|
||||
if ip_obj is not None:
|
||||
fam = socket.AF_INET6 if ip_obj.version == 6 else socket.AF_INET
|
||||
return [{
|
||||
"hostname": host,
|
||||
"host": ip,
|
||||
"port": port,
|
||||
"family": fam,
|
||||
"proto": 0,
|
||||
"flags": socket.AI_NUMERICHOST,
|
||||
}]
|
||||
if self._fallback is None:
|
||||
self._fallback = aiohttp.ThreadedResolver()
|
||||
return await self._fallback.resolve(host, port, family)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._fallback is not None:
|
||||
await self._fallback.close()
|
||||
|
||||
@@ -2,16 +2,29 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import Any, Final
|
||||
|
||||
from notify_bridge_core.storage import StorageBackend
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TELEGRAM_CACHE_TTL = 48 * 60 * 60
|
||||
DEFAULT_MAX_ENTRIES = 5000
|
||||
DEFAULT_TELEGRAM_CACHE_TTL: Final = 48 * 60 * 60
|
||||
DEFAULT_MAX_ENTRIES: Final = 5000
|
||||
|
||||
|
||||
def _parse_iso(value: str | None) -> datetime | None:
|
||||
"""Parse an ISO-8601 timestamp tolerantly. Returns ``None`` on failure."""
|
||||
if not value or not isinstance(value, str):
|
||||
return None
|
||||
try:
|
||||
# Python <3.11 doesn't accept "Z"; normalize to +00:00.
|
||||
v = value.replace("Z", "+00:00") if value.endswith("Z") else value
|
||||
return datetime.fromisoformat(v)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
class TelegramFileCache:
|
||||
@@ -25,7 +38,17 @@ class TelegramFileCache:
|
||||
Intended for content-addressable assets (e.g. Immich) where re-uploads
|
||||
should be triggered by visual change, not elapsed time.
|
||||
|
||||
``max_entries`` always applies as an LRU size cap (by ``cached_at``).
|
||||
``max_entries`` always applies as a FIFO size cap (oldest-cached first).
|
||||
|
||||
Concurrency
|
||||
~~~~~~~~~~~
|
||||
All mutators take an internal ``asyncio.Lock`` so concurrent
|
||||
media-group sends can't interleave a read-time invalidation with a
|
||||
bulk write and corrupt the underlying dict (``RuntimeError:
|
||||
dictionary changed size during iteration``) or lose just-written
|
||||
entries. Reads do not take the lock — they are O(1) dict lookups —
|
||||
but ``get`` uses a snapshot reference so it cannot mutate the data
|
||||
structure under another task.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -40,35 +63,40 @@ class TelegramFileCache:
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._use_thumbhash = use_thumbhash
|
||||
self._max_entries = max_entries
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def async_load(self) -> None:
|
||||
self._data = await self._backend.load() or {"files": {}}
|
||||
await self._cleanup_expired()
|
||||
async with self._lock:
|
||||
self._data = await self._backend.load() or {"files": {}}
|
||||
await self._cleanup_expired_locked()
|
||||
|
||||
async def _cleanup_expired(self) -> None:
|
||||
async def _cleanup_expired_locked(self) -> None:
|
||||
"""Caller must hold ``self._lock``."""
|
||||
if not self._data or "files" not in self._data:
|
||||
return
|
||||
files = self._data["files"]
|
||||
files: dict[str, dict[str, Any]] = self._data["files"]
|
||||
changed = False
|
||||
|
||||
# TTL sweep — only when TTL validation is active (i.e. no thumbhash
|
||||
# mode and a positive TTL). In thumbhash mode we rely entirely on
|
||||
# content validation; in "TTL disabled" mode (ttl_seconds <= 0) we
|
||||
# cache forever, subject only to the size cap.
|
||||
if not self._use_thumbhash and self._ttl_seconds > 0:
|
||||
now = datetime.now(timezone.utc)
|
||||
expired = [
|
||||
url for url, entry in files.items()
|
||||
if entry.get("cached_at") and
|
||||
(now - datetime.fromisoformat(entry["cached_at"])).total_seconds() > self._ttl_seconds
|
||||
]
|
||||
expired: list[str] = []
|
||||
for url, entry in list(files.items()):
|
||||
cached_at = _parse_iso(entry.get("cached_at"))
|
||||
if cached_at is None:
|
||||
continue
|
||||
if cached_at.tzinfo is None:
|
||||
cached_at = cached_at.replace(tzinfo=timezone.utc)
|
||||
if (now - cached_at).total_seconds() > self._ttl_seconds:
|
||||
expired.append(url)
|
||||
for key in expired:
|
||||
del files[key]
|
||||
changed = True
|
||||
|
||||
# LRU cap — always enforced. Evicts oldest-cached entries first.
|
||||
if self._max_entries > 0 and len(files) > self._max_entries:
|
||||
sorted_keys = sorted(files, key=lambda k: files[k].get("cached_at", ""))
|
||||
sorted_keys = sorted(
|
||||
files,
|
||||
key=lambda k: _parse_iso(files[k].get("cached_at")) or datetime.min.replace(tzinfo=timezone.utc),
|
||||
)
|
||||
for key in sorted_keys[: len(files) - self._max_entries]:
|
||||
del files[key]
|
||||
changed = True
|
||||
@@ -80,7 +108,10 @@ class TelegramFileCache:
|
||||
if not self._data or "files" not in self._data:
|
||||
return None
|
||||
|
||||
entry = self._data["files"].get(key)
|
||||
# Take a local reference so a concurrent ``async_set`` rebuilding
|
||||
# the dict cannot pull the rug out mid-read.
|
||||
files = self._data["files"]
|
||||
entry = files.get(key)
|
||||
if not entry:
|
||||
return None
|
||||
|
||||
@@ -88,19 +119,23 @@ class TelegramFileCache:
|
||||
if thumbhash is not None:
|
||||
stored = entry.get("thumbhash")
|
||||
if stored and stored != thumbhash:
|
||||
del self._data["files"][key]
|
||||
# Mark stale — actual deletion happens lock-protected
|
||||
# in the next mutation. Returning None is sufficient
|
||||
# for the caller to skip the cache hit.
|
||||
return None
|
||||
elif self._ttl_seconds > 0:
|
||||
cached_at_str = entry.get("cached_at")
|
||||
if cached_at_str:
|
||||
age = (datetime.now(timezone.utc) - datetime.fromisoformat(cached_at_str)).total_seconds()
|
||||
cached_at = _parse_iso(entry.get("cached_at"))
|
||||
if cached_at is not None:
|
||||
if cached_at.tzinfo is None:
|
||||
cached_at = cached_at.replace(tzinfo=timezone.utc)
|
||||
age = (datetime.now(timezone.utc) - cached_at).total_seconds()
|
||||
if age > self._ttl_seconds:
|
||||
return None
|
||||
|
||||
return {
|
||||
"file_id": entry.get("file_id"),
|
||||
"type": entry.get("type"),
|
||||
"size": entry.get("size"), # bytes of what was uploaded; None for legacy entries
|
||||
"size": entry.get("size"),
|
||||
}
|
||||
|
||||
async def async_set(
|
||||
@@ -111,21 +146,22 @@ class TelegramFileCache:
|
||||
thumbhash: str | None = None,
|
||||
size: int | None = None,
|
||||
) -> None:
|
||||
if self._data is None:
|
||||
self._data = {"files": {}}
|
||||
async with self._lock:
|
||||
if self._data is None:
|
||||
self._data = {"files": {}}
|
||||
|
||||
entry: dict[str, Any] = {
|
||||
"file_id": file_id,
|
||||
"type": media_type,
|
||||
"cached_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
if thumbhash is not None:
|
||||
entry["thumbhash"] = thumbhash
|
||||
if size is not None:
|
||||
entry["size"] = size
|
||||
entry: dict[str, Any] = {
|
||||
"file_id": file_id,
|
||||
"type": media_type,
|
||||
"cached_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
if thumbhash is not None:
|
||||
entry["thumbhash"] = thumbhash
|
||||
if size is not None:
|
||||
entry["size"] = size
|
||||
|
||||
self._data["files"][key] = entry
|
||||
await self._backend.save(self._data)
|
||||
self._data["files"][key] = entry
|
||||
await self._backend.save(self._data)
|
||||
|
||||
async def async_set_many(
|
||||
self,
|
||||
@@ -139,32 +175,34 @@ class TelegramFileCache:
|
||||
"""
|
||||
if not entries:
|
||||
return
|
||||
if self._data is None:
|
||||
self._data = {"files": {}}
|
||||
async with self._lock:
|
||||
if self._data is None:
|
||||
self._data = {"files": {}}
|
||||
|
||||
now_iso = datetime.now(timezone.utc).isoformat()
|
||||
for item in entries:
|
||||
if len(item) == 5:
|
||||
key, file_id, media_type, thumbhash, size = item
|
||||
else:
|
||||
key, file_id, media_type, thumbhash = item
|
||||
size = None
|
||||
entry: dict[str, Any] = {
|
||||
"file_id": file_id,
|
||||
"type": media_type,
|
||||
"cached_at": now_iso,
|
||||
}
|
||||
if thumbhash is not None:
|
||||
entry["thumbhash"] = thumbhash
|
||||
if size is not None:
|
||||
entry["size"] = size
|
||||
self._data["files"][key] = entry
|
||||
now_iso = datetime.now(timezone.utc).isoformat()
|
||||
for item in entries:
|
||||
if len(item) == 5:
|
||||
key, file_id, media_type, thumbhash, size = item
|
||||
else:
|
||||
key, file_id, media_type, thumbhash = item
|
||||
size = None
|
||||
entry: dict[str, Any] = {
|
||||
"file_id": file_id,
|
||||
"type": media_type,
|
||||
"cached_at": now_iso,
|
||||
}
|
||||
if thumbhash is not None:
|
||||
entry["thumbhash"] = thumbhash
|
||||
if size is not None:
|
||||
entry["size"] = size
|
||||
self._data["files"][key] = entry
|
||||
|
||||
await self._backend.save(self._data)
|
||||
await self._backend.save(self._data)
|
||||
|
||||
async def async_remove(self) -> None:
|
||||
await self._backend.remove()
|
||||
self._data = None
|
||||
async with self._lock:
|
||||
await self._backend.remove()
|
||||
self._data = None
|
||||
|
||||
def stats(self) -> dict[str, Any]:
|
||||
"""Return summary stats about the current cache contents.
|
||||
@@ -172,25 +210,33 @@ class TelegramFileCache:
|
||||
Includes the number of cached entries, total tracked size in bytes
|
||||
(only counts entries with a recorded ``size``), and the oldest /
|
||||
newest ``cached_at`` timestamps (ISO strings, or ``None`` if empty).
|
||||
Timestamps are compared as parsed ``datetime`` objects so mixed
|
||||
timezone formats (``Z`` vs ``+00:00``) order correctly.
|
||||
"""
|
||||
files = self._data.get("files", {}) if self._data else {}
|
||||
count = len(files)
|
||||
total_size = 0
|
||||
oldest: str | None = None
|
||||
newest: str | None = None
|
||||
oldest_dt: datetime | None = None
|
||||
newest_dt: datetime | None = None
|
||||
oldest_str: str | None = None
|
||||
newest_str: str | None = None
|
||||
for entry in files.values():
|
||||
size = entry.get("size")
|
||||
if isinstance(size, int):
|
||||
total_size += size
|
||||
cached_at = entry.get("cached_at")
|
||||
if cached_at:
|
||||
if oldest is None or cached_at < oldest:
|
||||
oldest = cached_at
|
||||
if newest is None or cached_at > newest:
|
||||
newest = cached_at
|
||||
dt = _parse_iso(cached_at)
|
||||
if dt is None or not cached_at:
|
||||
continue
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
if oldest_dt is None or dt < oldest_dt:
|
||||
oldest_dt, oldest_str = dt, cached_at
|
||||
if newest_dt is None or dt > newest_dt:
|
||||
newest_dt, newest_str = dt, cached_at
|
||||
return {
|
||||
"count": count,
|
||||
"total_size_bytes": total_size,
|
||||
"oldest": oldest,
|
||||
"newest": newest,
|
||||
"oldest": oldest_str,
|
||||
"newest": newest_str,
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,20 +2,35 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Final
|
||||
from urllib.parse import urlparse
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Telegram constants
|
||||
TELEGRAM_API_BASE_URL: Final = "https://api.telegram.org/bot"
|
||||
TELEGRAM_MAX_PHOTO_SIZE: Final = 10 * 1024 * 1024 # 10 MB
|
||||
TELEGRAM_MAX_VIDEO_SIZE: Final = 50 * 1024 * 1024 # 50 MB
|
||||
TELEGRAM_MAX_DIMENSION_SUM: Final = 10000
|
||||
# Telegram message-text limit (sendMessage) and caption limit
|
||||
# (sendPhoto/sendVideo/sendDocument/first item of sendMediaGroup).
|
||||
TELEGRAM_MAX_TEXT_LENGTH: Final = 4096
|
||||
TELEGRAM_MAX_CAPTION_LENGTH: Final = 1024
|
||||
|
||||
# Generic UUID pattern for asset IDs
|
||||
_ASSET_ID_PATTERN = re.compile(r"^[a-f0-9-]{36}$")
|
||||
# Strict canonical-UUID pattern (8-4-4-4-12) for asset IDs. The previous
|
||||
# loose ``[a-f0-9-]{36}`` matched 36 hyphens / arbitrary digit groupings,
|
||||
# which could collide across providers when used as a cache key.
|
||||
_ASSET_ID_PATTERN = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# Cache key: "host:uuid" or bare "uuid"
|
||||
_ASSET_CACHE_KEY_PATTERN = re.compile(r"^(?:[^:]+:)?[a-f0-9-]{36}$")
|
||||
_ASSET_CACHE_KEY_PATTERN = re.compile(
|
||||
r"^(?:[^:]+:)?[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# URL patterns to extract asset IDs (generic enough for Immich-style URLs)
|
||||
_ASSET_ID_URL_PATTERNS = [
|
||||
@@ -162,5 +177,10 @@ def check_photo_limits(
|
||||
return False, None, width, height
|
||||
except ImportError:
|
||||
return False, None, None, None
|
||||
except Exception:
|
||||
except (OSError, ValueError, MemoryError) as exc:
|
||||
# PIL surfaces ``UnidentifiedImageError`` (subclass of OSError),
|
||||
# truncated-image / decompression-bomb errors here. Log so a
|
||||
# corrupt asset isn't silently passed to Telegram and rejected
|
||||
# downstream with a less actionable error.
|
||||
_LOGGER.warning("check_photo_limits: failed to inspect image (%d bytes): %s", len(data), exc)
|
||||
return False, None, None, None
|
||||
|
||||
@@ -7,37 +7,29 @@ from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..ssrf import UnsafeURLError, avalidate_outbound_url
|
||||
from ..http_base import HttpProviderClient
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||
|
||||
class WebhookClient(HttpProviderClient):
|
||||
"""Send JSON payloads to a webhook URL.
|
||||
|
||||
class WebhookClient:
|
||||
"""Send JSON payloads to a webhook URL."""
|
||||
The URL is SSRF-validated on every send (defense-in-depth: re-validating
|
||||
catches DNS rebinding between calls and a misconfigured target). Headers
|
||||
pass through :func:`safe_headers` so a target config can't inject
|
||||
framing/hop-by-hop headers like ``Host`` or ``Transfer-Encoding``.
|
||||
"""
|
||||
|
||||
def __init__(self, session: aiohttp.ClientSession, url: str, headers: dict[str, str] | None = None) -> None:
|
||||
self._session = session
|
||||
def __init__(
|
||||
self,
|
||||
session: aiohttp.ClientSession,
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
super().__init__(session, provider_name="webhook")
|
||||
self._url = url
|
||||
self._headers = headers or {}
|
||||
self._extra_headers = headers or {}
|
||||
|
||||
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
try:
|
||||
await avalidate_outbound_url(self._url)
|
||||
except UnsafeURLError as err:
|
||||
return {"success": False, "error": f"Unsafe URL: {err}"}
|
||||
try:
|
||||
async with self._session.post(
|
||||
self._url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json", **self._headers},
|
||||
timeout=_DEFAULT_TIMEOUT,
|
||||
allow_redirects=False,
|
||||
) as response:
|
||||
if 200 <= response.status < 300:
|
||||
return {"success": True, "status_code": response.status}
|
||||
body = await response.text()
|
||||
return {"success": False, "error": f"HTTP {response.status}", "body": body[:200]}
|
||||
except aiohttp.ClientError as err:
|
||||
return {"success": False, "error": str(err)}
|
||||
return await self.request("POST", self._url, json=payload, headers=self._extra_headers)
|
||||
|
||||
@@ -218,6 +218,19 @@ async def get_supported_locales(
|
||||
return locales or ["en"]
|
||||
|
||||
|
||||
@router.get("/external-url")
|
||||
async def get_external_url(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Return the configured external base URL (available to all users).
|
||||
|
||||
Used by the UI to render absolute provider webhook URLs. Returns empty
|
||||
string when unset so the UI falls back to the relative path.
|
||||
"""
|
||||
return {"external_url": (await get_setting(session, "external_url")).rstrip("/")}
|
||||
|
||||
|
||||
async def _reregister_webhooks(
|
||||
session: AsyncSession, base_url: str, secret: str
|
||||
) -> None:
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Dispatcher result aggregation: per-receiver detail must survive."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from notify_bridge_core.notifications.dispatcher import NotificationDispatcher
|
||||
|
||||
|
||||
def test_aggregate_all_success() -> None:
|
||||
out = NotificationDispatcher._aggregate_results([
|
||||
{"success": True, "message_id": 1},
|
||||
{"success": True, "message_id": 2},
|
||||
])
|
||||
assert out["success"] is True
|
||||
assert out["receivers"] == 2
|
||||
assert out["successes"] == 2
|
||||
assert out["failures"] == 0
|
||||
|
||||
|
||||
def test_aggregate_partial() -> None:
|
||||
out = NotificationDispatcher._aggregate_results([
|
||||
{"success": True},
|
||||
{"success": False, "error": "boom"},
|
||||
])
|
||||
assert out["success"] is True # at least one succeeded
|
||||
assert out["successes"] == 1
|
||||
assert out["failures"] == 1
|
||||
assert "boom" in out["errors"]
|
||||
assert "results" in out
|
||||
|
||||
|
||||
def test_aggregate_all_fail_preserves_all_errors() -> None:
|
||||
out = NotificationDispatcher._aggregate_results([
|
||||
{"success": False, "error": "first"},
|
||||
{"success": False, "error": "second"},
|
||||
])
|
||||
assert out["success"] is False
|
||||
assert out["error"] == "first" # back-compat top-level field
|
||||
assert out["errors"] == ["first", "second"]
|
||||
# Per-receiver details survive — operator can see exactly what failed.
|
||||
assert len(out["results"]) == 2
|
||||
|
||||
|
||||
def test_aggregate_empty() -> None:
|
||||
out = NotificationDispatcher._aggregate_results([])
|
||||
assert out["success"] is False
|
||||
assert "error" in out
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Email client header-injection / address-validation regression tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from notify_bridge_core.notifications.email.client import (
|
||||
EmailClient,
|
||||
SmtpConfig,
|
||||
_strip_header,
|
||||
_validate_email,
|
||||
_to_html,
|
||||
)
|
||||
|
||||
|
||||
def test_strip_header_removes_crlf() -> None:
|
||||
out = _strip_header("Subject\r\nBcc: attacker@example.com")
|
||||
assert "\r" not in out
|
||||
assert "\n" not in out
|
||||
# The injected "Bcc:" line is folded to a single header line; the SMTP
|
||||
# server will treat it as part of the subject text, not a header.
|
||||
assert "Bcc:" in out # value preserved as plain text
|
||||
|
||||
|
||||
def test_strip_header_removes_bare_lf() -> None:
|
||||
out = _strip_header("Hello\nWorld")
|
||||
assert "\n" not in out
|
||||
|
||||
|
||||
def test_strip_header_handles_non_string() -> None:
|
||||
assert _strip_header(None) == ""
|
||||
|
||||
|
||||
def test_validate_email_rejects_crlf() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_validate_email("user@example.com\r\nBcc: x@y")
|
||||
|
||||
|
||||
def test_validate_email_rejects_no_at() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_validate_email("not-an-email")
|
||||
|
||||
|
||||
def test_validate_email_rejects_empty() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_validate_email("")
|
||||
|
||||
|
||||
def test_validate_email_accepts_normal() -> None:
|
||||
assert _validate_email("user@example.com") == "user@example.com"
|
||||
|
||||
|
||||
def test_to_html_escapes_brackets() -> None:
|
||||
out = _to_html("<script>alert(1)</script>")
|
||||
assert "<script>" not in out
|
||||
assert "<script>" in out
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_error_on_invalid_to() -> None:
|
||||
cfg = SmtpConfig(host="smtp.example.com", from_address="from@example.com")
|
||||
client = EmailClient(cfg)
|
||||
result = await client.send(
|
||||
to_email="user@example.com\r\nBcc: attacker@example.com",
|
||||
subject="hi",
|
||||
body_text="body",
|
||||
)
|
||||
assert result["success"] is False
|
||||
assert "Invalid email" in result["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_error_on_no_host() -> None:
|
||||
cfg = SmtpConfig(host="", from_address="from@example.com")
|
||||
client = EmailClient(cfg)
|
||||
result = await client.send("u@x.com", "s", "b")
|
||||
assert result["success"] is False
|
||||
@@ -0,0 +1,53 @@
|
||||
"""HttpProviderClient + safe_headers tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from notify_bridge_core.notifications.http_base import safe_headers
|
||||
|
||||
|
||||
class TestSafeHeaders:
|
||||
def test_drops_hop_by_hop(self) -> None:
|
||||
out = safe_headers({
|
||||
"X-Custom": "ok",
|
||||
"Host": "evil.example.com",
|
||||
"Content-Length": "999",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Connection": "close",
|
||||
})
|
||||
assert out == {"X-Custom": "ok"}
|
||||
|
||||
def test_rejects_crlf_in_value(self) -> None:
|
||||
out = safe_headers({
|
||||
"X-Custom": "ok",
|
||||
"X-Bad": "value\r\nInjected: yes",
|
||||
})
|
||||
assert "X-Custom" in out
|
||||
assert "X-Bad" not in out
|
||||
|
||||
def test_rejects_crlf_in_name(self) -> None:
|
||||
out = safe_headers({
|
||||
"X-Custom": "ok",
|
||||
"X-Bad\r\nInject": "value",
|
||||
})
|
||||
assert out == {"X-Custom": "ok"}
|
||||
|
||||
def test_empty_input(self) -> None:
|
||||
assert safe_headers(None) == {}
|
||||
assert safe_headers({}) == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_base_returns_safe_error_on_invalid_url() -> None:
|
||||
"""An obviously-bad URL must not panic or leak the URL verbatim."""
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.notifications.http_base import HttpProviderClient
|
||||
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
client = HttpProviderClient(sess, provider_name="test")
|
||||
# file:// is rejected by the SSRF guard before any HTTP call.
|
||||
result = await client.request("POST", "file:///etc/passwd", json={})
|
||||
assert result["success"] is False
|
||||
assert "Unsafe URL" in result["error"]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Matrix client validation: room_id format and quoting."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
from aioresponses import aioresponses
|
||||
|
||||
from notify_bridge_core.notifications.matrix.client import MatrixClient
|
||||
|
||||
|
||||
HOMESERVER = "https://matrix.example.com"
|
||||
TOKEN = "secret-bearer-token-1234567890"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_path_injection_room_id() -> None:
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
client = MatrixClient(sess, HOMESERVER, TOKEN)
|
||||
result = await client.send_message("!abc:host/../../etc/passwd", "hi")
|
||||
assert result["success"] is False
|
||||
assert "room_id" in result["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_empty_room_id() -> None:
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
client = MatrixClient(sess, HOMESERVER, TOKEN)
|
||||
result = await client.send_message("", "hi")
|
||||
assert result["success"] is False
|
||||
assert "room_id" in result["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_unicode_control_chars_in_room_id() -> None:
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
client = MatrixClient(sess, HOMESERVER, TOKEN)
|
||||
result = await client.send_message("!abc\x00:host", "hi")
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_url_encodes_room_id_special_chars() -> None:
|
||||
"""``!`` and ``:`` must reach the server URL-encoded."""
|
||||
captured: list[str] = []
|
||||
|
||||
with aioresponses() as mocked:
|
||||
# Match any PUT under the rooms path; capture the URL we got.
|
||||
mocked.put(
|
||||
"https://matrix.example.com/_matrix/client/v3/rooms/%21abc%3Ahost.example/send/m.room.message",
|
||||
status=200, body='{}', repeat=True,
|
||||
)
|
||||
# aioresponses doesn't expose URL templates well, so use a regex mock.
|
||||
import re
|
||||
mocked.put(
|
||||
re.compile(r"https://matrix\.example\.com/_matrix/client/v3/rooms/[^/]+/send/m\.room\.message/.*"),
|
||||
status=200, body='{}', repeat=True,
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
client = MatrixClient(sess, HOMESERVER, TOKEN)
|
||||
result = await client.send_message("!abc:host.example", "hi")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redacts_bearer_in_error() -> None:
|
||||
"""A 4xx response body must not echo the Authorization Bearer back to caller."""
|
||||
import re
|
||||
with aioresponses() as mocked:
|
||||
mocked.put(
|
||||
re.compile(r".*"),
|
||||
status=403,
|
||||
body='{"errcode": "M_FORBIDDEN", "Authorization": "Bearer ' + TOKEN + '"}',
|
||||
repeat=True,
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
client = MatrixClient(sess, HOMESERVER, TOKEN)
|
||||
result = await client.send_message("!abc:host.example", "hi")
|
||||
|
||||
assert result["success"] is False
|
||||
assert TOKEN not in result["error"]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""NotificationQueue bound + concurrency regression tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from notify_bridge_core.notifications.queue import (
|
||||
DEFAULT_MAX_QUEUE_SIZE,
|
||||
NotificationQueue,
|
||||
)
|
||||
|
||||
|
||||
class _MemBackend:
|
||||
"""In-memory storage backend stub for tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data: dict[str, Any] | None = None
|
||||
|
||||
async def load(self) -> dict[str, Any] | None:
|
||||
return self._data
|
||||
|
||||
async def save(self, data: dict[str, Any]) -> None:
|
||||
self._data = data
|
||||
|
||||
async def remove(self) -> None:
|
||||
self._data = None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_with_garbage_falls_back_to_empty() -> None:
|
||||
backend = _MemBackend()
|
||||
backend._data = {"queue": "not-a-list"} # type: ignore[assignment]
|
||||
q = NotificationQueue(backend)
|
||||
await q.async_load()
|
||||
assert q.get_all() == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_caps_at_max_size() -> None:
|
||||
backend = _MemBackend()
|
||||
q = NotificationQueue(backend, max_size=3)
|
||||
await q.async_load()
|
||||
for i in range(10):
|
||||
await q.async_enqueue({"i": i})
|
||||
items = q.get_all()
|
||||
assert len(items) == 3
|
||||
# FIFO drop: most recent three are kept (i=7..9).
|
||||
assert [it["params"]["i"] for it in items] == [7, 8, 9]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_returns_deep_copy() -> None:
|
||||
backend = _MemBackend()
|
||||
q = NotificationQueue(backend, max_size=10)
|
||||
await q.async_load()
|
||||
await q.async_enqueue({"key": "value"})
|
||||
snap = q.get_all()
|
||||
snap[0]["params"]["key"] = "MUTATED"
|
||||
snap2 = q.get_all()
|
||||
assert snap2[0]["params"]["key"] == "value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_enqueue_and_clear_no_corruption() -> None:
|
||||
backend = _MemBackend()
|
||||
q = NotificationQueue(backend, max_size=DEFAULT_MAX_QUEUE_SIZE)
|
||||
await q.async_load()
|
||||
|
||||
async def producer() -> None:
|
||||
for i in range(50):
|
||||
await q.async_enqueue({"i": i})
|
||||
|
||||
async def clearer() -> None:
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0)
|
||||
await q.async_clear()
|
||||
|
||||
await asyncio.gather(producer(), clearer())
|
||||
# No exceptions = no race-induced "dictionary changed size during iteration".
|
||||
items = q.get_all()
|
||||
assert isinstance(items, list)
|
||||
@@ -0,0 +1,74 @@
|
||||
"""Secret-redaction helper regression tests.
|
||||
|
||||
Locks in the patterns that surface from real provider error paths:
|
||||
Telegram bot URLs in aiohttp.ClientError messages, Authorization Bearer
|
||||
tokens in Matrix/ntfy responses, Discord/Slack webhook tokens, URL
|
||||
userinfo, and common ?token= query params.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from notify_bridge_core.notifications.redact import redact, redact_exc
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"raw,expected_substr,not_in",
|
||||
[
|
||||
(
|
||||
"Cannot connect to host api.telegram.org/bot1234567:AABBCC-secret-token/sendMessage",
|
||||
"api.telegram.org/bot***",
|
||||
"AABBCC-secret-token",
|
||||
),
|
||||
(
|
||||
"Authorization: Bearer ey.JhbGciOiJIUzI1NiJ9.payload.sig",
|
||||
"Bearer ***",
|
||||
"ey.JhbGciOiJIUzI1NiJ9",
|
||||
),
|
||||
(
|
||||
"POST https://discord.com/api/webhooks/12345/abcdefg-token failed",
|
||||
"discord.com/api/webhooks/12345/***",
|
||||
"abcdefg-token",
|
||||
),
|
||||
(
|
||||
"POST https://hooks.slack.com/services/T01/B02/zzzzz failed",
|
||||
"hooks.slack.com/services/T01/B02/***",
|
||||
"zzzzz",
|
||||
),
|
||||
(
|
||||
"fetch http://user:supersecret@example.com/foo",
|
||||
"http://***@example.com/foo",
|
||||
"supersecret",
|
||||
),
|
||||
(
|
||||
"https://api.example.com/x?token=mytoken123&extra=ok",
|
||||
"token=***",
|
||||
"mytoken123",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_redact_known_secrets(raw: str, expected_substr: str, not_in: str) -> None:
|
||||
out = redact(raw)
|
||||
assert expected_substr in out
|
||||
assert not_in not in out
|
||||
|
||||
|
||||
def test_redact_idempotent() -> None:
|
||||
once = redact("Bearer abcdefghij1234567890")
|
||||
twice = redact(once)
|
||||
assert once == twice
|
||||
|
||||
|
||||
def test_redact_exc_returns_str() -> None:
|
||||
err = RuntimeError("Bearer abcdefghij1234567890")
|
||||
out = redact_exc(err)
|
||||
assert isinstance(out, str)
|
||||
assert "Bearer ***" in out
|
||||
assert "abcdefghij1234567890" not in out
|
||||
|
||||
|
||||
def test_redact_handles_non_string() -> None:
|
||||
# Coercion path should not raise.
|
||||
out = redact(12345) # type: ignore[arg-type]
|
||||
assert out == "12345"
|
||||
@@ -0,0 +1,73 @@
|
||||
"""SSRF hardening regression tests.
|
||||
|
||||
Covers cases the original guard missed: IPv4-mapped IPv6, CGNAT,
|
||||
trailing-dot hostnames, IPv6 zone identifiers, and the safe-host repr
|
||||
used in error messages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from notify_bridge_core.notifications.ssrf import (
|
||||
UnsafeURLError,
|
||||
PinnedResolver,
|
||||
avalidate_outbound_url_full,
|
||||
validate_outbound_url,
|
||||
)
|
||||
|
||||
|
||||
class TestBlockedRanges:
|
||||
@pytest.mark.parametrize(
|
||||
"url",
|
||||
[
|
||||
"http://[::ffff:127.0.0.1]/", # IPv4-mapped IPv6 → loopback
|
||||
"http://[::ffff:10.0.0.1]/", # IPv4-mapped IPv6 → RFC1918
|
||||
"http://100.64.0.1/", # CGNAT (RFC 6598)
|
||||
"http://0.0.0.0/", # unspecified
|
||||
],
|
||||
)
|
||||
def test_rejects_extra_ranges(self, url: str) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
validate_outbound_url(url)
|
||||
|
||||
|
||||
class TestHostnameNormalization:
|
||||
def test_strips_trailing_dot(self) -> None:
|
||||
# ``localhost.`` should normalize to ``localhost`` and still resolve
|
||||
# to the loopback address — and be blocked.
|
||||
with pytest.raises(UnsafeURLError):
|
||||
validate_outbound_url("http://localhost./")
|
||||
|
||||
def test_rejects_bad_scheme_uppercase(self) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
validate_outbound_url("FILE:///etc/passwd")
|
||||
|
||||
|
||||
class TestErrorMessages:
|
||||
def test_error_does_not_leak_long_hosts(self) -> None:
|
||||
with pytest.raises(UnsafeURLError) as ei:
|
||||
validate_outbound_url("http://" + "a" * 1024 + ".invalid/")
|
||||
# Truncated to 64 chars in the error string.
|
||||
assert len(str(ei.value)) < 256
|
||||
|
||||
|
||||
class TestPinnedResolverSync:
|
||||
def test_pin_returns_pinned_ip(self) -> None:
|
||||
resolver = PinnedResolver({"example.com": "93.184.216.34"})
|
||||
# Just exercise the dict path — full resolve runs in async tests.
|
||||
assert resolver._map["example.com"] == "93.184.216.34" # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestAsyncFullValidator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_resolved_ip(self) -> None:
|
||||
# Literal IP — no DNS lookup; we still get back a ValidatedURL.
|
||||
result = await avalidate_outbound_url_full("http://8.8.8.8/")
|
||||
assert result.ip == "8.8.8.8"
|
||||
assert result.host == "8.8.8.8"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_blocked_literal(self) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
await avalidate_outbound_url_full("http://[::ffff:127.0.0.1]/")
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Telegram media-group mixed-type partitioning regression test.
|
||||
|
||||
Telegram rejects sendMediaGroup payloads that mix ``document`` with
|
||||
``photo``/``video``. The client must partition before chunking so a
|
||||
mixed input list still delivers all assets.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
||||
|
||||
|
||||
def test_partition_keeps_photo_video_together() -> None:
|
||||
parts = TelegramClient._partition_media_by_kind([
|
||||
{"type": "photo", "url": "p1"},
|
||||
{"type": "video", "url": "v1"},
|
||||
{"type": "photo", "url": "p2"},
|
||||
])
|
||||
assert len(parts) == 1
|
||||
assert [a["url"] for a in parts[0]] == ["p1", "v1", "p2"]
|
||||
|
||||
|
||||
def test_partition_separates_documents_from_media() -> None:
|
||||
parts = TelegramClient._partition_media_by_kind([
|
||||
{"type": "photo", "url": "p1"},
|
||||
{"type": "document", "url": "d1"},
|
||||
{"type": "video", "url": "v1"},
|
||||
])
|
||||
assert len(parts) == 3
|
||||
assert parts[0][0]["url"] == "p1"
|
||||
assert parts[1][0]["url"] == "d1"
|
||||
assert parts[2][0]["url"] == "v1"
|
||||
|
||||
|
||||
def test_partition_groups_consecutive_documents() -> None:
|
||||
parts = TelegramClient._partition_media_by_kind([
|
||||
{"type": "document", "url": "d1"},
|
||||
{"type": "document", "url": "d2"},
|
||||
{"type": "photo", "url": "p1"},
|
||||
])
|
||||
assert len(parts) == 2
|
||||
assert [a["url"] for a in parts[0]] == ["d1", "d2"]
|
||||
assert parts[1][0]["url"] == "p1"
|
||||
|
||||
|
||||
def test_partition_empty() -> None:
|
||||
assert TelegramClient._partition_media_by_kind([]) == []
|
||||
|
||||
|
||||
def test_partition_defaults_missing_type_to_photo() -> None:
|
||||
"""Items without an explicit type are treated as photos for grouping."""
|
||||
parts = TelegramClient._partition_media_by_kind([
|
||||
{"url": "x"}, # no type
|
||||
{"type": "video", "url": "v"},
|
||||
])
|
||||
assert len(parts) == 1
|
||||
Reference in New Issue
Block a user