feat: harden notification stack and switch logging selectors to icon grid
Notifications: - Add shared http_base, redact, and SSRF hardening modules - Refactor dispatcher, queue, receiver and per-provider clients (telegram, discord, email, matrix, ntfy, slack, webhook) to use the shared base, with bounded queue and redacted error logs - Tests for ssrf, redact, http_base, queue bounds, dispatcher aggregation, telegram media partition, email and matrix clients Frontend: - Settings: log level / log format selectors now use IconGridSelect with per-option icons and i18n descriptions - Minor providers page and entity-cache store updates Tooling: - Document code-review-graph MCP usage in CLAUDE.md - Ignore .code-review-graph/, register .mcp.json
This commit is contained in:
@@ -56,3 +56,5 @@ frontend/.svelte-kit/
|
|||||||
|
|
||||||
# Logs
|
# Logs
|
||||||
*.log
|
*.log
|
||||||
|
# Added by code-review-graph
|
||||||
|
.code-review-graph/
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"code-review-graph": {
|
||||||
|
"command": "uvx",
|
||||||
|
"args": [
|
||||||
|
"code-review-graph",
|
||||||
|
"serve"
|
||||||
|
],
|
||||||
|
"type": "stdio"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -43,3 +43,42 @@ Detailed context is split into focused documents under `.claude/docs/`. Read the
|
|||||||
- Notification preview sample: `packages/server/src/notify_bridge_server/services/sample_context.py` (`_SAMPLE_CONTEXT`)
|
- Notification preview sample: `packages/server/src/notify_bridge_server/services/sample_context.py` (`_SAMPLE_CONTEXT`)
|
||||||
- Command preview sample: `packages/server/src/notify_bridge_server/api/command_template_configs.py` (`sample_ctx` in `preview_raw`)
|
- Command preview sample: `packages/server/src/notify_bridge_server/api/command_template_configs.py` (`sample_ctx` in `preview_raw`)
|
||||||
- Runtime validator whitelist: `packages/core/src/notify_bridge_core/templates/validator.py`
|
- Runtime validator whitelist: `packages/core/src/notify_bridge_core/templates/validator.py`
|
||||||
|
|
||||||
|
<!-- code-review-graph MCP tools -->
|
||||||
|
## MCP Tools: code-review-graph
|
||||||
|
|
||||||
|
**IMPORTANT: This project has a knowledge graph. ALWAYS use the
|
||||||
|
code-review-graph MCP tools BEFORE using Grep/Glob/Read to explore
|
||||||
|
the codebase.** The graph is faster, cheaper (fewer tokens), and gives
|
||||||
|
you structural context (callers, dependents, test coverage) that file
|
||||||
|
scanning cannot.
|
||||||
|
|
||||||
|
### When to use graph tools FIRST
|
||||||
|
|
||||||
|
- **Exploring code**: `semantic_search_nodes` or `query_graph` instead of Grep
|
||||||
|
- **Understanding impact**: `get_impact_radius` instead of manually tracing imports
|
||||||
|
- **Code review**: `detect_changes` + `get_review_context` instead of reading entire files
|
||||||
|
- **Finding relationships**: `query_graph` with callers_of/callees_of/imports_of/tests_for
|
||||||
|
- **Architecture questions**: `get_architecture_overview` + `list_communities`
|
||||||
|
|
||||||
|
Fall back to Grep/Glob/Read **only** when the graph doesn't cover what you need.
|
||||||
|
|
||||||
|
### Key Tools
|
||||||
|
|
||||||
|
| Tool | Use when |
|
||||||
|
|------|----------|
|
||||||
|
| `detect_changes` | Reviewing code changes — gives risk-scored analysis |
|
||||||
|
| `get_review_context` | Need source snippets for review — token-efficient |
|
||||||
|
| `get_impact_radius` | Understanding blast radius of a change |
|
||||||
|
| `get_affected_flows` | Finding which execution paths are impacted |
|
||||||
|
| `query_graph` | Tracing callers, callees, imports, tests, dependencies |
|
||||||
|
| `semantic_search_nodes` | Finding functions/classes by name or keyword |
|
||||||
|
| `get_architecture_overview` | Understanding high-level codebase structure |
|
||||||
|
| `refactor_tool` | Planning renames, finding dead code |
|
||||||
|
|
||||||
|
### Workflow
|
||||||
|
|
||||||
|
1. The graph auto-updates on file changes (via hooks).
|
||||||
|
2. Use `detect_changes` for code review.
|
||||||
|
3. Use `get_affected_flows` to understand impact.
|
||||||
|
4. Use `query_graph` pattern="tests_for" to check coverage.
|
||||||
|
|||||||
@@ -73,6 +73,22 @@ export const localeItems = (): GridItem[] => [
|
|||||||
{ value: 'ru', icon: 'mdiAlphabeticalVariant', label: 'Русский', desc: t('gridDesc.localeRu') },
|
{ value: 'ru', icon: 'mdiAlphabeticalVariant', label: 'Русский', desc: t('gridDesc.localeRu') },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
// --- Log level ---
|
||||||
|
|
||||||
|
export const logLevelItems = (): GridItem[] => [
|
||||||
|
{ value: 'DEBUG', icon: 'mdiBugOutline', label: 'DEBUG', desc: t('gridDesc.logLevelDebug') },
|
||||||
|
{ value: 'INFO', icon: 'mdiInformationOutline', label: 'INFO', desc: t('gridDesc.logLevelInfo') },
|
||||||
|
{ value: 'WARNING', icon: 'mdiAlertOutline', label: 'WARNING', desc: t('gridDesc.logLevelWarning') },
|
||||||
|
{ value: 'ERROR', icon: 'mdiAlertOctagonOutline', label: 'ERROR', desc: t('gridDesc.logLevelError') },
|
||||||
|
];
|
||||||
|
|
||||||
|
// --- Log format ---
|
||||||
|
|
||||||
|
export const logFormatItems = (): GridItem[] => [
|
||||||
|
{ value: 'text', icon: 'mdiFormatText', label: 'text', desc: t('gridDesc.logFormatText') },
|
||||||
|
{ value: 'json', icon: 'mdiCodeJson', label: 'json', desc: t('gridDesc.logFormatJson') },
|
||||||
|
];
|
||||||
|
|
||||||
// --- Response mode ---
|
// --- Response mode ---
|
||||||
|
|
||||||
export const responseModeItems = (tFn: typeof t): GridItem[] => [
|
export const responseModeItems = (tFn: typeof t): GridItem[] => [
|
||||||
|
|||||||
@@ -192,7 +192,8 @@
|
|||||||
"apiToken": "API Token",
|
"apiToken": "API Token",
|
||||||
"apiTokenHint": "Optional. Needed for connection testing and repository listing.",
|
"apiTokenHint": "Optional. Needed for connection testing and repository listing.",
|
||||||
"webhookUrl": "Webhook URL",
|
"webhookUrl": "Webhook URL",
|
||||||
"webhookUrlHint": "Set this as the Target URL in Gitea webhook settings (relative to your bridge host).",
|
"webhookUrlHint": "Set this as the Target URL in Gitea webhook settings. The full URL is shown when an external base URL is configured in Settings; otherwise it is relative to your bridge host.",
|
||||||
|
"webhookUrlCopyTitle": "Click to copy",
|
||||||
"nutHost": "NUT Server Host",
|
"nutHost": "NUT Server Host",
|
||||||
"nutHostPlaceholder": "192.168.1.100 or ups.local",
|
"nutHostPlaceholder": "192.168.1.100 or ups.local",
|
||||||
"nutPort": "NUT Server Port",
|
"nutPort": "NUT Server Port",
|
||||||
@@ -1131,6 +1132,12 @@
|
|||||||
"memorySourceNative": "Use Immich native memories API",
|
"memorySourceNative": "Use Immich native memories API",
|
||||||
"localeEn": "English interface",
|
"localeEn": "English interface",
|
||||||
"localeRu": "Russian interface",
|
"localeRu": "Russian interface",
|
||||||
|
"logLevelDebug": "Verbose — show every step",
|
||||||
|
"logLevelInfo": "Default — high-level events",
|
||||||
|
"logLevelWarning": "Warnings and errors only",
|
||||||
|
"logLevelError": "Errors only — quietest",
|
||||||
|
"logFormatText": "Human-readable plain text",
|
||||||
|
"logFormatJson": "One JSON object per line",
|
||||||
"modeMedia": "Send actual photo/video files",
|
"modeMedia": "Send actual photo/video files",
|
||||||
"modeText": "Send file names and links only",
|
"modeText": "Send file names and links only",
|
||||||
"allEvents": "Show all event types",
|
"allEvents": "Show all event types",
|
||||||
|
|||||||
@@ -192,7 +192,8 @@
|
|||||||
"apiToken": "API токен",
|
"apiToken": "API токен",
|
||||||
"apiTokenHint": "Необязательно. Нужен для проверки подключения и получения списка репозиториев.",
|
"apiTokenHint": "Необязательно. Нужен для проверки подключения и получения списка репозиториев.",
|
||||||
"webhookUrl": "URL вебхука",
|
"webhookUrl": "URL вебхука",
|
||||||
"webhookUrlHint": "Укажите этот URL в настройках вебхука Gitea (относительно хоста bridge).",
|
"webhookUrlHint": "Укажите этот URL в настройках вебхука Gitea. Полный URL показывается, если в настройках задан внешний адрес; иначе путь указан относительно хоста bridge.",
|
||||||
|
"webhookUrlCopyTitle": "Нажмите, чтобы скопировать",
|
||||||
"nutHost": "Хост NUT-сервера",
|
"nutHost": "Хост NUT-сервера",
|
||||||
"nutHostPlaceholder": "192.168.1.100 или ups.local",
|
"nutHostPlaceholder": "192.168.1.100 или ups.local",
|
||||||
"nutPort": "Порт NUT-сервера",
|
"nutPort": "Порт NUT-сервера",
|
||||||
@@ -1131,6 +1132,12 @@
|
|||||||
"memorySourceNative": "Использовать API воспоминаний Immich",
|
"memorySourceNative": "Использовать API воспоминаний Immich",
|
||||||
"localeEn": "Английский интерфейс",
|
"localeEn": "Английский интерфейс",
|
||||||
"localeRu": "Русский интерфейс",
|
"localeRu": "Русский интерфейс",
|
||||||
|
"logLevelDebug": "Подробный — каждый шаг",
|
||||||
|
"logLevelInfo": "По умолчанию — ключевые события",
|
||||||
|
"logLevelWarning": "Только предупреждения и ошибки",
|
||||||
|
"logLevelError": "Только ошибки — самый тихий",
|
||||||
|
"logFormatText": "Читаемый человеком текст",
|
||||||
|
"logFormatJson": "Один JSON-объект на строку",
|
||||||
"modeMedia": "Отправка файлов фото/видео",
|
"modeMedia": "Отправка файлов фото/видео",
|
||||||
"modeText": "Только имена файлов и ссылки",
|
"modeText": "Только имена файлов и ссылки",
|
||||||
"allEvents": "Показать все типы событий",
|
"allEvents": "Показать все типы событий",
|
||||||
|
|||||||
@@ -112,6 +112,34 @@ export const capabilitiesCache = (() => {
|
|||||||
};
|
};
|
||||||
})();
|
})();
|
||||||
|
|
||||||
|
/** Configured external base URL — used to render absolute webhook URLs.
|
||||||
|
* Available to all authenticated users. Empty string when unset. */
|
||||||
|
export const externalUrlCache = (() => {
|
||||||
|
let data = $state<string>('');
|
||||||
|
let fetchedAt = $state(0);
|
||||||
|
let inflight: Promise<string> | null = null;
|
||||||
|
const TTL = 300_000;
|
||||||
|
return {
|
||||||
|
get value() { return data; },
|
||||||
|
invalidate() { fetchedAt = 0; },
|
||||||
|
async fetch(force = false): Promise<string> {
|
||||||
|
if (!force && fetchedAt > 0 && Date.now() - fetchedAt < TTL) return data;
|
||||||
|
if (inflight) return inflight;
|
||||||
|
inflight = (async () => {
|
||||||
|
try {
|
||||||
|
const res = await api<{ external_url: string }>('/settings/external-url');
|
||||||
|
data = (res?.external_url || '').replace(/\/+$/, '');
|
||||||
|
fetchedAt = Date.now();
|
||||||
|
return data;
|
||||||
|
} finally {
|
||||||
|
inflight = null;
|
||||||
|
}
|
||||||
|
})();
|
||||||
|
return inflight;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
})();
|
||||||
|
|
||||||
/** Supported template locales — fetched from app settings. */
|
/** Supported template locales — fetched from app settings. */
|
||||||
export const supportedLocalesCache = (() => {
|
export const supportedLocalesCache = (() => {
|
||||||
let data = $state<string[]>(['en', 'ru']);
|
let data = $state<string[]>(['en', 'ru']);
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import { slide } from 'svelte/transition';
|
import { slide } from 'svelte/transition';
|
||||||
import { api, getBlockedBy, type BlockedByDetail } from '$lib/api';
|
import { api, getBlockedBy, type BlockedByDetail } from '$lib/api';
|
||||||
import { t } from '$lib/i18n';
|
import { t } from '$lib/i18n';
|
||||||
import { providersCache } from '$lib/stores/caches.svelte';
|
import { providersCache, externalUrlCache } from '$lib/stores/caches.svelte';
|
||||||
import PageHeader from '$lib/components/PageHeader.svelte';
|
import PageHeader from '$lib/components/PageHeader.svelte';
|
||||||
import Card from '$lib/components/Card.svelte';
|
import Card from '$lib/components/Card.svelte';
|
||||||
import Loading from '$lib/components/Loading.svelte';
|
import Loading from '$lib/components/Loading.svelte';
|
||||||
@@ -21,7 +21,7 @@
|
|||||||
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
|
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
|
||||||
import { topbarAction } from '$lib/stores/topbar-action.svelte';
|
import { topbarAction } from '$lib/stores/topbar-action.svelte';
|
||||||
import { onDestroy } from 'svelte';
|
import { onDestroy } from 'svelte';
|
||||||
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
|
import { snackSuccess, snackError, snackInfo } from '$lib/stores/snackbar.svelte';
|
||||||
import { highlightFromUrl } from '$lib/highlight';
|
import { highlightFromUrl } from '$lib/highlight';
|
||||||
import { getDescriptor, buildProviderFormDefaults } from '$lib/providers';
|
import { getDescriptor, buildProviderFormDefaults } from '$lib/providers';
|
||||||
import Button from '$lib/components/Button.svelte';
|
import Button from '$lib/components/Button.svelte';
|
||||||
@@ -45,6 +45,30 @@
|
|||||||
let confirmDelete = $state<ServiceProvider | null>(null);
|
let confirmDelete = $state<ServiceProvider | null>(null);
|
||||||
|
|
||||||
let descriptor = $derived(getDescriptor(form.type));
|
let descriptor = $derived(getDescriptor(form.type));
|
||||||
|
let externalUrl = $derived(externalUrlCache.value);
|
||||||
|
|
||||||
|
function buildWebhookUrl(pattern: string, token: string): string {
|
||||||
|
const path = pattern.replace('{token}', token ?? '');
|
||||||
|
return externalUrl ? `${externalUrl}${path}` : path;
|
||||||
|
}
|
||||||
|
|
||||||
|
function copyWebhookUrl(e: Event, url: string) {
|
||||||
|
e.preventDefault();
|
||||||
|
e.stopPropagation();
|
||||||
|
if (navigator.clipboard?.writeText) {
|
||||||
|
navigator.clipboard.writeText(url);
|
||||||
|
} else {
|
||||||
|
const ta = document.createElement('textarea');
|
||||||
|
ta.value = url;
|
||||||
|
ta.style.position = 'fixed';
|
||||||
|
ta.style.opacity = '0';
|
||||||
|
document.body.appendChild(ta);
|
||||||
|
ta.select();
|
||||||
|
document.execCommand('copy');
|
||||||
|
document.body.removeChild(ta);
|
||||||
|
}
|
||||||
|
snackInfo(`${t('snack.copied')}: ${url}`);
|
||||||
|
}
|
||||||
|
|
||||||
// Auto-update name when provider type changes (unless user manually edited)
|
// Auto-update name when provider type changes (unless user manually edited)
|
||||||
$effect(() => {
|
$effect(() => {
|
||||||
@@ -76,6 +100,7 @@
|
|||||||
onclick: () => { showForm ? (showForm = false, editing = null) : openNew(); },
|
onclick: () => { showForm ? (showForm = false, editing = null) : openNew(); },
|
||||||
});
|
});
|
||||||
load();
|
load();
|
||||||
|
externalUrlCache.fetch().catch(() => { /* fall back to relative URLs */ });
|
||||||
});
|
});
|
||||||
onDestroy(() => topbarAction.clear());
|
onDestroy(() => topbarAction.clear());
|
||||||
async function load() {
|
async function load() {
|
||||||
@@ -246,9 +271,15 @@
|
|||||||
</div>
|
</div>
|
||||||
{/each}
|
{/each}
|
||||||
{#if descriptor?.webhookUrlPattern && editing}
|
{#if descriptor?.webhookUrlPattern && editing}
|
||||||
|
{@const editingWebhookUrl = buildWebhookUrl(descriptor.webhookUrlPattern, providers.find(p => p.id === editing)?.webhook_token ?? '')}
|
||||||
<div class="bg-[var(--color-muted)] rounded-md p-3">
|
<div class="bg-[var(--color-muted)] rounded-md p-3">
|
||||||
<div class="block text-sm font-medium mb-1">{t('providers.webhookUrl')}</div>
|
<div class="block text-sm font-medium mb-1">{t('providers.webhookUrl')}</div>
|
||||||
<code class="text-xs select-all break-all">{descriptor.webhookUrlPattern.replace('{token}', providers.find(p => p.id === editing)?.webhook_token ?? '')}</code>
|
<button type="button"
|
||||||
|
onclick={(e) => copyWebhookUrl(e, editingWebhookUrl)}
|
||||||
|
title={t('providers.webhookUrlCopyTitle')}
|
||||||
|
class="text-xs break-all text-left hover:text-[var(--color-primary)] cursor-pointer font-mono w-full">
|
||||||
|
<code class="bg-transparent">{editingWebhookUrl}</code>
|
||||||
|
</button>
|
||||||
<p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.webhookUrlHint')}</p>
|
<p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.webhookUrlHint')}</p>
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
@@ -295,7 +326,14 @@
|
|||||||
<p class="text-xs text-[var(--color-muted-foreground)] font-mono">{provider.config.host}:{provider.config.port || 3493}</p>
|
<p class="text-xs text-[var(--color-muted-foreground)] font-mono">{provider.config.host}:{provider.config.port || 3493}</p>
|
||||||
{/if}
|
{/if}
|
||||||
{#if provDesc?.webhookUrlPattern}
|
{#if provDesc?.webhookUrlPattern}
|
||||||
<p class="text-xs text-[var(--color-muted-foreground)] font-mono mt-0.5">{t('providers.webhookUrl')}: <span class="select-all">{provDesc.webhookUrlPattern.replace('{token}', provider.webhook_token)}</span></p>
|
{@const webhookUrl = buildWebhookUrl(provDesc.webhookUrlPattern, provider.webhook_token)}
|
||||||
|
<p class="text-xs text-[var(--color-muted-foreground)] font-mono mt-0.5">
|
||||||
|
{t('providers.webhookUrl')}:
|
||||||
|
<button type="button"
|
||||||
|
onclick={(e) => copyWebhookUrl(e, webhookUrl)}
|
||||||
|
title={t('providers.webhookUrlCopyTitle')}
|
||||||
|
class="hover:text-[var(--color-primary)] cursor-pointer break-all text-left">{webhookUrl}</button>
|
||||||
|
</p>
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -12,7 +12,10 @@
|
|||||||
import ConfirmModal from '$lib/components/ConfirmModal.svelte';
|
import ConfirmModal from '$lib/components/ConfirmModal.svelte';
|
||||||
import LocaleSelector from '$lib/components/LocaleSelector.svelte';
|
import LocaleSelector from '$lib/components/LocaleSelector.svelte';
|
||||||
import TimezoneSelector from '$lib/components/TimezoneSelector.svelte';
|
import TimezoneSelector from '$lib/components/TimezoneSelector.svelte';
|
||||||
|
import IconGridSelect from '$lib/components/IconGridSelect.svelte';
|
||||||
|
import { logLevelItems, logFormatItems } from '$lib/grid-items';
|
||||||
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
|
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
|
||||||
|
import { externalUrlCache } from '$lib/stores/caches.svelte';
|
||||||
|
|
||||||
interface CacheBucketStats {
|
interface CacheBucketStats {
|
||||||
count: number;
|
count: number;
|
||||||
@@ -76,6 +79,7 @@
|
|||||||
saving = true; error = '';
|
saving = true; error = '';
|
||||||
try {
|
try {
|
||||||
settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) });
|
settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) });
|
||||||
|
externalUrlCache.invalidate();
|
||||||
snackSuccess(t('settings.saved'));
|
snackSuccess(t('settings.saved'));
|
||||||
} catch (err: any) { error = err.message; snackError(err.message); }
|
} catch (err: any) { error = err.message; snackError(err.message); }
|
||||||
saving = false;
|
saving = false;
|
||||||
@@ -221,21 +225,11 @@
|
|||||||
<div class="grid grid-cols-1 sm:grid-cols-2 gap-4">
|
<div class="grid grid-cols-1 sm:grid-cols-2 gap-4">
|
||||||
<div>
|
<div>
|
||||||
<label class="block text-xs font-medium mb-1">{t('settings.logLevel')}<Hint text={t('settings.logLevelHint')} /></label>
|
<label class="block text-xs font-medium mb-1">{t('settings.logLevel')}<Hint text={t('settings.logLevelHint')} /></label>
|
||||||
<select bind:value={settings.log_level}
|
<IconGridSelect items={logLevelItems()} bind:value={settings.log_level} columns={2} />
|
||||||
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
|
|
||||||
<option value="DEBUG">DEBUG</option>
|
|
||||||
<option value="INFO">INFO</option>
|
|
||||||
<option value="WARNING">WARNING</option>
|
|
||||||
<option value="ERROR">ERROR</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<label class="block text-xs font-medium mb-1">{t('settings.logFormat')}<Hint text={t('settings.logFormatHint')} /></label>
|
<label class="block text-xs font-medium mb-1">{t('settings.logFormat')}<Hint text={t('settings.logFormatHint')} /></label>
|
||||||
<select bind:value={settings.log_format}
|
<IconGridSelect items={logFormatItems()} bind:value={settings.log_format} columns={2} />
|
||||||
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
|
|
||||||
<option value="text">text</option>
|
|
||||||
<option value="json">json</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
</div>
|
||||||
<div class="sm:col-span-2">
|
<div class="sm:col-span-2">
|
||||||
<label class="block text-xs font-medium mb-1">{t('settings.logLevels')}<Hint text={t('settings.logLevelsHint')} /></label>
|
<label class="block text-xs font-medium mb-1">{t('settings.logLevels')}<Hint text={t('settings.logLevelsHint')} /></label>
|
||||||
|
|||||||
@@ -4,21 +4,24 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, Final
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from ..http_base import HttpProviderClient
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Discord webhook content limit
|
# Discord API constraints (per webhook docs).
|
||||||
MAX_CONTENT_LENGTH = 2000
|
MAX_CONTENT_LENGTH: Final = 2000
|
||||||
|
MAX_USERNAME_LENGTH: Final = 80
|
||||||
|
|
||||||
|
|
||||||
class DiscordClient:
|
class DiscordClient(HttpProviderClient):
|
||||||
"""Sends messages via Discord webhook URLs."""
|
"""Sends messages via Discord webhook URLs."""
|
||||||
|
|
||||||
def __init__(self, session: aiohttp.ClientSession) -> None:
|
def __init__(self, session: aiohttp.ClientSession) -> None:
|
||||||
self._session = session
|
super().__init__(session, provider_name="discord")
|
||||||
|
|
||||||
async def send(
|
async def send(
|
||||||
self,
|
self,
|
||||||
@@ -33,6 +36,8 @@ class DiscordClient:
|
|||||||
"""
|
"""
|
||||||
if not webhook_url:
|
if not webhook_url:
|
||||||
return {"success": False, "error": "Missing webhook_url"}
|
return {"success": False, "error": "Missing webhook_url"}
|
||||||
|
if username and len(username) > MAX_USERNAME_LENGTH:
|
||||||
|
return {"success": False, "error": f"username exceeds {MAX_USERNAME_LENGTH} chars"}
|
||||||
|
|
||||||
chunks = _split_message(message, MAX_CONTENT_LENGTH)
|
chunks = _split_message(message, MAX_CONTENT_LENGTH)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
@@ -42,71 +47,34 @@ class DiscordClient:
|
|||||||
if avatar_url:
|
if avatar_url:
|
||||||
payload["avatar_url"] = avatar_url
|
payload["avatar_url"] = avatar_url
|
||||||
|
|
||||||
result = await self._post(webhook_url, payload)
|
result = await self.request("POST", webhook_url, json=payload)
|
||||||
if not result["success"]:
|
if not result.get("success"):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Small delay between chunks to respect rate limits
|
|
||||||
if len(chunks) > 1:
|
if len(chunks) > 1:
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
|
|
||||||
_MAX_RETRIES = 3
|
|
||||||
_MAX_RETRY_AFTER = 60.0
|
|
||||||
|
|
||||||
async def _post(self, url: str, payload: dict) -> dict[str, Any]:
|
|
||||||
"""POST with bounded 429 retry.
|
|
||||||
|
|
||||||
We cap retries at _MAX_RETRIES and the ``Retry-After`` header at
|
|
||||||
_MAX_RETRY_AFTER seconds so a hostile or misbehaving upstream cannot
|
|
||||||
pin the dispatch task indefinitely.
|
|
||||||
"""
|
|
||||||
for attempt in range(self._MAX_RETRIES + 1):
|
|
||||||
try:
|
|
||||||
async with self._session.post(
|
|
||||||
url,
|
|
||||||
json=payload,
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
allow_redirects=False,
|
|
||||||
) as resp:
|
|
||||||
if resp.status == 429 and attempt < self._MAX_RETRIES:
|
|
||||||
try:
|
|
||||||
retry_after = float(resp.headers.get("Retry-After", "2"))
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
retry_after = 2.0
|
|
||||||
retry_after = max(0.0, min(retry_after, self._MAX_RETRY_AFTER))
|
|
||||||
_LOGGER.warning(
|
|
||||||
"Discord rate limited, retrying after %.1fs (attempt %d/%d)",
|
|
||||||
retry_after, attempt + 1, self._MAX_RETRIES,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(retry_after)
|
|
||||||
continue
|
|
||||||
if 200 <= resp.status < 300:
|
|
||||||
return {"success": True}
|
|
||||||
body = await resp.text()
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"HTTP {resp.status}: {body[:200]}",
|
|
||||||
}
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
return {"success": False, "error": "Rate limited (retries exhausted)"}
|
|
||||||
|
|
||||||
|
|
||||||
def _split_message(text: str, limit: int) -> list[str]:
|
def _split_message(text: str, limit: int) -> list[str]:
|
||||||
"""Split message into chunks respecting the character limit."""
|
"""Split message into chunks respecting the character limit.
|
||||||
|
|
||||||
|
Drops chunks that contain only whitespace — Discord rejects those.
|
||||||
|
"""
|
||||||
if len(text) <= limit:
|
if len(text) <= limit:
|
||||||
return [text]
|
return [text]
|
||||||
chunks = []
|
chunks: list[str] = []
|
||||||
while text:
|
while text:
|
||||||
if len(text) <= limit:
|
if len(text) <= limit:
|
||||||
chunks.append(text)
|
piece = text
|
||||||
break
|
text = ""
|
||||||
# Try to split at newline
|
else:
|
||||||
split_at = text.rfind("\n", 0, limit)
|
split_at = text.rfind("\n", 0, limit)
|
||||||
if split_at <= 0:
|
if split_at <= 0:
|
||||||
split_at = limit
|
split_at = limit
|
||||||
chunks.append(text[:split_at])
|
piece = text[:split_at]
|
||||||
text = text[split_at:].lstrip("\n")
|
text = text[split_at:].lstrip("\n")
|
||||||
return chunks
|
if piece.strip():
|
||||||
|
chunks.append(piece)
|
||||||
|
return chunks or [text]
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import contextlib
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, AsyncIterator
|
from typing import Any, AsyncIterator, Awaitable, Callable, Final
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
@@ -15,37 +15,20 @@ from notify_bridge_core.log_context import bind_log_context, dispatch_id_var
|
|||||||
from notify_bridge_core.models.events import ServiceEvent
|
from notify_bridge_core.models.events import ServiceEvent
|
||||||
from notify_bridge_core.templates.context import build_template_context
|
from notify_bridge_core.templates.context import build_template_context
|
||||||
from notify_bridge_core.templates.renderer import render_template
|
from notify_bridge_core.templates.renderer import render_template
|
||||||
from .ssrf import UnsafeURLError, avalidate_outbound_url
|
|
||||||
|
|
||||||
_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
|
||||||
|
|
||||||
# Cap on how many asset downloads run concurrently inside
|
|
||||||
# ``_preload_asset_data``. Peak memory during a send is bounded to roughly
|
|
||||||
# ``_PRELOAD_CONCURRENCY * max_asset_size`` instead of ``max_media_to_send *
|
|
||||||
# max_asset_size``, which matters on small-RAM Docker hosts when a batch
|
|
||||||
# contains many large videos.
|
|
||||||
_PRELOAD_CONCURRENCY = 6
|
|
||||||
|
|
||||||
|
|
||||||
def _new_session() -> aiohttp.ClientSession:
|
|
||||||
"""Per-dispatch aiohttp session with a sane default timeout.
|
|
||||||
|
|
||||||
We still open a short-lived session per dispatch (connection reuse across
|
|
||||||
dispatches lives in the server-side shared session), but we always attach
|
|
||||||
a total timeout so a hung peer cannot wedge the task forever.
|
|
||||||
"""
|
|
||||||
return aiohttp.ClientSession(timeout=_HTTP_TIMEOUT)
|
|
||||||
|
|
||||||
|
from .http_base import safe_headers
|
||||||
from .receiver import (
|
from .receiver import (
|
||||||
|
DiscordReceiver,
|
||||||
|
EmailReceiver,
|
||||||
|
MatrixReceiver,
|
||||||
|
NtfyReceiver,
|
||||||
Receiver,
|
Receiver,
|
||||||
|
SlackReceiver,
|
||||||
TelegramReceiver,
|
TelegramReceiver,
|
||||||
WebhookReceiver,
|
WebhookReceiver,
|
||||||
EmailReceiver,
|
|
||||||
DiscordReceiver,
|
|
||||||
SlackReceiver,
|
|
||||||
NtfyReceiver,
|
|
||||||
MatrixReceiver,
|
|
||||||
)
|
)
|
||||||
|
from .redact import redact_exc
|
||||||
|
from .ssrf import UnsafeURLError, avalidate_outbound_url
|
||||||
from .telegram.cache import TelegramFileCache
|
from .telegram.cache import TelegramFileCache
|
||||||
from .telegram.client import TelegramClient
|
from .telegram.client import TelegramClient
|
||||||
from .telegram.media import (
|
from .telegram.media import (
|
||||||
@@ -58,7 +41,33 @@ from .webhook.client import WebhookClient
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_TEMPLATE = '{{ event_type }}: "{{ collection_name }}"'
|
DEFAULT_TEMPLATE: Final = '{{ event_type }}: "{{ collection_name }}"'
|
||||||
|
|
||||||
|
_HTTP_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
|
||||||
|
|
||||||
|
# Cap on how many asset downloads run concurrently inside
|
||||||
|
# ``_preload_asset_data``. Peak memory during a send is bounded to roughly
|
||||||
|
# ``_PRELOAD_CONCURRENCY * max_asset_size`` instead of ``max_media_to_send *
|
||||||
|
# max_asset_size``.
|
||||||
|
_PRELOAD_CONCURRENCY: Final = 6
|
||||||
|
|
||||||
|
# Cap on how many targets the dispatcher fans out to at once. With dozens
|
||||||
|
# of targets and a single hung peer, unbounded ``gather`` can pin the
|
||||||
|
# dispatch task. The cap also protects against credential-reuse rate
|
||||||
|
# limits on shared providers.
|
||||||
|
_DISPATCH_CONCURRENCY: Final = 16
|
||||||
|
|
||||||
|
# Cap on parallel per-receiver sends within a single target.
|
||||||
|
_RECEIVER_CONCURRENCY: Final = 8
|
||||||
|
|
||||||
|
# Per-target soft timeout — at the top of the dispatch tree so a single
|
||||||
|
# misbehaving target can't hold the whole batch open. Individual provider
|
||||||
|
# clients carry their own per-request timeouts on top of this.
|
||||||
|
_TARGET_TIMEOUT_S: Final = 120.0
|
||||||
|
|
||||||
|
|
||||||
|
def _new_session() -> aiohttp.ClientSession:
|
||||||
|
return aiohttp.ClientSession(timeout=_HTTP_TIMEOUT)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -66,17 +75,23 @@ class TargetConfig:
|
|||||||
"""Configuration for a notification target."""
|
"""Configuration for a notification target."""
|
||||||
|
|
||||||
type: str # "telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"
|
type: str # "telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"
|
||||||
config: dict[str, Any] # target-level config (bot_token, settings, etc.)
|
config: dict[str, Any]
|
||||||
template_slots: dict[str, dict[str, str]] | None = None # event_type -> {locale -> template}
|
template_slots: dict[str, dict[str, str]] | None = None
|
||||||
locale: str = "en" # default locale for template resolution
|
locale: str = "en"
|
||||||
date_format: str = "%d.%m.%Y, %H:%M UTC"
|
date_format: str = "%d.%m.%Y, %H:%M UTC"
|
||||||
date_only_format: str = "%d.%m.%Y"
|
date_only_format: str = "%d.%m.%Y"
|
||||||
provider_api_key: str | None = None # API key for downloading assets from provider
|
provider_api_key: str | None = None
|
||||||
provider_internal_url: str | None = None # Internal provider URL for API key scoping
|
provider_internal_url: str | None = None
|
||||||
provider_external_url: str | None = None # External domain for API key scoping
|
provider_external_url: str | None = None
|
||||||
receivers: list[Receiver] = field(default_factory=list)
|
receivers: list[Receiver] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
_SendMethod = Callable[
|
||||||
|
["NotificationDispatcher", TargetConfig, str, ServiceEvent],
|
||||||
|
Awaitable[dict[str, Any]],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class NotificationDispatcher:
|
class NotificationDispatcher:
|
||||||
"""Dispatches ServiceEvent notifications to configured targets."""
|
"""Dispatches ServiceEvent notifications to configured targets."""
|
||||||
|
|
||||||
@@ -90,18 +105,11 @@ class NotificationDispatcher:
|
|||||||
self._url_cache = url_cache
|
self._url_cache = url_cache
|
||||||
self._asset_cache = asset_cache
|
self._asset_cache = asset_cache
|
||||||
# Optional shared session owned by the caller; when supplied we reuse
|
# Optional shared session owned by the caller; when supplied we reuse
|
||||||
# its connection pool instead of opening a fresh per-dispatch session
|
# its connection pool instead of opening a fresh per-dispatch session.
|
||||||
# (saves a TLS handshake per outbound call).
|
|
||||||
self._shared_session = session
|
self._shared_session = session
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]:
|
async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]:
|
||||||
"""Yield an aiohttp session, reusing the shared one if provided.
|
|
||||||
|
|
||||||
When a shared session was passed in ``__init__`` we yield it without
|
|
||||||
closing (the caller owns its lifetime). Otherwise we open a
|
|
||||||
short-lived session with our default timeout and close it on exit.
|
|
||||||
"""
|
|
||||||
if self._shared_session is not None and not self._shared_session.closed:
|
if self._shared_session is not None and not self._shared_session.closed:
|
||||||
yield self._shared_session
|
yield self._shared_session
|
||||||
return
|
return
|
||||||
@@ -115,11 +123,9 @@ class NotificationDispatcher:
|
|||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Send event notification to all targets.
|
"""Send event notification to all targets.
|
||||||
|
|
||||||
Returns list of results (one per target).
|
Returns one result per target. Per-target failures are isolated;
|
||||||
|
a single bad target cannot poison the batch.
|
||||||
"""
|
"""
|
||||||
# Bind a dispatch_id so every log line emitted by the target sends
|
|
||||||
# (including deep in TelegramClient) can be correlated to the same
|
|
||||||
# upstream event.
|
|
||||||
new_id = dispatch_id_var.get() or f"disp:{uuid.uuid4().hex[:12]}"
|
new_id = dispatch_id_var.get() or f"disp:{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
with bind_log_context(dispatch_id=new_id):
|
with bind_log_context(dispatch_id=new_id):
|
||||||
@@ -128,20 +134,36 @@ class NotificationDispatcher:
|
|||||||
event.event_type.value if hasattr(event.event_type, "value") else event.event_type,
|
event.event_type.value if hasattr(event.event_type, "value") else event.event_type,
|
||||||
getattr(event, "collection_name", None), len(targets),
|
getattr(event, "collection_name", None), len(targets),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sem = asyncio.Semaphore(_DISPATCH_CONCURRENCY)
|
||||||
|
|
||||||
|
async def run_one(t: TargetConfig) -> dict[str, Any]:
|
||||||
|
async with sem:
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
self._send_to_target(event, t),
|
||||||
|
timeout=_TARGET_TIMEOUT_S,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Target dispatch timed out after {_TARGET_TIMEOUT_S}s",
|
||||||
|
}
|
||||||
|
|
||||||
raw_results = await asyncio.gather(
|
raw_results = await asyncio.gather(
|
||||||
*[self._send_to_target(event, t) for t in targets],
|
*[run_one(t) for t in targets],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
results = []
|
results: list[dict[str, Any]] = []
|
||||||
failures = 0
|
failures = 0
|
||||||
for target, raw in zip(targets, raw_results):
|
for target, raw in zip(targets, raw_results):
|
||||||
if isinstance(raw, Exception):
|
if isinstance(raw, Exception):
|
||||||
failures += 1
|
failures += 1
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
"Dispatch to target type=%s failed: %s",
|
"Dispatch to target type=%s failed: %s",
|
||||||
target.type, raw, exc_info=raw,
|
target.type, redact_exc(raw),
|
||||||
)
|
)
|
||||||
results.append({"success": False, "error": str(raw)})
|
results.append({"success": False, "error": redact_exc(raw)})
|
||||||
else:
|
else:
|
||||||
if isinstance(raw, dict) and not raw.get("success"):
|
if isinstance(raw, dict) and not raw.get("success"):
|
||||||
failures += 1
|
failures += 1
|
||||||
@@ -155,7 +177,6 @@ class NotificationDispatcher:
|
|||||||
def _resolve_template(
|
def _resolve_template(
|
||||||
self, event: ServiceEvent, target: TargetConfig, locale: str,
|
self, event: ServiceEvent, target: TargetConfig, locale: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Resolve template string for an event, with locale fallback."""
|
|
||||||
template_str = DEFAULT_TEMPLATE
|
template_str = DEFAULT_TEMPLATE
|
||||||
if target.template_slots:
|
if target.template_slots:
|
||||||
locale_map = target.template_slots.get(event.event_type.value)
|
locale_map = target.template_slots.get(event.event_type.value)
|
||||||
@@ -166,7 +187,6 @@ class NotificationDispatcher:
|
|||||||
def _render_message(
|
def _render_message(
|
||||||
self, event: ServiceEvent, target: TargetConfig, locale: str,
|
self, event: ServiceEvent, target: TargetConfig, locale: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Resolve template and render message for a given locale."""
|
|
||||||
template_str = self._resolve_template(event, target, locale)
|
template_str = self._resolve_template(event, target, locale)
|
||||||
ctx = build_template_context(
|
ctx = build_template_context(
|
||||||
event, target_type=target.type,
|
event, target_type=target.type,
|
||||||
@@ -179,7 +199,6 @@ class NotificationDispatcher:
|
|||||||
self, receiver: Receiver, default_message: str,
|
self, receiver: Receiver, default_message: str,
|
||||||
event: ServiceEvent, target: TargetConfig,
|
event: ServiceEvent, target: TargetConfig,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return per-receiver message, re-rendering if receiver has a different locale."""
|
|
||||||
if receiver.locale and receiver.locale != target.locale:
|
if receiver.locale and receiver.locale != target.locale:
|
||||||
return self._render_message(event, target, receiver.locale)
|
return self._render_message(event, target, receiver.locale)
|
||||||
return default_message
|
return default_message
|
||||||
@@ -187,21 +206,16 @@ class NotificationDispatcher:
|
|||||||
async def _send_to_target(
|
async def _send_to_target(
|
||||||
self, event: ServiceEvent, target: TargetConfig
|
self, event: ServiceEvent, target: TargetConfig
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Send event to a single target (potentially multiple receivers)."""
|
"""Dispatch to a single target via the registered handler."""
|
||||||
default_message = self._render_message(event, target, target.locale)
|
default_message = self._render_message(event, target, target.locale)
|
||||||
|
send_method = _PROVIDER_HANDLERS.get(target.type)
|
||||||
send_method = {
|
if send_method is None:
|
||||||
"telegram": self._send_telegram,
|
|
||||||
"webhook": self._send_webhook,
|
|
||||||
"email": self._send_email,
|
|
||||||
"discord": self._send_discord,
|
|
||||||
"slack": self._send_slack,
|
|
||||||
"ntfy": self._send_ntfy,
|
|
||||||
"matrix": self._send_matrix,
|
|
||||||
}.get(target.type)
|
|
||||||
if send_method:
|
|
||||||
return await send_method(target, default_message, event)
|
|
||||||
return {"success": False, "error": f"Unknown target type: {target.type}"}
|
return {"success": False, "error": f"Unknown target type: {target.type}"}
|
||||||
|
return await send_method(self, target, default_message, event)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Asset preload (Telegram-specific)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def _preload_asset_data(
|
async def _preload_asset_data(
|
||||||
self,
|
self,
|
||||||
@@ -210,36 +224,13 @@ class NotificationDispatcher:
|
|||||||
session: aiohttp.ClientSession,
|
session: aiohttp.ClientSession,
|
||||||
max_size: int | None,
|
max_size: int | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Download each non-cached asset's bytes once and attach to the entry.
|
"""Download each non-cached asset's bytes once, with SSRF guard."""
|
||||||
|
|
||||||
Three benefits:
|
|
||||||
* ``TelegramClient`` sees ``entry["data"]`` and skips its own download,
|
|
||||||
so we don't fetch each URL twice.
|
|
||||||
* We know the exact upload size, which lets the oversize warning in
|
|
||||||
the rendered text compare against real bytes (for Immich videos,
|
|
||||||
the transcoded ``/video/playback``), not the original ``file_size``.
|
|
||||||
* Assets already in the Telegram file_id cache are skipped, and their
|
|
||||||
stored size (if any) is used to populate ``playback_size`` — so
|
|
||||||
templates see consistent sizes for repeat sends without re-download.
|
|
||||||
|
|
||||||
Entries whose download fails or exceeds ``max_size`` are left without
|
|
||||||
``data``; ``TelegramClient`` will then fall back to its own download
|
|
||||||
path and apply the same checks — no regression, just no preload win.
|
|
||||||
|
|
||||||
Concurrency is bounded by ``_PRELOAD_CONCURRENCY`` so peak memory
|
|
||||||
stays predictable: at most N assets worth of bytes held in RAM at
|
|
||||||
once, regardless of ``max_media_to_send``. Total wall-clock is
|
|
||||||
unchanged for small batches and only marginally slower for large
|
|
||||||
ones (most assets fit in a single RTT and SSL negotiation cost
|
|
||||||
dominates, so 6-way parallelism is sufficient).
|
|
||||||
"""
|
|
||||||
if not assets:
|
if not assets:
|
||||||
return
|
return
|
||||||
|
|
||||||
sem = asyncio.Semaphore(_PRELOAD_CONCURRENCY)
|
sem = asyncio.Semaphore(_PRELOAD_CONCURRENCY)
|
||||||
|
|
||||||
async def _fetch(entry: dict[str, Any], media: Any) -> None:
|
async def fetch(entry: dict[str, Any], media: Any) -> None:
|
||||||
# Cache hit → skip download; populate playback_size from stored size.
|
|
||||||
cache, key = self._cache_for_entry(entry)
|
cache, key = self._cache_for_entry(entry)
|
||||||
if cache and key:
|
if cache and key:
|
||||||
cached = cache.get(key)
|
cached = cache.get(key)
|
||||||
@@ -251,28 +242,40 @@ class NotificationDispatcher:
|
|||||||
|
|
||||||
url = entry["url"]
|
url = entry["url"]
|
||||||
headers = entry.get("headers") or {}
|
headers = entry.get("headers") or {}
|
||||||
|
try:
|
||||||
|
# Defense-in-depth: validate even though TelegramClient
|
||||||
|
# also validates. The dispatcher is what triggers the
|
||||||
|
# download, so the guard belongs here too.
|
||||||
|
await avalidate_outbound_url(url)
|
||||||
|
except UnsafeURLError as err:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Asset preload skipped: unsafe URL (%s)", redact_exc(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
async with sem:
|
async with sem:
|
||||||
try:
|
try:
|
||||||
async with session.get(url, headers=headers) as resp:
|
async with session.get(url, headers=headers) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
return
|
return
|
||||||
data = await resp.read()
|
data = await resp.read()
|
||||||
except aiohttp.ClientError:
|
except (aiohttp.ClientError, asyncio.TimeoutError, OSError):
|
||||||
return
|
return
|
||||||
if max_size is not None and len(data) > max_size:
|
if max_size is not None and len(data) > max_size:
|
||||||
return
|
return
|
||||||
entry["data"] = data
|
entry["data"] = data
|
||||||
media.extra["playback_size"] = len(data)
|
media.extra["playback_size"] = len(data)
|
||||||
|
|
||||||
await asyncio.gather(*(_fetch(e, m) for e, m in zip(assets, media_assets)))
|
raw = await asyncio.gather(
|
||||||
|
*(fetch(e, m) for e, m in zip(assets, media_assets)),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
for r in raw:
|
||||||
|
if isinstance(r, Exception):
|
||||||
|
_LOGGER.warning("Asset preload raised: %s", redact_exc(r))
|
||||||
|
|
||||||
def _cache_for_entry(
|
def _cache_for_entry(
|
||||||
self, entry: dict[str, Any],
|
self, entry: dict[str, Any],
|
||||||
) -> tuple[TelegramFileCache | None, str | None]:
|
) -> tuple[TelegramFileCache | None, str | None]:
|
||||||
"""Resolve (cache, key) for an asset entry — mirrors TelegramClient logic.
|
|
||||||
|
|
||||||
Returns (None, None) if no cache is configured or no key can be derived.
|
|
||||||
"""
|
|
||||||
cache_key = entry.get("cache_key")
|
cache_key = entry.get("cache_key")
|
||||||
if cache_key:
|
if cache_key:
|
||||||
cache = self._asset_cache if is_asset_cache_key(cache_key) else self._url_cache
|
cache = self._asset_cache if is_asset_cache_key(cache_key) else self._url_cache
|
||||||
@@ -287,6 +290,10 @@ class NotificationDispatcher:
|
|||||||
return self._url_cache, url
|
return self._url_cache, url
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Per-provider handlers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def _send_telegram(
|
async def _send_telegram(
|
||||||
self, target: TargetConfig, default_message: str, event: ServiceEvent
|
self, target: TargetConfig, default_message: str, event: ServiceEvent
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
@@ -296,27 +303,25 @@ class NotificationDispatcher:
|
|||||||
max_media = target.config.get("max_media_to_send", 50)
|
max_media = target.config.get("max_media_to_send", 50)
|
||||||
max_group = target.config.get("max_media_per_group", 10)
|
max_group = target.config.get("max_media_per_group", 10)
|
||||||
chunk_delay = target.config.get("media_delay", 500)
|
chunk_delay = target.config.get("media_delay", 500)
|
||||||
max_size = target.config.get("max_asset_size")
|
max_size_mb = target.config.get("max_asset_size")
|
||||||
if max_size:
|
max_size_bytes = max_size_mb * 1024 * 1024 if max_size_mb else None
|
||||||
max_size = max_size * 1024 * 1024 # MB to bytes
|
|
||||||
send_large_as_docs = target.config.get("send_large_photos_as_documents", False)
|
send_large_as_docs = target.config.get("send_large_photos_as_documents", False)
|
||||||
|
|
||||||
if not bot_token:
|
if not bot_token:
|
||||||
return {"success": False, "error": "Missing bot_token"}
|
return {"success": False, "error": "Missing bot_token"}
|
||||||
|
|
||||||
if not target.receivers:
|
if not target.receivers:
|
||||||
return {"success": False, "error": "No receivers configured"}
|
return {"success": False, "error": "No receivers configured"}
|
||||||
|
|
||||||
# Prepare assets list once (shared across receivers)
|
|
||||||
# Prefer internal URL for fetching (LAN speed vs public internet)
|
|
||||||
internal_url = (target.provider_internal_url or "").rstrip("/")
|
internal_url = (target.provider_internal_url or "").rstrip("/")
|
||||||
external_url = (target.provider_external_url or "").rstrip("/")
|
external_url = (target.provider_external_url or "").rstrip("/")
|
||||||
assets = []
|
assets: list[dict[str, Any]] = []
|
||||||
media_assets: list[Any] = [] # aligned with `assets` for preload
|
media_assets: list[Any] = []
|
||||||
for asset in event.added_assets[:max_media]:
|
for asset in event.added_assets[:max_media]:
|
||||||
url = asset.preview_url or asset.thumbnail_url or asset.full_url
|
url = asset.preview_url or asset.thumbnail_url or asset.full_url
|
||||||
|
if not url:
|
||||||
|
continue
|
||||||
asset_entry = build_telegram_asset_entry(
|
asset_entry = build_telegram_asset_entry(
|
||||||
url=url or "",
|
url=url,
|
||||||
media_type=asset.type.value,
|
media_type=asset.type.value,
|
||||||
api_key=target.provider_api_key,
|
api_key=target.provider_api_key,
|
||||||
internal_url=internal_url,
|
internal_url=internal_url,
|
||||||
@@ -327,26 +332,15 @@ class NotificationDispatcher:
|
|||||||
assets.append(asset_entry)
|
assets.append(asset_entry)
|
||||||
media_assets.append(asset)
|
media_assets.append(asset)
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
|
||||||
async with self._session_ctx() as session:
|
async with self._session_ctx() as session:
|
||||||
# Preload all asset bytes once so (a) TelegramClient can skip its
|
await self._preload_asset_data(assets, media_assets, session, max_size_bytes)
|
||||||
# own download and (b) we know exact upload sizes in time for the
|
|
||||||
# oversize warning in the rendered text.
|
|
||||||
await self._preload_asset_data(assets, media_assets, session, max_size)
|
|
||||||
default_message = self._render_message(event, target, target.locale)
|
|
||||||
|
|
||||||
# Asset cache (when in thumbhash mode) invalidates entries when the
|
|
||||||
# asset's visual content changes. The resolver maps asset id → its
|
|
||||||
# current thumbhash. Providers that expose thumbhash put it in
|
|
||||||
# ``asset.extra["thumbhash"]`` (currently Immich).
|
|
||||||
thumbhash_map = {
|
thumbhash_map = {
|
||||||
asset.id: asset.extra.get("thumbhash")
|
asset.id: asset.extra.get("thumbhash")
|
||||||
for asset in event.added_assets
|
for asset in event.added_assets
|
||||||
if asset.extra.get("thumbhash")
|
if asset.extra.get("thumbhash")
|
||||||
}
|
}
|
||||||
thumbhash_resolver = (
|
thumbhash_resolver = thumbhash_map.get if thumbhash_map else None
|
||||||
(lambda key: thumbhash_map.get(key)) if thumbhash_map else None
|
|
||||||
)
|
|
||||||
|
|
||||||
client = TelegramClient(
|
client = TelegramClient(
|
||||||
session, bot_token,
|
session, bot_token,
|
||||||
@@ -355,39 +349,51 @@ class NotificationDispatcher:
|
|||||||
thumbhash_resolver=thumbhash_resolver,
|
thumbhash_resolver=thumbhash_resolver,
|
||||||
)
|
)
|
||||||
|
|
||||||
for receiver in target.receivers:
|
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||||
if not isinstance(receiver, TelegramReceiver) or not receiver.chat_id:
|
if not isinstance(receiver, TelegramReceiver) or not receiver.chat_id:
|
||||||
results.append({"success": False, "error": "Invalid telegram receiver"})
|
return {"success": False, "error": "Invalid telegram receiver"}
|
||||||
continue
|
|
||||||
|
|
||||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||||
|
|
||||||
text_result = await client.send_message(
|
text_result = await client.send_message(
|
||||||
chat_id=receiver.chat_id,
|
chat_id=receiver.chat_id,
|
||||||
text=message,
|
text=message,
|
||||||
disable_web_page_preview=bool(disable_preview),
|
disable_web_page_preview=bool(disable_preview),
|
||||||
)
|
)
|
||||||
if not text_result.get("success"):
|
if not text_result.get("success"):
|
||||||
_LOGGER.warning("Failed to send to chat %s: %s", receiver.chat_id, text_result.get("error"))
|
_LOGGER.warning(
|
||||||
results.append(text_result)
|
"Failed to send to chat %s: %s",
|
||||||
continue
|
receiver.chat_id, text_result.get("error"),
|
||||||
|
)
|
||||||
|
return text_result
|
||||||
|
|
||||||
if assets:
|
if assets:
|
||||||
reply_to = text_result.get("message_id")
|
|
||||||
media_result = await client.send_notification(
|
media_result = await client.send_notification(
|
||||||
chat_id=receiver.chat_id,
|
chat_id=receiver.chat_id,
|
||||||
assets=assets,
|
assets=assets,
|
||||||
reply_to_message_id=reply_to,
|
reply_to_message_id=text_result.get("message_id"),
|
||||||
max_group_size=max_group,
|
max_group_size=max_group,
|
||||||
chunk_delay=chunk_delay,
|
chunk_delay=chunk_delay,
|
||||||
max_asset_data_size=max_size,
|
max_asset_data_size=max_size_bytes,
|
||||||
send_large_photos_as_documents=send_large_as_docs,
|
send_large_photos_as_documents=send_large_as_docs,
|
||||||
chat_action=chat_action or None,
|
chat_action=chat_action or None,
|
||||||
)
|
)
|
||||||
if not media_result.get("success"):
|
if not media_result.get("success"):
|
||||||
_LOGGER.warning("Text sent OK but media failed for chat %s: %s", receiver.chat_id, media_result.get("error"))
|
_LOGGER.warning(
|
||||||
|
"Text sent OK but media failed for chat %s: %s",
|
||||||
|
receiver.chat_id, media_result.get("error"),
|
||||||
|
)
|
||||||
|
# Preserve both outcomes — text succeeded, media
|
||||||
|
# didn't. Operators losing media-failure detail
|
||||||
|
# in the result dict made root-cause analysis
|
||||||
|
# impossible.
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message_id": text_result.get("message_id"),
|
||||||
|
"media_error": media_result.get("error"),
|
||||||
|
"media_failed_at_chunk": media_result.get("failed_at_chunk"),
|
||||||
|
}
|
||||||
|
return text_result
|
||||||
|
|
||||||
results.append(text_result)
|
results = await self._fan_out(target.receivers, send_one)
|
||||||
|
|
||||||
return self._aggregate_results(results)
|
return self._aggregate_results(results)
|
||||||
|
|
||||||
@@ -397,17 +403,10 @@ class NotificationDispatcher:
|
|||||||
if not target.receivers:
|
if not target.receivers:
|
||||||
return {"success": False, "error": "No receivers configured"}
|
return {"success": False, "error": "No receivers configured"}
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
|
||||||
async with self._session_ctx() as session:
|
async with self._session_ctx() as session:
|
||||||
for receiver in target.receivers:
|
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||||
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
|
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
|
||||||
results.append({"success": False, "error": "Invalid webhook receiver"})
|
return {"success": False, "error": "Invalid webhook receiver"}
|
||||||
continue
|
|
||||||
try:
|
|
||||||
await avalidate_outbound_url(receiver.url)
|
|
||||||
except UnsafeURLError as err:
|
|
||||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
|
||||||
continue
|
|
||||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||||
payload = {
|
payload = {
|
||||||
"message": message,
|
"message": message,
|
||||||
@@ -417,8 +416,10 @@ class NotificationDispatcher:
|
|||||||
"collection_id": event.collection_id,
|
"collection_id": event.collection_id,
|
||||||
"timestamp": event.timestamp.isoformat(),
|
"timestamp": event.timestamp.isoformat(),
|
||||||
}
|
}
|
||||||
client = WebhookClient(session, receiver.url, receiver.headers)
|
client = WebhookClient(session, receiver.url, safe_headers(receiver.headers))
|
||||||
results.append(await client.send(payload))
|
return await client.send(payload)
|
||||||
|
|
||||||
|
results = await self._fan_out(target.receivers, send_one)
|
||||||
|
|
||||||
return self._aggregate_results(results)
|
return self._aggregate_results(results)
|
||||||
|
|
||||||
@@ -431,7 +432,7 @@ class NotificationDispatcher:
|
|||||||
if not smtp_cfg.get("host"):
|
if not smtp_cfg.get("host"):
|
||||||
return {"success": False, "error": "SMTP not configured"}
|
return {"success": False, "error": "SMTP not configured"}
|
||||||
|
|
||||||
client = EmailClient(SmtpConfig(
|
email_client = EmailClient(SmtpConfig(
|
||||||
host=smtp_cfg["host"],
|
host=smtp_cfg["host"],
|
||||||
port=int(smtp_cfg.get("port", 587)),
|
port=int(smtp_cfg.get("port", 587)),
|
||||||
username=smtp_cfg.get("username", ""),
|
username=smtp_cfg.get("username", ""),
|
||||||
@@ -439,27 +440,28 @@ class NotificationDispatcher:
|
|||||||
from_address=smtp_cfg.get("from_address", ""),
|
from_address=smtp_cfg.get("from_address", ""),
|
||||||
from_name=smtp_cfg.get("from_name", "Notify Bridge"),
|
from_name=smtp_cfg.get("from_name", "Notify Bridge"),
|
||||||
use_tls=smtp_cfg.get("use_tls", True),
|
use_tls=smtp_cfg.get("use_tls", True),
|
||||||
|
tls_mode=smtp_cfg.get("tls_mode", "auto"),
|
||||||
))
|
))
|
||||||
|
|
||||||
if not target.receivers:
|
if not target.receivers:
|
||||||
return {"success": False, "error": "No receivers configured"}
|
return {"success": False, "error": "No receivers configured"}
|
||||||
subject = f"[Notify Bridge] {event.event_type.value}: {event.collection_name}"
|
subject = f"[Notify Bridge] {event.event_type.value}: {event.collection_name}"
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||||
for receiver in target.receivers:
|
|
||||||
if not isinstance(receiver, EmailReceiver) or not receiver.email:
|
if not isinstance(receiver, EmailReceiver) or not receiver.email:
|
||||||
results.append({"success": False, "error": "Invalid email receiver"})
|
return {"success": False, "error": "Invalid email receiver"}
|
||||||
continue
|
|
||||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||||
result = await client.send(
|
# body_html=None lets EmailClient build a safely-escaped HTML
|
||||||
|
# alternative from body_text instead of trusting user content.
|
||||||
|
return await email_client.send(
|
||||||
to_email=receiver.email,
|
to_email=receiver.email,
|
||||||
subject=subject,
|
subject=subject,
|
||||||
body_text=message,
|
body_text=message,
|
||||||
body_html=message,
|
body_html=None,
|
||||||
to_name=receiver.name,
|
to_name=receiver.name,
|
||||||
)
|
)
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
|
results = await self._fan_out(target.receivers, send_one)
|
||||||
return self._aggregate_results(results)
|
return self._aggregate_results(results)
|
||||||
|
|
||||||
async def _send_discord(
|
async def _send_discord(
|
||||||
@@ -471,20 +473,16 @@ class NotificationDispatcher:
|
|||||||
return {"success": False, "error": "No receivers configured"}
|
return {"success": False, "error": "No receivers configured"}
|
||||||
username = target.config.get("username")
|
username = target.config.get("username")
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
|
||||||
async with self._session_ctx() as session:
|
async with self._session_ctx() as session:
|
||||||
client = DiscordClient(session)
|
client = DiscordClient(session)
|
||||||
for receiver in target.receivers:
|
|
||||||
|
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||||
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
|
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
|
||||||
results.append({"success": False, "error": "Invalid discord receiver"})
|
return {"success": False, "error": "Invalid discord receiver"}
|
||||||
continue
|
|
||||||
try:
|
|
||||||
await avalidate_outbound_url(receiver.webhook_url)
|
|
||||||
except UnsafeURLError as err:
|
|
||||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
|
||||||
continue
|
|
||||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||||
results.append(await client.send(receiver.webhook_url, message, username=username))
|
return await client.send(receiver.webhook_url, message, username=username)
|
||||||
|
|
||||||
|
results = await self._fan_out(target.receivers, send_one)
|
||||||
|
|
||||||
return self._aggregate_results(results)
|
return self._aggregate_results(results)
|
||||||
|
|
||||||
@@ -497,20 +495,16 @@ class NotificationDispatcher:
|
|||||||
return {"success": False, "error": "No receivers configured"}
|
return {"success": False, "error": "No receivers configured"}
|
||||||
username = target.config.get("username")
|
username = target.config.get("username")
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
|
||||||
async with self._session_ctx() as session:
|
async with self._session_ctx() as session:
|
||||||
client = SlackClient(session)
|
client = SlackClient(session)
|
||||||
for receiver in target.receivers:
|
|
||||||
|
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||||
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
|
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
|
||||||
results.append({"success": False, "error": "Invalid slack receiver"})
|
return {"success": False, "error": "Invalid slack receiver"}
|
||||||
continue
|
|
||||||
try:
|
|
||||||
await avalidate_outbound_url(receiver.webhook_url)
|
|
||||||
except UnsafeURLError as err:
|
|
||||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
|
||||||
continue
|
|
||||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||||
results.append(await client.send(receiver.webhook_url, message, username=username))
|
return await client.send(receiver.webhook_url, message, username=username)
|
||||||
|
|
||||||
|
results = await self._fan_out(target.receivers, send_one)
|
||||||
|
|
||||||
return self._aggregate_results(results)
|
return self._aggregate_results(results)
|
||||||
|
|
||||||
@@ -526,22 +520,23 @@ class NotificationDispatcher:
|
|||||||
try:
|
try:
|
||||||
await avalidate_outbound_url(server_url)
|
await avalidate_outbound_url(server_url)
|
||||||
except UnsafeURLError as err:
|
except UnsafeURLError as err:
|
||||||
return {"success": False, "error": f"Unsafe ntfy server_url: {err}"}
|
return {"success": False, "error": f"Unsafe ntfy server_url: {redact_exc(err)}"}
|
||||||
|
|
||||||
title = f"{event.event_type.value}: {event.collection_name}"
|
title = f"{event.event_type.value}: {event.collection_name}"
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
|
||||||
async with self._session_ctx() as session:
|
async with self._session_ctx() as session:
|
||||||
client = NtfyClient(session)
|
client = NtfyClient(session)
|
||||||
for receiver in target.receivers:
|
|
||||||
|
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||||
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
|
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
|
||||||
results.append({"success": False, "error": "Invalid ntfy receiver"})
|
return {"success": False, "error": "Invalid ntfy receiver"}
|
||||||
continue
|
|
||||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||||
results.append(await client.send(
|
return await client.send(
|
||||||
server_url, receiver.topic, message,
|
server_url, receiver.topic, message,
|
||||||
title=title, priority=receiver.priority, auth_token=auth_token,
|
title=title, priority=receiver.priority, auth_token=auth_token,
|
||||||
))
|
)
|
||||||
|
|
||||||
|
results = await self._fan_out(target.receivers, send_one)
|
||||||
|
|
||||||
return self._aggregate_results(results)
|
return self._aggregate_results(results)
|
||||||
|
|
||||||
@@ -557,33 +552,108 @@ class NotificationDispatcher:
|
|||||||
try:
|
try:
|
||||||
await avalidate_outbound_url(homeserver)
|
await avalidate_outbound_url(homeserver)
|
||||||
except UnsafeURLError as err:
|
except UnsafeURLError as err:
|
||||||
return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"}
|
return {"success": False, "error": f"Unsafe matrix homeserver_url: {redact_exc(err)}"}
|
||||||
|
|
||||||
if not target.receivers:
|
if not target.receivers:
|
||||||
return {"success": False, "error": "No receivers configured"}
|
return {"success": False, "error": "No receivers configured"}
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
|
||||||
async with self._session_ctx() as session:
|
async with self._session_ctx() as session:
|
||||||
client = MatrixClient(session, homeserver, access_token)
|
client = MatrixClient(session, homeserver, access_token)
|
||||||
for receiver in target.receivers:
|
|
||||||
|
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||||
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
|
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
|
||||||
results.append({"success": False, "error": "Invalid matrix receiver"})
|
return {"success": False, "error": "Invalid matrix receiver"}
|
||||||
continue
|
|
||||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||||
results.append(await client.send_message(
|
# body_html is the same plain text — Matrix accepts the
|
||||||
receiver.room_id, message, html_message=message,
|
# raw message as both ``body`` and ``formatted_body``.
|
||||||
))
|
# If templates emit HTML in the future, generate a
|
||||||
|
# separate HTML body upstream rather than aliasing here.
|
||||||
|
return await client.send_message(
|
||||||
|
receiver.room_id, message, html_message=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await self._fan_out(target.receivers, send_one)
|
||||||
|
|
||||||
return self._aggregate_results(results)
|
return self._aggregate_results(results)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Aggregation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _fan_out(
|
||||||
|
receivers: list[Receiver],
|
||||||
|
send_one: Callable[[Receiver], Awaitable[dict[str, Any]]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Run ``send_one`` per receiver with bounded concurrency.
|
||||||
|
|
||||||
|
Per-receiver exceptions are converted to failure dicts so a single
|
||||||
|
bad receiver can't cancel its peers.
|
||||||
|
"""
|
||||||
|
sem = asyncio.Semaphore(_RECEIVER_CONCURRENCY)
|
||||||
|
|
||||||
|
async def guarded(receiver: Receiver) -> dict[str, Any]:
|
||||||
|
async with sem:
|
||||||
|
try:
|
||||||
|
return await send_one(receiver)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
_LOGGER.error("Receiver send raised: %s", redact_exc(exc))
|
||||||
|
return {"success": False, "error": redact_exc(exc)}
|
||||||
|
|
||||||
|
return await asyncio.gather(*(guarded(r) for r in receivers))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
|
def _aggregate_results(results: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
"""Aggregate broadcast results into a single result dict."""
|
"""Aggregate per-receiver results into a single target-level result.
|
||||||
successes = sum(1 for r in results if r.get("success"))
|
|
||||||
if successes == len(results) and results:
|
Preserves the per-receiver detail under ``receivers`` so a caller
|
||||||
return {"success": True, "receivers": len(results)}
|
can see exactly which receivers failed, instead of getting only
|
||||||
elif successes > 0:
|
the first error.
|
||||||
return {"success": True, "receivers": len(results), "partial_failures": len(results) - successes}
|
"""
|
||||||
elif results:
|
if not results:
|
||||||
return results[0]
|
|
||||||
return {"success": False, "error": "No receivers configured"}
|
return {"success": False, "error": "No receivers configured"}
|
||||||
|
|
||||||
|
successes = sum(1 for r in results if r.get("success"))
|
||||||
|
failures = len(results) - successes
|
||||||
|
out: dict[str, Any] = {
|
||||||
|
"success": successes > 0,
|
||||||
|
"receivers": len(results),
|
||||||
|
"successes": successes,
|
||||||
|
"failures": failures,
|
||||||
|
"results": results,
|
||||||
|
}
|
||||||
|
if failures:
|
||||||
|
out["errors"] = [
|
||||||
|
r.get("error") for r in results if not r.get("success")
|
||||||
|
]
|
||||||
|
if successes == 0:
|
||||||
|
# Surface the first error at the top level for back-compat
|
||||||
|
# with callers that only check ``error``.
|
||||||
|
out["error"] = results[0].get("error", "All receivers failed")
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# Provider registry — replaces the if/elif chain so adding a provider
|
||||||
|
# means just registering it here, not editing dispatch logic.
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
_PROVIDER_HANDLERS: dict[str, _SendMethod] = {
|
||||||
|
"telegram": NotificationDispatcher._send_telegram,
|
||||||
|
"webhook": NotificationDispatcher._send_webhook,
|
||||||
|
"email": NotificationDispatcher._send_email,
|
||||||
|
"discord": NotificationDispatcher._send_discord,
|
||||||
|
"slack": NotificationDispatcher._send_slack,
|
||||||
|
"ntfy": NotificationDispatcher._send_ntfy,
|
||||||
|
"matrix": NotificationDispatcher._send_matrix,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_provider(name: str, handler: _SendMethod) -> None:
|
||||||
|
"""Register a new dispatcher provider at runtime.
|
||||||
|
|
||||||
|
Allows out-of-tree providers to extend the dispatcher without
|
||||||
|
forking. The handler must follow the
|
||||||
|
``async (dispatcher, target, default_message, event) -> dict`` shape.
|
||||||
|
"""
|
||||||
|
_PROVIDER_HANDLERS[name] = handler
|
||||||
|
|||||||
@@ -2,14 +2,32 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import html
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
import ssl
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.headerregistry import Address
|
||||||
from email.mime.text import MIMEText
|
from email.message import EmailMessage
|
||||||
from typing import Any
|
from typing import Any, Final, Literal
|
||||||
|
|
||||||
|
try: # Optional dependency — fail at first send rather than at import.
|
||||||
|
import aiosmtplib
|
||||||
|
from aiosmtplib import SMTPException
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
aiosmtplib = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
class SMTPException(Exception): # type: ignore[no-redef]
|
||||||
|
pass
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_TIMEOUT_S: Final = 30.0
|
||||||
|
_TlsMode = Literal["auto", "implicit", "starttls", "none"]
|
||||||
|
# RFC 5322 lite: catches the obvious-bad addresses ("foo bar", "no-at",
|
||||||
|
# embedded CRLF) without pretending to fully validate addresses.
|
||||||
|
_EMAIL_RE: Final = re.compile(r"^[^\s@\r\n,;<>]+@[^\s@\r\n,;<>]+\.[^\s@\r\n,;<>]+$")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SmtpConfig:
|
class SmtpConfig:
|
||||||
@@ -22,6 +40,55 @@ class SmtpConfig:
|
|||||||
from_address: str = ""
|
from_address: str = ""
|
||||||
from_name: str = "Notify Bridge"
|
from_name: str = "Notify Bridge"
|
||||||
use_tls: bool = True
|
use_tls: bool = True
|
||||||
|
# Explicit TLS mode. ``auto`` (back-compat) infers from ``use_tls`` and
|
||||||
|
# ``port``: 465 → implicit; 587 with use_tls=False → starttls; 25 → none.
|
||||||
|
tls_mode: _TlsMode = "auto"
|
||||||
|
timeout_s: float = _DEFAULT_TIMEOUT_S
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_header(value: str) -> str:
|
||||||
|
"""Reject CRLF and bare CR/LF in header-bound strings.
|
||||||
|
|
||||||
|
SMTP header injection turns user-controlled subject/name strings into
|
||||||
|
arbitrary headers (``\\r\\nBcc: attacker@x``). The Python stdlib
|
||||||
|
accepts CRLF when followed by SP/HT (header folding), so explicit
|
||||||
|
sanitization is required even though :class:`EmailMessage` does some
|
||||||
|
validation of its own.
|
||||||
|
"""
|
||||||
|
return re.sub(r"[\r\n]+", " ", str(value or "")).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_email(addr: str) -> str:
|
||||||
|
addr = _strip_header(addr)
|
||||||
|
if not addr:
|
||||||
|
raise ValueError("email address is empty")
|
||||||
|
if not _EMAIL_RE.match(addr):
|
||||||
|
raise ValueError("email address is invalid")
|
||||||
|
return addr
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_tls(cfg: SmtpConfig) -> tuple[bool, bool]:
|
||||||
|
"""Resolve ``(use_tls, start_tls)`` flags from the config.
|
||||||
|
|
||||||
|
``tls_mode`` overrides ``use_tls``/port heuristics when provided.
|
||||||
|
"""
|
||||||
|
mode = cfg.tls_mode
|
||||||
|
if mode == "implicit":
|
||||||
|
return True, False
|
||||||
|
if mode == "starttls":
|
||||||
|
return False, True
|
||||||
|
if mode == "none":
|
||||||
|
return False, False
|
||||||
|
# auto — preserve the historical "use_tls bool + port heuristic" behavior
|
||||||
|
# but make the path explicit.
|
||||||
|
if cfg.use_tls:
|
||||||
|
return True, False
|
||||||
|
return False, cfg.port != 25
|
||||||
|
|
||||||
|
|
||||||
|
def _to_html(text: str) -> str:
|
||||||
|
"""Convert plain text to a minimal HTML body, escaped for safety."""
|
||||||
|
return "<html><body><pre>" + html.escape(text) + "</pre></body></html>"
|
||||||
|
|
||||||
|
|
||||||
class EmailClient:
|
class EmailClient:
|
||||||
@@ -30,30 +97,39 @@ class EmailClient:
|
|||||||
def __init__(self, smtp_config: SmtpConfig) -> None:
|
def __init__(self, smtp_config: SmtpConfig) -> None:
|
||||||
self._config = smtp_config
|
self._config = smtp_config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ssl_context() -> ssl.SSLContext:
|
||||||
|
# Explicit context so the TLS posture is auditable; aiosmtplib
|
||||||
|
# defaults look correct today but past regressions (and downstream
|
||||||
|
# repackaging) make implicit reliance fragile.
|
||||||
|
return ssl.create_default_context()
|
||||||
|
|
||||||
async def verify_connection(self) -> dict[str, Any]:
|
async def verify_connection(self) -> dict[str, Any]:
|
||||||
"""Test SMTP connection and authentication without sending an email."""
|
"""Test SMTP connection and authentication without sending an email."""
|
||||||
try:
|
if aiosmtplib is None:
|
||||||
import aiosmtplib
|
|
||||||
except ImportError:
|
|
||||||
return {"success": False, "error": "aiosmtplib not installed"}
|
return {"success": False, "error": "aiosmtplib not installed"}
|
||||||
|
|
||||||
cfg = self._config
|
cfg = self._config
|
||||||
if not cfg.host:
|
if not cfg.host:
|
||||||
return {"success": False, "error": "SMTP host not configured"}
|
return {"success": False, "error": "SMTP host not configured"}
|
||||||
|
|
||||||
|
use_tls, start_tls = _resolve_tls(cfg)
|
||||||
try:
|
try:
|
||||||
smtp = aiosmtplib.SMTP(
|
smtp = aiosmtplib.SMTP(
|
||||||
hostname=cfg.host,
|
hostname=cfg.host,
|
||||||
port=cfg.port,
|
port=cfg.port,
|
||||||
use_tls=cfg.use_tls,
|
use_tls=use_tls,
|
||||||
start_tls=not cfg.use_tls and cfg.port != 25,
|
start_tls=start_tls,
|
||||||
|
tls_context=self._ssl_context(),
|
||||||
|
timeout=cfg.timeout_s,
|
||||||
|
validate_certs=True,
|
||||||
)
|
)
|
||||||
await smtp.connect()
|
await smtp.connect()
|
||||||
if cfg.username and cfg.password:
|
if cfg.username and cfg.password:
|
||||||
await smtp.login(cfg.username, cfg.password)
|
await smtp.login(cfg.username, cfg.password)
|
||||||
await smtp.quit()
|
await smtp.quit()
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
except Exception as e:
|
except (SMTPException, OSError) as e:
|
||||||
_LOGGER.warning("SMTP verification failed for %s:%d: %s", cfg.host, cfg.port, e)
|
_LOGGER.warning("SMTP verification failed for %s:%d: %s", cfg.host, cfg.port, e)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
@@ -65,27 +141,52 @@ class EmailClient:
|
|||||||
body_html: str | None = None,
|
body_html: str | None = None,
|
||||||
to_name: str = "",
|
to_name: str = "",
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Send an email. Returns {"success": True} or {"success": False, "error": "..."}."""
|
"""Send an email.
|
||||||
try:
|
|
||||||
import aiosmtplib
|
Returns ``{"success": True}`` or ``{"success": False, "error": "..."}``.
|
||||||
except ImportError:
|
|
||||||
|
``body_html`` is treated as already-safe markup. Pass ``None`` to
|
||||||
|
derive a safe HTML alternative from ``body_text`` automatically.
|
||||||
|
"""
|
||||||
|
if aiosmtplib is None:
|
||||||
return {"success": False, "error": "aiosmtplib not installed. Run: pip install aiosmtplib"}
|
return {"success": False, "error": "aiosmtplib not installed. Run: pip install aiosmtplib"}
|
||||||
|
|
||||||
cfg = self._config
|
cfg = self._config
|
||||||
|
|
||||||
if not cfg.host or not cfg.from_address:
|
if not cfg.host or not cfg.from_address:
|
||||||
return {"success": False, "error": "SMTP not configured (missing host or from_address)"}
|
return {"success": False, "error": "SMTP not configured (missing host or from_address)"}
|
||||||
|
|
||||||
# Build email message
|
try:
|
||||||
msg = MIMEMultipart("alternative")
|
to_addr = _validate_email(to_email)
|
||||||
msg["From"] = f"{cfg.from_name} <{cfg.from_address}>" if cfg.from_name else cfg.from_address
|
from_addr = _validate_email(cfg.from_address)
|
||||||
msg["To"] = f"{to_name} <{to_email}>" if to_name else to_email
|
except ValueError as exc:
|
||||||
msg["Subject"] = subject
|
return {"success": False, "error": f"Invalid email address: {exc}"}
|
||||||
|
|
||||||
msg.attach(MIMEText(body_text, "plain", "utf-8"))
|
# EmailMessage with structured Address objects rejects CRLF and
|
||||||
if body_html:
|
# framework-folds long headers safely. We still strip first because
|
||||||
msg.attach(MIMEText(body_html, "html", "utf-8"))
|
# EmailMessage's display-name slot is a pure string.
|
||||||
|
msg = EmailMessage()
|
||||||
|
from_display = _strip_header(cfg.from_name) or ""
|
||||||
|
to_display = _strip_header(to_name) or ""
|
||||||
|
try:
|
||||||
|
from_user, _, from_domain = from_addr.partition("@")
|
||||||
|
to_user, _, to_domain = to_addr.partition("@")
|
||||||
|
msg["From"] = Address(from_display, from_user, from_domain) if from_display else from_addr
|
||||||
|
msg["To"] = Address(to_display, to_user, to_domain) if to_display else to_addr
|
||||||
|
except ValueError as exc:
|
||||||
|
return {"success": False, "error": f"Invalid email address: {exc}"}
|
||||||
|
msg["Subject"] = _strip_header(subject)
|
||||||
|
|
||||||
|
msg.set_content(body_text or "", subtype="plain", charset="utf-8")
|
||||||
|
# If the caller provided HTML explicitly, honor it; otherwise build a
|
||||||
|
# safe escaped version so a stray "<" in the rendered template can't
|
||||||
|
# break the markup.
|
||||||
|
msg.add_alternative(
|
||||||
|
body_html if body_html is not None else _to_html(body_text or ""),
|
||||||
|
subtype="html",
|
||||||
|
charset="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
use_tls, start_tls = _resolve_tls(cfg)
|
||||||
try:
|
try:
|
||||||
await aiosmtplib.send(
|
await aiosmtplib.send(
|
||||||
msg,
|
msg,
|
||||||
@@ -93,11 +194,14 @@ class EmailClient:
|
|||||||
port=cfg.port,
|
port=cfg.port,
|
||||||
username=cfg.username or None,
|
username=cfg.username or None,
|
||||||
password=cfg.password or None,
|
password=cfg.password or None,
|
||||||
use_tls=cfg.use_tls,
|
use_tls=use_tls,
|
||||||
start_tls=not cfg.use_tls and cfg.port != 25,
|
start_tls=start_tls,
|
||||||
|
tls_context=self._ssl_context(),
|
||||||
|
timeout=cfg.timeout_s,
|
||||||
|
validate_certs=True,
|
||||||
)
|
)
|
||||||
_LOGGER.info("Email sent to %s", to_email)
|
_LOGGER.info("Email sent to %s", to_addr)
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
except Exception as e:
|
except (SMTPException, OSError) as e:
|
||||||
_LOGGER.error("Failed to send email to %s: %s", to_email, e)
|
_LOGGER.error("Failed to send email to %s: %s", to_addr, e)
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
|||||||
@@ -0,0 +1,196 @@
|
|||||||
|
"""Shared HTTP infrastructure for notification provider clients.
|
||||||
|
|
||||||
|
Slack/Discord/ntfy/Matrix/Webhook all follow the same pattern: build a
|
||||||
|
JSON payload, POST/PUT it, decode 200-range as success, decode 4xx/5xx
|
||||||
|
into a stable error dict, and retry transient 429/503 responses with a
|
||||||
|
capped ``Retry-After``. ``HttpProviderClient`` centralizes that pattern
|
||||||
|
so every provider gets the same SSRF guard, timeouts, secret-redacted
|
||||||
|
errors, and bounded retry policy by construction — adding a new
|
||||||
|
provider doesn't get to forget any one of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Final, Mapping
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from .redact import redact, redact_exc
|
||||||
|
from .ssrf import UnsafeURLError, avalidate_outbound_url
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
|
||||||
|
_MAX_RETRIES: Final = 3
|
||||||
|
_MAX_RETRY_AFTER_S: Final = 60.0
|
||||||
|
_RETRY_STATUSES: Final = frozenset({429, 503})
|
||||||
|
|
||||||
|
# Hop-by-hop / framing headers a caller must not be able to override via
|
||||||
|
# user-supplied target config. Letting them through enables request
|
||||||
|
# smuggling, host-header bypasses of WAFs, and cache poisoning.
|
||||||
|
_FORBIDDEN_HEADERS: Final = frozenset({
|
||||||
|
"host",
|
||||||
|
"content-length",
|
||||||
|
"transfer-encoding",
|
||||||
|
"connection",
|
||||||
|
"keep-alive",
|
||||||
|
"te",
|
||||||
|
"upgrade",
|
||||||
|
"expect",
|
||||||
|
"proxy-authorization",
|
||||||
|
"proxy-connection",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def safe_headers(headers: Mapping[str, str] | None) -> dict[str, str]:
|
||||||
|
"""Return a copy of ``headers`` with hop-by-hop/forbidden names dropped
|
||||||
|
and CRLF-bearing values rejected.
|
||||||
|
|
||||||
|
A target config that lets a user inject ``"X-Foo": "bar\\r\\nHost: evil"``
|
||||||
|
can perform request smuggling depending on aiohttp's framing. We strip
|
||||||
|
those values rather than letting them reach the wire.
|
||||||
|
"""
|
||||||
|
if not headers:
|
||||||
|
return {}
|
||||||
|
safe: dict[str, str] = {}
|
||||||
|
for raw_name, raw_value in headers.items():
|
||||||
|
name = str(raw_name).strip()
|
||||||
|
if not name or name.lower() in _FORBIDDEN_HEADERS:
|
||||||
|
continue
|
||||||
|
if any(c in name for c in "\r\n:"):
|
||||||
|
continue
|
||||||
|
value = str(raw_value)
|
||||||
|
if "\r" in value or "\n" in value:
|
||||||
|
continue
|
||||||
|
safe[name] = value
|
||||||
|
return safe
|
||||||
|
|
||||||
|
|
||||||
|
def make_error(message: str, *, status: int | None = None, body: str | None = None) -> dict[str, Any]:
|
||||||
|
"""Build a stable failure dict shape used by every provider client."""
|
||||||
|
err: dict[str, Any] = {"success": False, "error": redact(message)}
|
||||||
|
if status is not None:
|
||||||
|
err["status_code"] = status
|
||||||
|
if body:
|
||||||
|
err["body"] = redact(body)[:200]
|
||||||
|
return err
|
||||||
|
|
||||||
|
|
||||||
|
def make_success(**extra: Any) -> dict[str, Any]:
|
||||||
|
"""Build a stable success dict shape used by every provider client."""
|
||||||
|
out: dict[str, Any] = {"success": True}
|
||||||
|
out.update(extra)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _retry_after_seconds(headers: Mapping[str, str], cap_s: float) -> float:
|
||||||
|
raw = headers.get("Retry-After") or headers.get("retry-after") or "2"
|
||||||
|
try:
|
||||||
|
seconds = float(raw)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
seconds = 2.0
|
||||||
|
return max(0.0, min(seconds, cap_s))
|
||||||
|
|
||||||
|
|
||||||
|
class HttpProviderClient:
|
||||||
|
"""Base for JSON-over-HTTP notification providers.
|
||||||
|
|
||||||
|
Subclasses call :meth:`request` instead of using ``self._session``
|
||||||
|
directly. ``request`` runs the SSRF guard (skippable for known-safe
|
||||||
|
upstreams via ``ssrf_validate=False``), enforces a per-request
|
||||||
|
timeout, retries 429/503 with a capped ``Retry-After``, and turns
|
||||||
|
transport/HTTP errors into the canonical ``{"success": False, ...}``
|
||||||
|
shape with secrets redacted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_max_retries: int = _MAX_RETRIES
|
||||||
|
# Settable per-instance so tests / hostile-upstream tuning can
|
||||||
|
# tighten the cap. Reads of this attribute fall through to the
|
||||||
|
# class default when no instance value has been set.
|
||||||
|
_MAX_RETRY_AFTER: float = _MAX_RETRY_AFTER_S
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
*,
|
||||||
|
timeout: aiohttp.ClientTimeout | None = None,
|
||||||
|
provider_name: str = "http",
|
||||||
|
) -> None:
|
||||||
|
self._session = session
|
||||||
|
self._timeout = timeout or _DEFAULT_TIMEOUT
|
||||||
|
self._provider = provider_name
|
||||||
|
|
||||||
|
async def request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
json: Any = None,
|
||||||
|
headers: Mapping[str, str] | None = None,
|
||||||
|
ssrf_validate: bool = True,
|
||||||
|
retry_statuses: frozenset[int] = _RETRY_STATUSES,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send a single request with retry + redaction. Always returns a dict.
|
||||||
|
|
||||||
|
On 2xx returns ``{"success": True, "status_code": int, "json": ...
|
||||||
|
OR "body": str}``. On non-2xx returns the canonical error dict.
|
||||||
|
"""
|
||||||
|
if ssrf_validate:
|
||||||
|
try:
|
||||||
|
await avalidate_outbound_url(url)
|
||||||
|
except UnsafeURLError as err:
|
||||||
|
return make_error(f"Unsafe URL: {redact_exc(err)}")
|
||||||
|
|
||||||
|
outbound_headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||||
|
outbound_headers.update(safe_headers(headers))
|
||||||
|
|
||||||
|
for attempt in range(1, self._max_retries + 1):
|
||||||
|
try:
|
||||||
|
async with self._session.request(
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
json=json,
|
||||||
|
headers=outbound_headers,
|
||||||
|
timeout=self._timeout,
|
||||||
|
allow_redirects=False,
|
||||||
|
) as resp:
|
||||||
|
if resp.status in retry_statuses and attempt < self._max_retries:
|
||||||
|
delay = _retry_after_seconds(resp.headers, self._MAX_RETRY_AFTER)
|
||||||
|
_LOGGER.warning(
|
||||||
|
"%s %s %s: HTTP %d, retrying after %.2fs (attempt %d/%d)",
|
||||||
|
self._provider, method, redact(url), resp.status,
|
||||||
|
delay, attempt, self._max_retries,
|
||||||
|
)
|
||||||
|
await resp.read() # drain body so connection can return to pool
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
continue
|
||||||
|
if 200 <= resp.status < 300:
|
||||||
|
try:
|
||||||
|
payload: Any = await resp.json(content_type=None)
|
||||||
|
except (aiohttp.ContentTypeError, ValueError):
|
||||||
|
payload = await resp.text()
|
||||||
|
return make_success(status_code=resp.status, json=payload)
|
||||||
|
body = await resp.text()
|
||||||
|
return make_error(
|
||||||
|
f"HTTP {resp.status}",
|
||||||
|
status=resp.status,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err:
|
||||||
|
# asyncio.CancelledError inherits from BaseException on
|
||||||
|
# 3.8+, so it is not caught here — good: cancellation
|
||||||
|
# must propagate.
|
||||||
|
if attempt < self._max_retries and isinstance(err, asyncio.TimeoutError):
|
||||||
|
_LOGGER.warning(
|
||||||
|
"%s %s %s: timeout, retrying (attempt %d/%d)",
|
||||||
|
self._provider, method, redact(url),
|
||||||
|
attempt, self._max_retries,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(min(2 ** (attempt - 1), 5))
|
||||||
|
continue
|
||||||
|
return make_error(redact_exc(err))
|
||||||
|
|
||||||
|
# Retry budget exhausted on a retriable status.
|
||||||
|
return make_error("Rate limited (retries exhausted)")
|
||||||
@@ -2,22 +2,36 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import re
|
||||||
from typing import Any
|
import uuid
|
||||||
|
from typing import Any, Final
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from ..http_base import _MAX_RETRY_AFTER_S, safe_headers
|
||||||
|
from ..redact import redact, redact_exc
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Monotonically increasing transaction counter for idempotent sends
|
# Matrix room IDs are ``!opaque:server.name`` per the spec. We also allow
|
||||||
_txn_counter = int(time.time() * 1000)
|
# the ``#alias:server`` form because some callers may pass aliases. The
|
||||||
|
# pattern's purpose is to reject obvious path-injection (``/``, ``..``,
|
||||||
|
# control chars, query/fragment chars) before the value reaches a URL.
|
||||||
|
_ROOM_ID_RE: Final = re.compile(r"^[!#][^\x00-\x1f\s/?#]{1,255}:[A-Za-z0-9.\-:]{1,255}$")
|
||||||
|
|
||||||
|
_DEFAULT_TIMEOUT: Final = aiohttp.ClientTimeout(total=30, connect=10)
|
||||||
|
_MAX_RETRIES: Final = 3
|
||||||
|
|
||||||
|
|
||||||
def _next_txn_id() -> str:
|
def _validate_room_id(room_id: str) -> str:
|
||||||
global _txn_counter
|
if not room_id:
|
||||||
_txn_counter += 1
|
raise ValueError("room_id is empty")
|
||||||
return str(_txn_counter)
|
if not _ROOM_ID_RE.match(room_id):
|
||||||
|
raise ValueError("room_id format is invalid")
|
||||||
|
return room_id
|
||||||
|
|
||||||
|
|
||||||
class MatrixClient:
|
class MatrixClient:
|
||||||
@@ -33,49 +47,67 @@ class MatrixClient:
|
|||||||
self._homeserver = homeserver_url.rstrip("/")
|
self._homeserver = homeserver_url.rstrip("/")
|
||||||
self._token = access_token
|
self._token = access_token
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _txn_id() -> str:
|
||||||
|
# uuid4 hex is collision-resistant across processes/restarts;
|
||||||
|
# eliminates the previous module-level counter race.
|
||||||
|
return uuid.uuid4().hex
|
||||||
|
|
||||||
async def send_message(
|
async def send_message(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
html_message: str | None = None,
|
html_message: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Send a text message to a Matrix room.
|
"""Send a text message to a Matrix room."""
|
||||||
|
try:
|
||||||
|
room_id = _validate_room_id(room_id)
|
||||||
|
except ValueError as exc:
|
||||||
|
return {"success": False, "error": f"Invalid room_id: {exc}"}
|
||||||
|
|
||||||
Args:
|
encoded_room = quote(room_id, safe="")
|
||||||
room_id: Internal room ID (e.g. !abc:matrix.org)
|
url = (
|
||||||
message: Plain text body
|
f"{self._homeserver}/_matrix/client/v3/rooms/{encoded_room}"
|
||||||
html_message: Optional HTML-formatted body
|
f"/send/m.room.message/{self._txn_id()}"
|
||||||
"""
|
)
|
||||||
if not room_id:
|
|
||||||
return {"success": False, "error": "Missing room_id"}
|
|
||||||
|
|
||||||
txn_id = _next_txn_id()
|
body: dict[str, Any] = {"msgtype": "m.text", "body": message}
|
||||||
# URL-encode the room_id (! and : need encoding)
|
|
||||||
encoded_room = room_id.replace("!", "%21").replace(":", "%3A")
|
|
||||||
url = f"{self._homeserver}/_matrix/client/v3/rooms/{encoded_room}/send/m.room.message/{txn_id}"
|
|
||||||
|
|
||||||
body: dict[str, Any] = {
|
|
||||||
"msgtype": "m.text",
|
|
||||||
"body": message,
|
|
||||||
}
|
|
||||||
if html_message:
|
if html_message:
|
||||||
body["format"] = "org.matrix.custom.html"
|
body["format"] = "org.matrix.custom.html"
|
||||||
body["formatted_body"] = html_message
|
body["formatted_body"] = html_message
|
||||||
|
|
||||||
headers = {
|
headers = safe_headers({
|
||||||
"Authorization": f"Bearer {self._token}",
|
"Authorization": f"Bearer {self._token}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
}
|
})
|
||||||
|
|
||||||
|
for attempt in range(1, _MAX_RETRIES + 1):
|
||||||
try:
|
try:
|
||||||
async with self._session.put(
|
async with self._session.put(
|
||||||
url, json=body, headers=headers, allow_redirects=False,
|
url, json=body, headers=headers,
|
||||||
|
timeout=_DEFAULT_TIMEOUT, allow_redirects=False,
|
||||||
) as resp:
|
) as resp:
|
||||||
if 200 <= resp.status < 300:
|
if 200 <= resp.status < 300:
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
resp_body = await resp.text()
|
resp_body = await resp.text()
|
||||||
if resp.status == 429:
|
if resp.status == 429 and attempt < _MAX_RETRIES:
|
||||||
_LOGGER.warning("Matrix rate limited: %s", resp_body[:200])
|
try:
|
||||||
return {"success": False, "error": f"HTTP {resp.status}: {resp_body[:200]}"}
|
wait_s = float(resp.headers.get("Retry-After", "2"))
|
||||||
except aiohttp.ClientError as e:
|
except (TypeError, ValueError):
|
||||||
return {"success": False, "error": str(e)}
|
wait_s = 2.0
|
||||||
|
wait_s = max(0.0, min(wait_s, _MAX_RETRY_AFTER_S))
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Matrix rate limited, retrying after %.2fs (attempt %d/%d)",
|
||||||
|
wait_s, attempt, _MAX_RETRIES,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait_s)
|
||||||
|
continue
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"HTTP {resp.status}: {redact(resp_body)[:200]}",
|
||||||
|
"status_code": resp.status,
|
||||||
|
}
|
||||||
|
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e:
|
||||||
|
return {"success": False, "error": redact_exc(e)}
|
||||||
|
|
||||||
|
return {"success": False, "error": "Rate limited (retries exhausted)"}
|
||||||
|
|||||||
@@ -3,18 +3,33 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, Final
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from ..http_base import HttpProviderClient
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_PRIORITY_MIN: Final = 1
|
||||||
|
_PRIORITY_MAX: Final = 5
|
||||||
|
_DEFAULT_PRIORITY: Final = 3
|
||||||
|
_MAX_TAGS: Final = 10
|
||||||
|
_MAX_TAG_LEN: Final = 64
|
||||||
|
|
||||||
class NtfyClient:
|
|
||||||
|
def _strip_crlf(value: str) -> str:
|
||||||
|
"""Remove CR/LF — ntfy's JSON path is safe today, but the same fields
|
||||||
|
are used by the header API; defensive sanitization here means a future
|
||||||
|
refactor can't accidentally re-introduce header injection."""
|
||||||
|
return value.replace("\r", " ").replace("\n", " ")
|
||||||
|
|
||||||
|
|
||||||
|
class NtfyClient(HttpProviderClient):
|
||||||
"""Sends push notifications via ntfy server."""
|
"""Sends push notifications via ntfy server."""
|
||||||
|
|
||||||
def __init__(self, session: aiohttp.ClientSession) -> None:
|
def __init__(self, session: aiohttp.ClientSession) -> None:
|
||||||
self._session = session
|
super().__init__(session, provider_name="ntfy")
|
||||||
|
|
||||||
async def send(
|
async def send(
|
||||||
self,
|
self,
|
||||||
@@ -22,41 +37,48 @@ class NtfyClient:
|
|||||||
topic: str,
|
topic: str,
|
||||||
message: str,
|
message: str,
|
||||||
title: str | None = None,
|
title: str | None = None,
|
||||||
priority: int = 3,
|
priority: int = _DEFAULT_PRIORITY,
|
||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
click_url: str | None = None,
|
click_url: str | None = None,
|
||||||
auth_token: str | None = None,
|
auth_token: str | None = None,
|
||||||
|
markdown: bool = True,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Send a push notification to an ntfy topic."""
|
"""Send a push notification to an ntfy topic."""
|
||||||
if not server_url or not topic:
|
if not server_url or not topic:
|
||||||
return {"success": False, "error": "Missing server_url or topic"}
|
return {"success": False, "error": "Missing server_url or topic"}
|
||||||
|
|
||||||
url = f"{server_url.rstrip('/')}"
|
topic = _strip_crlf(topic).strip()
|
||||||
|
if not topic:
|
||||||
|
return {"success": False, "error": "Topic is empty after sanitization"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
priority_int = int(priority) if priority is not None else _DEFAULT_PRIORITY
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
priority_int = _DEFAULT_PRIORITY
|
||||||
|
priority_int = max(_PRIORITY_MIN, min(priority_int, _PRIORITY_MAX))
|
||||||
|
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"topic": topic,
|
"topic": topic,
|
||||||
"message": message,
|
"message": message,
|
||||||
"markdown": True,
|
"markdown": bool(markdown),
|
||||||
}
|
}
|
||||||
if title:
|
if title:
|
||||||
payload["title"] = title
|
payload["title"] = _strip_crlf(title)
|
||||||
if priority != 3:
|
if priority_int != _DEFAULT_PRIORITY:
|
||||||
payload["priority"] = priority
|
payload["priority"] = priority_int
|
||||||
if tags:
|
if tags:
|
||||||
payload["tags"] = tags
|
cleaned = [
|
||||||
|
_strip_crlf(str(t))[:_MAX_TAG_LEN]
|
||||||
|
for t in tags[:_MAX_TAGS]
|
||||||
|
if t
|
||||||
|
]
|
||||||
|
if cleaned:
|
||||||
|
payload["tags"] = cleaned
|
||||||
if click_url:
|
if click_url:
|
||||||
payload["click"] = click_url
|
payload["click"] = _strip_crlf(click_url)
|
||||||
|
|
||||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
headers: dict[str, str] = {}
|
||||||
if auth_token:
|
if auth_token:
|
||||||
headers["Authorization"] = f"Bearer {auth_token}"
|
headers["Authorization"] = f"Bearer {auth_token}"
|
||||||
|
|
||||||
try:
|
return await self.request("POST", server_url.rstrip("/"), json=payload, headers=headers)
|
||||||
async with self._session.post(
|
|
||||||
url, json=payload, headers=headers, allow_redirects=False,
|
|
||||||
) as resp:
|
|
||||||
if 200 <= resp.status < 300:
|
|
||||||
return {"success": True}
|
|
||||||
body = await resp.text()
|
|
||||||
return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
|
|||||||
@@ -2,47 +2,88 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any, Final
|
||||||
|
|
||||||
from notify_bridge_core.storage import StorageBackend
|
from notify_bridge_core.storage import StorageBackend
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Bound on queue length. Without a cap, a misconfigured quiet-hour
|
||||||
|
# window plus high event throughput grows the persisted file unboundedly
|
||||||
|
# and every enqueue rewrites the whole file (O(n²) total writes). When
|
||||||
|
# the cap is hit we drop the oldest entry (FIFO) so the most recent
|
||||||
|
# events still reach the recipient when the window opens.
|
||||||
|
DEFAULT_MAX_QUEUE_SIZE: Final = 1000
|
||||||
|
|
||||||
|
|
||||||
class NotificationQueue:
|
class NotificationQueue:
|
||||||
"""Persistent queue for notifications deferred during quiet hours."""
|
"""Persistent queue for notifications deferred during quiet hours."""
|
||||||
|
|
||||||
def __init__(self, backend: StorageBackend) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
backend: StorageBackend,
|
||||||
|
*,
|
||||||
|
max_size: int = DEFAULT_MAX_QUEUE_SIZE,
|
||||||
|
) -> None:
|
||||||
self._backend = backend
|
self._backend = backend
|
||||||
self._data: dict[str, Any] | None = None
|
self._data: dict[str, Any] | None = None
|
||||||
|
self._max_size = max_size
|
||||||
|
# Coordinates load / enqueue / clear / remove so a write-while-load
|
||||||
|
# race can't leave the in-memory copy out of sync with disk and so
|
||||||
|
# bulk operations don't interleave their reads-then-writes.
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_schema(data: Any) -> dict[str, Any]:
|
||||||
|
if not isinstance(data, dict) or not isinstance(data.get("queue"), list):
|
||||||
|
return {"queue": []}
|
||||||
|
return data
|
||||||
|
|
||||||
async def async_load(self) -> None:
|
async def async_load(self) -> None:
|
||||||
self._data = await self._backend.load() or {"queue": []}
|
async with self._lock:
|
||||||
|
raw = await self._backend.load()
|
||||||
|
self._data = self._ensure_schema(raw)
|
||||||
|
|
||||||
async def async_enqueue(self, notification_params: dict[str, Any]) -> None:
|
async def async_enqueue(self, notification_params: dict[str, Any]) -> None:
|
||||||
|
async with self._lock:
|
||||||
if self._data is None:
|
if self._data is None:
|
||||||
self._data = {"queue": []}
|
self._data = {"queue": []}
|
||||||
self._data["queue"].append({
|
queue: list[dict[str, Any]] = self._data["queue"]
|
||||||
|
queue.append({
|
||||||
"params": notification_params,
|
"params": notification_params,
|
||||||
"queued_at": datetime.now(timezone.utc).isoformat(),
|
"queued_at": datetime.now(timezone.utc).isoformat(),
|
||||||
})
|
})
|
||||||
|
if self._max_size > 0 and len(queue) > self._max_size:
|
||||||
|
# Drop oldest (FIFO) so a new event can still land.
|
||||||
|
drop = len(queue) - self._max_size
|
||||||
|
_LOGGER.warning(
|
||||||
|
"NotificationQueue: dropping %d oldest entries (cap=%d)",
|
||||||
|
drop, self._max_size,
|
||||||
|
)
|
||||||
|
del queue[:drop]
|
||||||
await self._backend.save(self._data)
|
await self._backend.save(self._data)
|
||||||
|
|
||||||
def get_all(self) -> list[dict[str, Any]]:
|
def get_all(self) -> list[dict[str, Any]]:
|
||||||
if not self._data:
|
if not self._data:
|
||||||
return []
|
return []
|
||||||
return list(self._data.get("queue", []))
|
# Deep copy so callers can iterate / mutate without corrupting the
|
||||||
|
# in-memory queue. The cost is bounded by ``max_size``.
|
||||||
|
return copy.deepcopy(list(self._data.get("queue", [])))
|
||||||
|
|
||||||
def has_pending(self) -> bool:
|
def has_pending(self) -> bool:
|
||||||
return bool(self._data and self._data.get("queue"))
|
return bool(self._data and self._data.get("queue"))
|
||||||
|
|
||||||
async def async_clear(self) -> None:
|
async def async_clear(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
if self._data:
|
if self._data:
|
||||||
self._data["queue"] = []
|
self._data["queue"] = []
|
||||||
await self._backend.save(self._data)
|
await self._backend.save(self._data)
|
||||||
|
|
||||||
async def async_remove(self) -> None:
|
async def async_remove(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
await self._backend.remove()
|
await self._backend.remove()
|
||||||
self._data = None
|
self._data = None
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -70,51 +70,64 @@ class MatrixReceiver(Receiver):
|
|||||||
room_id: str = ""
|
room_id: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
_ReceiverFactory = Callable[[str, dict[str, Any]], Receiver]
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_int(value: Any, default: int) -> int:
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
_RECEIVER_FACTORIES: dict[str, _ReceiverFactory] = {
|
||||||
|
"telegram": lambda locale, config: TelegramReceiver(
|
||||||
|
locale=locale, config=config, chat_id=str(config.get("chat_id", "")),
|
||||||
|
),
|
||||||
|
"webhook": lambda locale, config: WebhookReceiver(
|
||||||
|
locale=locale, config=config,
|
||||||
|
url=str(config.get("url", "")),
|
||||||
|
headers=dict(config.get("headers", {}) or {}),
|
||||||
|
),
|
||||||
|
"email": lambda locale, config: EmailReceiver(
|
||||||
|
locale=locale, config=config,
|
||||||
|
email=str(config.get("email", "")),
|
||||||
|
name=str(config.get("name", "")),
|
||||||
|
),
|
||||||
|
"discord": lambda locale, config: DiscordReceiver(
|
||||||
|
locale=locale, config=config,
|
||||||
|
webhook_url=str(config.get("webhook_url", "")),
|
||||||
|
),
|
||||||
|
"slack": lambda locale, config: SlackReceiver(
|
||||||
|
locale=locale, config=config,
|
||||||
|
webhook_url=str(config.get("webhook_url", "")),
|
||||||
|
),
|
||||||
|
"ntfy": lambda locale, config: NtfyReceiver(
|
||||||
|
locale=locale, config=config,
|
||||||
|
topic=str(config.get("topic", "")),
|
||||||
|
priority=_coerce_int(config.get("priority"), 3),
|
||||||
|
),
|
||||||
|
"matrix": lambda locale, config: MatrixReceiver(
|
||||||
|
locale=locale, config=config,
|
||||||
|
room_id=str(config.get("room_id", "")),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_receiver_factory(target_type: str, factory: _ReceiverFactory) -> None:
|
||||||
|
"""Register a receiver factory for an out-of-tree target type."""
|
||||||
|
_RECEIVER_FACTORIES[target_type] = factory
|
||||||
|
|
||||||
|
|
||||||
def build_receiver(target_type: str, config: dict[str, Any], locale: str = "") -> Receiver:
|
def build_receiver(target_type: str, config: dict[str, Any], locale: str = "") -> Receiver:
|
||||||
"""Factory: build typed Receiver from target type and config dict."""
|
"""Factory: build typed Receiver from target type and config dict.
|
||||||
if target_type == "telegram":
|
|
||||||
return TelegramReceiver(
|
Falls back to a base ``Receiver`` for unknown target types so callers
|
||||||
locale=locale,
|
that handle types defensively still receive a usable object — but the
|
||||||
config=config,
|
dispatcher rejects them with ``"Unknown target type"`` so a typo can't
|
||||||
chat_id=str(config.get("chat_id", "")),
|
silently route to nowhere.
|
||||||
)
|
"""
|
||||||
if target_type == "webhook":
|
factory = _RECEIVER_FACTORIES.get(target_type)
|
||||||
return WebhookReceiver(
|
if factory is None:
|
||||||
locale=locale,
|
|
||||||
config=config,
|
|
||||||
url=config.get("url", ""),
|
|
||||||
headers=config.get("headers", {}),
|
|
||||||
)
|
|
||||||
if target_type == "email":
|
|
||||||
return EmailReceiver(
|
|
||||||
locale=locale,
|
|
||||||
config=config,
|
|
||||||
email=config.get("email", ""),
|
|
||||||
name=config.get("name", ""),
|
|
||||||
)
|
|
||||||
if target_type == "discord":
|
|
||||||
return DiscordReceiver(
|
|
||||||
locale=locale,
|
|
||||||
config=config,
|
|
||||||
webhook_url=config.get("webhook_url", ""),
|
|
||||||
)
|
|
||||||
if target_type == "slack":
|
|
||||||
return SlackReceiver(
|
|
||||||
locale=locale,
|
|
||||||
config=config,
|
|
||||||
webhook_url=config.get("webhook_url", ""),
|
|
||||||
)
|
|
||||||
if target_type == "ntfy":
|
|
||||||
return NtfyReceiver(
|
|
||||||
locale=locale,
|
|
||||||
config=config,
|
|
||||||
topic=config.get("topic", ""),
|
|
||||||
priority=config.get("priority", 3),
|
|
||||||
)
|
|
||||||
if target_type == "matrix":
|
|
||||||
return MatrixReceiver(
|
|
||||||
locale=locale,
|
|
||||||
config=config,
|
|
||||||
room_id=config.get("room_id", ""),
|
|
||||||
)
|
|
||||||
return Receiver(locale=locale, config=config)
|
return Receiver(locale=locale, config=config)
|
||||||
|
return factory(locale, config)
|
||||||
|
|||||||
@@ -0,0 +1,64 @@
|
|||||||
|
"""Secret-redaction helpers for log lines and error strings.
|
||||||
|
|
||||||
|
Notification clients embed secrets in URLs (Telegram bot tokens) and
|
||||||
|
Authorization headers (Matrix access tokens, ntfy bearer tokens). When
|
||||||
|
those secrets surface in ``aiohttp.ClientError.__str__``, response
|
||||||
|
bodies, or operator-visible error fields, they leak into logs and into
|
||||||
|
the per-target result dict that callers may forward upstream. ``redact``
|
||||||
|
returns a defanged copy safe for both contexts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
# api.telegram.org/bot<digits>:<token>/<method>
|
||||||
|
_TELEGRAM_BOT_TOKEN_RE: Final = re.compile(
|
||||||
|
r"(api\.telegram\.org/bot)\d+:[A-Za-z0-9_-]+", re.IGNORECASE,
|
||||||
|
)
|
||||||
|
# Authorization: Bearer <token> (header form, case-insensitive)
|
||||||
|
_BEARER_RE: Final = re.compile(r"(Bearer\s+)[A-Za-z0-9._\-+/=]+", re.IGNORECASE)
|
||||||
|
# Discord webhook: /api/webhooks/<id>/<token>
|
||||||
|
_DISCORD_WEBHOOK_RE: Final = re.compile(
|
||||||
|
r"(discord(?:app)?\.com/api/webhooks/\d+/)[A-Za-z0-9_-]+",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
# Slack webhook path: /services/T.../B.../<token>
|
||||||
|
_SLACK_WEBHOOK_RE: Final = re.compile(
|
||||||
|
r"(hooks\.slack\.com/services/[A-Z0-9]+/[A-Z0-9]+/)[A-Za-z0-9]+",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
# URL userinfo: scheme://user:password@host
|
||||||
|
_URL_USERINFO_RE: Final = re.compile(
|
||||||
|
r"([a-z][a-z0-9+\-.]*://)[^/@\s]+:[^/@\s]+@",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
# Common token query parameters
|
||||||
|
_QUERY_TOKEN_RE: Final = re.compile(
|
||||||
|
r"([?&](?:token|access_token|api_key|key|secret|password)=)[^&\s]+",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def redact(text: str) -> str:
|
||||||
|
"""Return ``text`` with known secret patterns replaced by ``***``.
|
||||||
|
|
||||||
|
Idempotent and safe to call on already-redacted strings. Always
|
||||||
|
returns a ``str``; non-strings are coerced via ``str()`` so callers
|
||||||
|
can pass exception instances directly.
|
||||||
|
"""
|
||||||
|
if not isinstance(text, str):
|
||||||
|
text = str(text)
|
||||||
|
text = _TELEGRAM_BOT_TOKEN_RE.sub(r"\1***", text)
|
||||||
|
text = _DISCORD_WEBHOOK_RE.sub(r"\1***", text)
|
||||||
|
text = _SLACK_WEBHOOK_RE.sub(r"\1***", text)
|
||||||
|
text = _BEARER_RE.sub(r"\1***", text)
|
||||||
|
text = _URL_USERINFO_RE.sub(r"\1***@", text)
|
||||||
|
text = _QUERY_TOKEN_RE.sub(r"\1***", text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def redact_exc(err: BaseException) -> str:
|
||||||
|
"""Redact-and-stringify an exception. Convenience for error fields."""
|
||||||
|
return redact(str(err))
|
||||||
@@ -7,14 +7,16 @@ from typing import Any
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from ..http_base import HttpProviderClient
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SlackClient:
|
class SlackClient(HttpProviderClient):
|
||||||
"""Sends messages via Slack incoming webhook URLs."""
|
"""Sends messages via Slack incoming webhook URLs."""
|
||||||
|
|
||||||
def __init__(self, session: aiohttp.ClientSession) -> None:
|
def __init__(self, session: aiohttp.ClientSession) -> None:
|
||||||
self._session = session
|
super().__init__(session, provider_name="slack")
|
||||||
|
|
||||||
async def send(
|
async def send(
|
||||||
self,
|
self,
|
||||||
@@ -33,19 +35,4 @@ class SlackClient:
|
|||||||
if icon_emoji:
|
if icon_emoji:
|
||||||
payload["icon_emoji"] = icon_emoji
|
payload["icon_emoji"] = icon_emoji
|
||||||
|
|
||||||
try:
|
return await self.request("POST", webhook_url, json=payload)
|
||||||
async with self._session.post(
|
|
||||||
webhook_url,
|
|
||||||
json=payload,
|
|
||||||
headers={"Content-Type": "application/json"},
|
|
||||||
allow_redirects=False,
|
|
||||||
) as resp:
|
|
||||||
if resp.status == 429:
|
|
||||||
_LOGGER.warning("Slack rate limited")
|
|
||||||
return {"success": False, "error": "Rate limited by Slack"}
|
|
||||||
if 200 <= resp.status < 300:
|
|
||||||
return {"success": True}
|
|
||||||
body = await resp.text()
|
|
||||||
return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
|
|||||||
@@ -1,10 +1,22 @@
|
|||||||
"""Outbound URL validation to mitigate SSRF attacks.
|
"""Outbound URL validation to mitigate SSRF attacks.
|
||||||
|
|
||||||
User-controlled URLs (provider `url`, webhook target `url`, shared-link
|
User-controlled URLs (provider ``url``, webhook target ``url``,
|
||||||
base URLs, image downloads) must be validated before any HTTP request is
|
shared-link base URLs, image downloads) must be validated before any
|
||||||
issued. This module rejects schemes other than http/https and blocks
|
HTTP request is issued. This module rejects schemes other than
|
||||||
destinations that resolve to private, loopback, link-local, or unspecified
|
http/https and blocks destinations that resolve to private, loopback,
|
||||||
address ranges.
|
link-local, unspecified, CGNAT (100.64.0.0/10), or IPv4-mapped IPv6
|
||||||
|
ranges.
|
||||||
|
|
||||||
|
DNS rebinding mitigation
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
``avalidate_outbound_url`` returns the original URL on success, but
|
||||||
|
also returns the resolved IP it actually validated. Callers that pass
|
||||||
|
the validated URL straight into ``aiohttp`` are vulnerable to a
|
||||||
|
DNS-rebinding attack: the validator's ``getaddrinfo`` returns a public
|
||||||
|
IP; aiohttp's connect-time resolution returns ``127.0.0.1``. To close
|
||||||
|
that gap, use :func:`build_ssrf_safe_session` (or
|
||||||
|
:class:`PinnedResolver`) so the resolved IP from the validation step is
|
||||||
|
the one aiohttp connects to.
|
||||||
|
|
||||||
Set ``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1`` in the environment for
|
Set ``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1`` in the environment for
|
||||||
development against localhost services.
|
development against localhost services.
|
||||||
@@ -17,12 +29,20 @@ import ipaddress
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
from dataclasses import dataclass
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
_ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1"
|
_ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1"
|
||||||
_ALLOWED_SCHEMES = {"http", "https"}
|
_ALLOWED_SCHEMES = frozenset({"http", "https"})
|
||||||
|
|
||||||
|
# Carrier-grade NAT range. Not in stdlib's ``is_private``; an attacker
|
||||||
|
# pointing a domain at a CGNAT IP could reach the operator's ISP-side
|
||||||
|
# routing infrastructure. RFC 6598.
|
||||||
|
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
|
||||||
|
|
||||||
if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner
|
if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
@@ -36,7 +56,29 @@ class UnsafeURLError(ValueError):
|
|||||||
"""Raised when a URL targets a disallowed network destination."""
|
"""Raised when a URL targets a disallowed network destination."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ValidatedURL:
|
||||||
|
"""Result of validating an outbound URL.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
url: The original URL string (unchanged).
|
||||||
|
host: Hostname extracted from the URL (lower-cased, IDN-encoded).
|
||||||
|
ip: Resolved IP address that passed the block-range check, as a
|
||||||
|
string. Pass to :class:`PinnedResolver` to defeat DNS
|
||||||
|
rebinding by reusing this exact IP at connect time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
url: str
|
||||||
|
host: str
|
||||||
|
ip: str
|
||||||
|
|
||||||
|
|
||||||
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||||
|
# An IPv4-mapped IPv6 like ``::ffff:127.0.0.1`` is NOT considered
|
||||||
|
# ``is_private`` etc. by stdlib — the v4 view holds those flags. So
|
||||||
|
# we unwrap before checking.
|
||||||
|
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped is not None:
|
||||||
|
ip = ip.ipv4_mapped
|
||||||
return (
|
return (
|
||||||
ip.is_private
|
ip.is_private
|
||||||
or ip.is_loopback
|
or ip.is_loopback
|
||||||
@@ -44,22 +86,54 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
|||||||
or ip.is_multicast
|
or ip.is_multicast
|
||||||
or ip.is_reserved
|
or ip.is_reserved
|
||||||
or ip.is_unspecified
|
or ip.is_unspecified
|
||||||
|
or (isinstance(ip, ipaddress.IPv4Address) and ip in _CGNAT_NETWORK)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_host_repr(host: str) -> str:
|
||||||
|
"""Return ``host`` shortened/escaped for safe inclusion in error text."""
|
||||||
|
h = host[:64].replace("\r", "").replace("\n", "")
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_host(parsed_host: str) -> str:
|
||||||
|
"""Normalize a hostname: lowercase, strip trailing dot, IDN-encode."""
|
||||||
|
host = parsed_host.lower()
|
||||||
|
if host.endswith("."):
|
||||||
|
host = host[:-1]
|
||||||
|
# Strip IPv6 zone id ("fe80::1%eth0") — must not reach the resolver.
|
||||||
|
if "%" in host:
|
||||||
|
host = host.split("%", 1)[0]
|
||||||
|
# IDN-encode unicode hostnames so we don't downgrade to confusables
|
||||||
|
# in any later log/output and so getaddrinfo gets the ascii form.
|
||||||
|
try:
|
||||||
|
if any(ord(c) > 127 for c in host):
|
||||||
|
host = host.encode("idna").decode("ascii")
|
||||||
|
except UnicodeError:
|
||||||
|
# Caller will fail on resolution; leave as-is so the error path
|
||||||
|
# surfaces a "DNS resolution failed" rather than a stack trace.
|
||||||
|
pass
|
||||||
|
return host
|
||||||
|
|
||||||
|
|
||||||
def _check_scheme_host(url: str) -> tuple[str, str]:
|
def _check_scheme_host(url: str) -> tuple[str, str]:
|
||||||
if not isinstance(url, str) or not url:
|
if not isinstance(url, str) or not url:
|
||||||
raise UnsafeURLError("URL is empty")
|
raise UnsafeURLError("URL is empty")
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
if parsed.scheme not in _ALLOWED_SCHEMES:
|
scheme = parsed.scheme.lower()
|
||||||
raise UnsafeURLError(f"Scheme '{parsed.scheme}' not allowed")
|
if scheme not in _ALLOWED_SCHEMES:
|
||||||
|
raise UnsafeURLError(f"Scheme '{scheme[:16]}' not allowed")
|
||||||
host = parsed.hostname
|
host = parsed.hostname
|
||||||
if not host:
|
if not host:
|
||||||
raise UnsafeURLError("URL has no host")
|
raise UnsafeURLError("URL has no host")
|
||||||
return parsed.scheme, host
|
return scheme, _normalize_host(host)
|
||||||
|
|
||||||
|
|
||||||
def _check_resolved_addresses(host: str, infos: list[tuple]) -> None:
|
def _select_addresses(
|
||||||
|
host: str, infos: list[tuple],
|
||||||
|
) -> list[ipaddress.IPv4Address | ipaddress.IPv6Address]:
|
||||||
|
"""Return parsed, non-blocked IPs from ``getaddrinfo`` results."""
|
||||||
|
addrs: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = []
|
||||||
for info in infos:
|
for info in infos:
|
||||||
sockaddr = info[4]
|
sockaddr = info[4]
|
||||||
try:
|
try:
|
||||||
@@ -67,64 +141,143 @@ def _check_resolved_addresses(host: str, infos: list[tuple]) -> None:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
continue
|
continue
|
||||||
if _is_blocked_ip(ip):
|
if _is_blocked_ip(ip):
|
||||||
raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
|
raise UnsafeURLError(
|
||||||
|
f"Host {_safe_host_repr(host)} resolves to blocked address {ip}"
|
||||||
|
)
|
||||||
|
addrs.append(ip)
|
||||||
|
if not addrs:
|
||||||
|
raise UnsafeURLError(f"Host {_safe_host_repr(host)} has no usable address")
|
||||||
|
return addrs
|
||||||
|
|
||||||
|
|
||||||
def validate_outbound_url(url: str) -> str:
|
def validate_outbound_url(url: str) -> str:
|
||||||
"""Validate ``url`` is safe to fetch; returns the URL on success.
|
"""Validate ``url`` is safe to fetch; returns the URL on success.
|
||||||
|
|
||||||
Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
|
.. deprecated::
|
||||||
is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
|
|
||||||
private addresses are permitted but the scheme check still applies.
|
|
||||||
|
|
||||||
Synchronous; uses blocking ``socket.getaddrinfo``. Prefer
|
Synchronous; uses blocking ``socket.getaddrinfo``. Prefer
|
||||||
:func:`avalidate_outbound_url` from async code paths.
|
:func:`avalidate_outbound_url` from async code paths so the
|
||||||
|
event loop isn't blocked, and use :func:`build_ssrf_safe_session`
|
||||||
|
to defeat DNS rebinding.
|
||||||
"""
|
"""
|
||||||
_, host = _check_scheme_host(url)
|
_, host = _check_scheme_host(url)
|
||||||
|
|
||||||
if _ALLOW_PRIVATE:
|
if _ALLOW_PRIVATE:
|
||||||
return url
|
return url
|
||||||
|
|
||||||
# Literal IP host
|
|
||||||
try:
|
try:
|
||||||
ip = ipaddress.ip_address(host)
|
ip = ipaddress.ip_address(host)
|
||||||
if _is_blocked_ip(ip):
|
if _is_blocked_ip(ip):
|
||||||
raise UnsafeURLError(f"Host {host} is in a blocked range")
|
raise UnsafeURLError(f"Host {_safe_host_repr(host)} is in a blocked range")
|
||||||
return url
|
return url
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
infos = socket.getaddrinfo(host, None)
|
infos = socket.getaddrinfo(host, None)
|
||||||
except socket.gaierror as exc:
|
except (socket.gaierror, UnicodeError, OSError) as exc:
|
||||||
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
|
# ``UnicodeError`` covers IDNA failures (labels >63 chars, malformed
|
||||||
_check_resolved_addresses(host, infos)
|
# unicode) which getaddrinfo surfaces as encoding errors rather than
|
||||||
|
# gaierror. ``OSError`` covers transient resolver failures on some
|
||||||
|
# platforms.
|
||||||
|
raise UnsafeURLError(f"DNS resolution failed for {_safe_host_repr(host)}") from exc
|
||||||
|
_select_addresses(host, infos)
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
async def avalidate_outbound_url(url: str) -> str:
|
async def avalidate_outbound_url(url: str) -> str:
|
||||||
"""Async variant that resolves DNS via the running loop's resolver.
|
"""Async variant — returns the URL on success.
|
||||||
|
|
||||||
Use this from ``async def`` code paths to avoid blocking the event
|
For DNS-rebinding-safe usage, prefer :func:`avalidate_outbound_url_full`
|
||||||
loop on DNS lookups.
|
which also returns the resolved IP for connect-time pinning.
|
||||||
|
"""
|
||||||
|
result = await avalidate_outbound_url_full(url)
|
||||||
|
return result.url
|
||||||
|
|
||||||
|
|
||||||
|
async def avalidate_outbound_url_full(url: str) -> ValidatedURL:
|
||||||
|
"""Validate ``url`` and return a :class:`ValidatedURL` on success.
|
||||||
|
|
||||||
|
The returned ``ip`` field is the IP that passed the block-range
|
||||||
|
check. Pair with :class:`PinnedResolver` so aiohttp connects to that
|
||||||
|
exact IP and a malicious DNS server can't swap in a private address
|
||||||
|
after validation.
|
||||||
"""
|
"""
|
||||||
_, host = _check_scheme_host(url)
|
_, host = _check_scheme_host(url)
|
||||||
|
|
||||||
if _ALLOW_PRIVATE:
|
if _ALLOW_PRIVATE:
|
||||||
return url
|
# In dev mode we still resolve to give a usable IP, but we don't
|
||||||
|
# gate on the result.
|
||||||
try:
|
try:
|
||||||
ip = ipaddress.ip_address(host)
|
ip = str(ipaddress.ip_address(host))
|
||||||
if _is_blocked_ip(ip):
|
except ValueError:
|
||||||
raise UnsafeURLError(f"Host {host} is in a blocked range")
|
try:
|
||||||
return url
|
loop = asyncio.get_running_loop()
|
||||||
|
infos = await loop.getaddrinfo(host, None)
|
||||||
|
ip = infos[0][4][0] if infos else host
|
||||||
|
except (socket.gaierror, OSError):
|
||||||
|
ip = host
|
||||||
|
return ValidatedURL(url=url, host=host, ip=ip)
|
||||||
|
|
||||||
|
# Literal IP host
|
||||||
|
try:
|
||||||
|
ip_obj = ipaddress.ip_address(host)
|
||||||
|
if _is_blocked_ip(ip_obj):
|
||||||
|
raise UnsafeURLError(f"Host {_safe_host_repr(host)} is in a blocked range")
|
||||||
|
return ValidatedURL(url=url, host=host, ip=str(ip_obj))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
try:
|
try:
|
||||||
infos = await loop.getaddrinfo(host, None)
|
infos = await loop.getaddrinfo(host, None)
|
||||||
except socket.gaierror as exc:
|
except (socket.gaierror, UnicodeError, OSError) as exc:
|
||||||
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
|
raise UnsafeURLError(f"DNS resolution failed for {_safe_host_repr(host)}") from exc
|
||||||
_check_resolved_addresses(host, infos)
|
addrs = _select_addresses(host, infos)
|
||||||
return url
|
return ValidatedURL(url=url, host=host, ip=str(addrs[0]))
|
||||||
|
|
||||||
|
|
||||||
|
class PinnedResolver(aiohttp.abc.AbstractResolver):
|
||||||
|
"""aiohttp resolver that returns a fixed (host, ip) mapping.
|
||||||
|
|
||||||
|
Used to pin the resolved IP from :func:`avalidate_outbound_url_full`
|
||||||
|
so aiohttp's connect-time resolution can't be tricked by DNS
|
||||||
|
rebinding into using a different IP than the one we validated.
|
||||||
|
|
||||||
|
Falls back to :class:`aiohttp.AsyncResolver` (or default) for any
|
||||||
|
host not explicitly pinned, so a single resolver instance can be
|
||||||
|
reused across multiple validated URLs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mapping: dict[str, str] | None = None) -> None:
|
||||||
|
self._map: dict[str, str] = dict(mapping or {})
|
||||||
|
self._fallback: aiohttp.abc.AbstractResolver | None = None
|
||||||
|
|
||||||
|
def pin(self, host: str, ip: str) -> None:
|
||||||
|
self._map[host.lower()] = ip
|
||||||
|
|
||||||
|
async def resolve(
|
||||||
|
self, host: str, port: int = 0, family: int = socket.AF_INET,
|
||||||
|
) -> list[dict]:
|
||||||
|
ip = self._map.get(host.lower())
|
||||||
|
if ip is not None:
|
||||||
|
try:
|
||||||
|
ip_obj = ipaddress.ip_address(ip)
|
||||||
|
except ValueError:
|
||||||
|
ip_obj = None
|
||||||
|
if ip_obj is not None:
|
||||||
|
fam = socket.AF_INET6 if ip_obj.version == 6 else socket.AF_INET
|
||||||
|
return [{
|
||||||
|
"hostname": host,
|
||||||
|
"host": ip,
|
||||||
|
"port": port,
|
||||||
|
"family": fam,
|
||||||
|
"proto": 0,
|
||||||
|
"flags": socket.AI_NUMERICHOST,
|
||||||
|
}]
|
||||||
|
if self._fallback is None:
|
||||||
|
self._fallback = aiohttp.ThreadedResolver()
|
||||||
|
return await self._fallback.resolve(host, port, family)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._fallback is not None:
|
||||||
|
await self._fallback.close()
|
||||||
|
|||||||
@@ -2,16 +2,29 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any, Final
|
||||||
|
|
||||||
from notify_bridge_core.storage import StorageBackend
|
from notify_bridge_core.storage import StorageBackend
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
DEFAULT_TELEGRAM_CACHE_TTL = 48 * 60 * 60
|
DEFAULT_TELEGRAM_CACHE_TTL: Final = 48 * 60 * 60
|
||||||
DEFAULT_MAX_ENTRIES = 5000
|
DEFAULT_MAX_ENTRIES: Final = 5000
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_iso(value: str | None) -> datetime | None:
|
||||||
|
"""Parse an ISO-8601 timestamp tolerantly. Returns ``None`` on failure."""
|
||||||
|
if not value or not isinstance(value, str):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
# Python <3.11 doesn't accept "Z"; normalize to +00:00.
|
||||||
|
v = value.replace("Z", "+00:00") if value.endswith("Z") else value
|
||||||
|
return datetime.fromisoformat(v)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class TelegramFileCache:
|
class TelegramFileCache:
|
||||||
@@ -25,7 +38,17 @@ class TelegramFileCache:
|
|||||||
Intended for content-addressable assets (e.g. Immich) where re-uploads
|
Intended for content-addressable assets (e.g. Immich) where re-uploads
|
||||||
should be triggered by visual change, not elapsed time.
|
should be triggered by visual change, not elapsed time.
|
||||||
|
|
||||||
``max_entries`` always applies as an LRU size cap (by ``cached_at``).
|
``max_entries`` always applies as a FIFO size cap (oldest-cached first).
|
||||||
|
|
||||||
|
Concurrency
|
||||||
|
~~~~~~~~~~~
|
||||||
|
All mutators take an internal ``asyncio.Lock`` so concurrent
|
||||||
|
media-group sends can't interleave a read-time invalidation with a
|
||||||
|
bulk write and corrupt the underlying dict (``RuntimeError:
|
||||||
|
dictionary changed size during iteration``) or lose just-written
|
||||||
|
entries. Reads do not take the lock — they are O(1) dict lookups —
|
||||||
|
but ``get`` uses a snapshot reference so it cannot mutate the data
|
||||||
|
structure under another task.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -40,35 +63,40 @@ class TelegramFileCache:
|
|||||||
self._ttl_seconds = ttl_seconds
|
self._ttl_seconds = ttl_seconds
|
||||||
self._use_thumbhash = use_thumbhash
|
self._use_thumbhash = use_thumbhash
|
||||||
self._max_entries = max_entries
|
self._max_entries = max_entries
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
async def async_load(self) -> None:
|
async def async_load(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
self._data = await self._backend.load() or {"files": {}}
|
self._data = await self._backend.load() or {"files": {}}
|
||||||
await self._cleanup_expired()
|
await self._cleanup_expired_locked()
|
||||||
|
|
||||||
async def _cleanup_expired(self) -> None:
|
async def _cleanup_expired_locked(self) -> None:
|
||||||
|
"""Caller must hold ``self._lock``."""
|
||||||
if not self._data or "files" not in self._data:
|
if not self._data or "files" not in self._data:
|
||||||
return
|
return
|
||||||
files = self._data["files"]
|
files: dict[str, dict[str, Any]] = self._data["files"]
|
||||||
changed = False
|
changed = False
|
||||||
|
|
||||||
# TTL sweep — only when TTL validation is active (i.e. no thumbhash
|
|
||||||
# mode and a positive TTL). In thumbhash mode we rely entirely on
|
|
||||||
# content validation; in "TTL disabled" mode (ttl_seconds <= 0) we
|
|
||||||
# cache forever, subject only to the size cap.
|
|
||||||
if not self._use_thumbhash and self._ttl_seconds > 0:
|
if not self._use_thumbhash and self._ttl_seconds > 0:
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
expired = [
|
expired: list[str] = []
|
||||||
url for url, entry in files.items()
|
for url, entry in list(files.items()):
|
||||||
if entry.get("cached_at") and
|
cached_at = _parse_iso(entry.get("cached_at"))
|
||||||
(now - datetime.fromisoformat(entry["cached_at"])).total_seconds() > self._ttl_seconds
|
if cached_at is None:
|
||||||
]
|
continue
|
||||||
|
if cached_at.tzinfo is None:
|
||||||
|
cached_at = cached_at.replace(tzinfo=timezone.utc)
|
||||||
|
if (now - cached_at).total_seconds() > self._ttl_seconds:
|
||||||
|
expired.append(url)
|
||||||
for key in expired:
|
for key in expired:
|
||||||
del files[key]
|
del files[key]
|
||||||
changed = True
|
changed = True
|
||||||
|
|
||||||
# LRU cap — always enforced. Evicts oldest-cached entries first.
|
|
||||||
if self._max_entries > 0 and len(files) > self._max_entries:
|
if self._max_entries > 0 and len(files) > self._max_entries:
|
||||||
sorted_keys = sorted(files, key=lambda k: files[k].get("cached_at", ""))
|
sorted_keys = sorted(
|
||||||
|
files,
|
||||||
|
key=lambda k: _parse_iso(files[k].get("cached_at")) or datetime.min.replace(tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
for key in sorted_keys[: len(files) - self._max_entries]:
|
for key in sorted_keys[: len(files) - self._max_entries]:
|
||||||
del files[key]
|
del files[key]
|
||||||
changed = True
|
changed = True
|
||||||
@@ -80,7 +108,10 @@ class TelegramFileCache:
|
|||||||
if not self._data or "files" not in self._data:
|
if not self._data or "files" not in self._data:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
entry = self._data["files"].get(key)
|
# Take a local reference so a concurrent ``async_set`` rebuilding
|
||||||
|
# the dict cannot pull the rug out mid-read.
|
||||||
|
files = self._data["files"]
|
||||||
|
entry = files.get(key)
|
||||||
if not entry:
|
if not entry:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -88,19 +119,23 @@ class TelegramFileCache:
|
|||||||
if thumbhash is not None:
|
if thumbhash is not None:
|
||||||
stored = entry.get("thumbhash")
|
stored = entry.get("thumbhash")
|
||||||
if stored and stored != thumbhash:
|
if stored and stored != thumbhash:
|
||||||
del self._data["files"][key]
|
# Mark stale — actual deletion happens lock-protected
|
||||||
|
# in the next mutation. Returning None is sufficient
|
||||||
|
# for the caller to skip the cache hit.
|
||||||
return None
|
return None
|
||||||
elif self._ttl_seconds > 0:
|
elif self._ttl_seconds > 0:
|
||||||
cached_at_str = entry.get("cached_at")
|
cached_at = _parse_iso(entry.get("cached_at"))
|
||||||
if cached_at_str:
|
if cached_at is not None:
|
||||||
age = (datetime.now(timezone.utc) - datetime.fromisoformat(cached_at_str)).total_seconds()
|
if cached_at.tzinfo is None:
|
||||||
|
cached_at = cached_at.replace(tzinfo=timezone.utc)
|
||||||
|
age = (datetime.now(timezone.utc) - cached_at).total_seconds()
|
||||||
if age > self._ttl_seconds:
|
if age > self._ttl_seconds:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"file_id": entry.get("file_id"),
|
"file_id": entry.get("file_id"),
|
||||||
"type": entry.get("type"),
|
"type": entry.get("type"),
|
||||||
"size": entry.get("size"), # bytes of what was uploaded; None for legacy entries
|
"size": entry.get("size"),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def async_set(
|
async def async_set(
|
||||||
@@ -111,6 +146,7 @@ class TelegramFileCache:
|
|||||||
thumbhash: str | None = None,
|
thumbhash: str | None = None,
|
||||||
size: int | None = None,
|
size: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
async with self._lock:
|
||||||
if self._data is None:
|
if self._data is None:
|
||||||
self._data = {"files": {}}
|
self._data = {"files": {}}
|
||||||
|
|
||||||
@@ -139,6 +175,7 @@ class TelegramFileCache:
|
|||||||
"""
|
"""
|
||||||
if not entries:
|
if not entries:
|
||||||
return
|
return
|
||||||
|
async with self._lock:
|
||||||
if self._data is None:
|
if self._data is None:
|
||||||
self._data = {"files": {}}
|
self._data = {"files": {}}
|
||||||
|
|
||||||
@@ -163,6 +200,7 @@ class TelegramFileCache:
|
|||||||
await self._backend.save(self._data)
|
await self._backend.save(self._data)
|
||||||
|
|
||||||
async def async_remove(self) -> None:
|
async def async_remove(self) -> None:
|
||||||
|
async with self._lock:
|
||||||
await self._backend.remove()
|
await self._backend.remove()
|
||||||
self._data = None
|
self._data = None
|
||||||
|
|
||||||
@@ -172,25 +210,33 @@ class TelegramFileCache:
|
|||||||
Includes the number of cached entries, total tracked size in bytes
|
Includes the number of cached entries, total tracked size in bytes
|
||||||
(only counts entries with a recorded ``size``), and the oldest /
|
(only counts entries with a recorded ``size``), and the oldest /
|
||||||
newest ``cached_at`` timestamps (ISO strings, or ``None`` if empty).
|
newest ``cached_at`` timestamps (ISO strings, or ``None`` if empty).
|
||||||
|
Timestamps are compared as parsed ``datetime`` objects so mixed
|
||||||
|
timezone formats (``Z`` vs ``+00:00``) order correctly.
|
||||||
"""
|
"""
|
||||||
files = self._data.get("files", {}) if self._data else {}
|
files = self._data.get("files", {}) if self._data else {}
|
||||||
count = len(files)
|
count = len(files)
|
||||||
total_size = 0
|
total_size = 0
|
||||||
oldest: str | None = None
|
oldest_dt: datetime | None = None
|
||||||
newest: str | None = None
|
newest_dt: datetime | None = None
|
||||||
|
oldest_str: str | None = None
|
||||||
|
newest_str: str | None = None
|
||||||
for entry in files.values():
|
for entry in files.values():
|
||||||
size = entry.get("size")
|
size = entry.get("size")
|
||||||
if isinstance(size, int):
|
if isinstance(size, int):
|
||||||
total_size += size
|
total_size += size
|
||||||
cached_at = entry.get("cached_at")
|
cached_at = entry.get("cached_at")
|
||||||
if cached_at:
|
dt = _parse_iso(cached_at)
|
||||||
if oldest is None or cached_at < oldest:
|
if dt is None or not cached_at:
|
||||||
oldest = cached_at
|
continue
|
||||||
if newest is None or cached_at > newest:
|
if dt.tzinfo is None:
|
||||||
newest = cached_at
|
dt = dt.replace(tzinfo=timezone.utc)
|
||||||
|
if oldest_dt is None or dt < oldest_dt:
|
||||||
|
oldest_dt, oldest_str = dt, cached_at
|
||||||
|
if newest_dt is None or dt > newest_dt:
|
||||||
|
newest_dt, newest_str = dt, cached_at
|
||||||
return {
|
return {
|
||||||
"count": count,
|
"count": count,
|
||||||
"total_size_bytes": total_size,
|
"total_size_bytes": total_size,
|
||||||
"oldest": oldest,
|
"oldest": oldest_str,
|
||||||
"newest": newest,
|
"newest": newest_str,
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -2,20 +2,35 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Final
|
from typing import Any, Final
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Telegram constants
|
# Telegram constants
|
||||||
TELEGRAM_API_BASE_URL: Final = "https://api.telegram.org/bot"
|
TELEGRAM_API_BASE_URL: Final = "https://api.telegram.org/bot"
|
||||||
TELEGRAM_MAX_PHOTO_SIZE: Final = 10 * 1024 * 1024 # 10 MB
|
TELEGRAM_MAX_PHOTO_SIZE: Final = 10 * 1024 * 1024 # 10 MB
|
||||||
TELEGRAM_MAX_VIDEO_SIZE: Final = 50 * 1024 * 1024 # 50 MB
|
TELEGRAM_MAX_VIDEO_SIZE: Final = 50 * 1024 * 1024 # 50 MB
|
||||||
TELEGRAM_MAX_DIMENSION_SUM: Final = 10000
|
TELEGRAM_MAX_DIMENSION_SUM: Final = 10000
|
||||||
|
# Telegram message-text limit (sendMessage) and caption limit
|
||||||
|
# (sendPhoto/sendVideo/sendDocument/first item of sendMediaGroup).
|
||||||
|
TELEGRAM_MAX_TEXT_LENGTH: Final = 4096
|
||||||
|
TELEGRAM_MAX_CAPTION_LENGTH: Final = 1024
|
||||||
|
|
||||||
# Generic UUID pattern for asset IDs
|
# Strict canonical-UUID pattern (8-4-4-4-12) for asset IDs. The previous
|
||||||
_ASSET_ID_PATTERN = re.compile(r"^[a-f0-9-]{36}$")
|
# loose ``[a-f0-9-]{36}`` matched 36 hyphens / arbitrary digit groupings,
|
||||||
|
# which could collide across providers when used as a cache key.
|
||||||
|
_ASSET_ID_PATTERN = re.compile(
|
||||||
|
r"^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
# Cache key: "host:uuid" or bare "uuid"
|
# Cache key: "host:uuid" or bare "uuid"
|
||||||
_ASSET_CACHE_KEY_PATTERN = re.compile(r"^(?:[^:]+:)?[a-f0-9-]{36}$")
|
_ASSET_CACHE_KEY_PATTERN = re.compile(
|
||||||
|
r"^(?:[^:]+:)?[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
# URL patterns to extract asset IDs (generic enough for Immich-style URLs)
|
# URL patterns to extract asset IDs (generic enough for Immich-style URLs)
|
||||||
_ASSET_ID_URL_PATTERNS = [
|
_ASSET_ID_URL_PATTERNS = [
|
||||||
@@ -162,5 +177,10 @@ def check_photo_limits(
|
|||||||
return False, None, width, height
|
return False, None, width, height
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False, None, None, None
|
return False, None, None, None
|
||||||
except Exception:
|
except (OSError, ValueError, MemoryError) as exc:
|
||||||
|
# PIL surfaces ``UnidentifiedImageError`` (subclass of OSError),
|
||||||
|
# truncated-image / decompression-bomb errors here. Log so a
|
||||||
|
# corrupt asset isn't silently passed to Telegram and rejected
|
||||||
|
# downstream with a less actionable error.
|
||||||
|
_LOGGER.warning("check_photo_limits: failed to inspect image (%d bytes): %s", len(data), exc)
|
||||||
return False, None, None, None
|
return False, None, None, None
|
||||||
|
|||||||
@@ -7,37 +7,29 @@ from typing import Any
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from ..ssrf import UnsafeURLError, avalidate_outbound_url
|
from ..http_base import HttpProviderClient
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
|
||||||
|
|
||||||
|
class WebhookClient(HttpProviderClient):
|
||||||
|
"""Send JSON payloads to a webhook URL.
|
||||||
|
|
||||||
class WebhookClient:
|
The URL is SSRF-validated on every send (defense-in-depth: re-validating
|
||||||
"""Send JSON payloads to a webhook URL."""
|
catches DNS rebinding between calls and a misconfigured target). Headers
|
||||||
|
pass through :func:`safe_headers` so a target config can't inject
|
||||||
|
framing/hop-by-hop headers like ``Host`` or ``Transfer-Encoding``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, session: aiohttp.ClientSession, url: str, headers: dict[str, str] | None = None) -> None:
|
def __init__(
|
||||||
self._session = session
|
self,
|
||||||
|
session: aiohttp.ClientSession,
|
||||||
|
url: str,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(session, provider_name="webhook")
|
||||||
self._url = url
|
self._url = url
|
||||||
self._headers = headers or {}
|
self._extra_headers = headers or {}
|
||||||
|
|
||||||
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
|
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
try:
|
return await self.request("POST", self._url, json=payload, headers=self._extra_headers)
|
||||||
await avalidate_outbound_url(self._url)
|
|
||||||
except UnsafeURLError as err:
|
|
||||||
return {"success": False, "error": f"Unsafe URL: {err}"}
|
|
||||||
try:
|
|
||||||
async with self._session.post(
|
|
||||||
self._url,
|
|
||||||
json=payload,
|
|
||||||
headers={"Content-Type": "application/json", **self._headers},
|
|
||||||
timeout=_DEFAULT_TIMEOUT,
|
|
||||||
allow_redirects=False,
|
|
||||||
) as response:
|
|
||||||
if 200 <= response.status < 300:
|
|
||||||
return {"success": True, "status_code": response.status}
|
|
||||||
body = await response.text()
|
|
||||||
return {"success": False, "error": f"HTTP {response.status}", "body": body[:200]}
|
|
||||||
except aiohttp.ClientError as err:
|
|
||||||
return {"success": False, "error": str(err)}
|
|
||||||
|
|||||||
@@ -218,6 +218,19 @@ async def get_supported_locales(
|
|||||||
return locales or ["en"]
|
return locales or ["en"]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/external-url")
|
||||||
|
async def get_external_url(
|
||||||
|
user: User = Depends(get_current_user),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
):
|
||||||
|
"""Return the configured external base URL (available to all users).
|
||||||
|
|
||||||
|
Used by the UI to render absolute provider webhook URLs. Returns empty
|
||||||
|
string when unset so the UI falls back to the relative path.
|
||||||
|
"""
|
||||||
|
return {"external_url": (await get_setting(session, "external_url")).rstrip("/")}
|
||||||
|
|
||||||
|
|
||||||
async def _reregister_webhooks(
|
async def _reregister_webhooks(
|
||||||
session: AsyncSession, base_url: str, secret: str
|
session: AsyncSession, base_url: str, secret: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
"""Dispatcher result aggregation: per-receiver detail must survive."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.dispatcher import NotificationDispatcher
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_all_success() -> None:
|
||||||
|
out = NotificationDispatcher._aggregate_results([
|
||||||
|
{"success": True, "message_id": 1},
|
||||||
|
{"success": True, "message_id": 2},
|
||||||
|
])
|
||||||
|
assert out["success"] is True
|
||||||
|
assert out["receivers"] == 2
|
||||||
|
assert out["successes"] == 2
|
||||||
|
assert out["failures"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_partial() -> None:
|
||||||
|
out = NotificationDispatcher._aggregate_results([
|
||||||
|
{"success": True},
|
||||||
|
{"success": False, "error": "boom"},
|
||||||
|
])
|
||||||
|
assert out["success"] is True # at least one succeeded
|
||||||
|
assert out["successes"] == 1
|
||||||
|
assert out["failures"] == 1
|
||||||
|
assert "boom" in out["errors"]
|
||||||
|
assert "results" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_all_fail_preserves_all_errors() -> None:
|
||||||
|
out = NotificationDispatcher._aggregate_results([
|
||||||
|
{"success": False, "error": "first"},
|
||||||
|
{"success": False, "error": "second"},
|
||||||
|
])
|
||||||
|
assert out["success"] is False
|
||||||
|
assert out["error"] == "first" # back-compat top-level field
|
||||||
|
assert out["errors"] == ["first", "second"]
|
||||||
|
# Per-receiver details survive — operator can see exactly what failed.
|
||||||
|
assert len(out["results"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_empty() -> None:
|
||||||
|
out = NotificationDispatcher._aggregate_results([])
|
||||||
|
assert out["success"] is False
|
||||||
|
assert "error" in out
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
"""Email client header-injection / address-validation regression tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.email.client import (
|
||||||
|
EmailClient,
|
||||||
|
SmtpConfig,
|
||||||
|
_strip_header,
|
||||||
|
_validate_email,
|
||||||
|
_to_html,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_strip_header_removes_crlf() -> None:
|
||||||
|
out = _strip_header("Subject\r\nBcc: attacker@example.com")
|
||||||
|
assert "\r" not in out
|
||||||
|
assert "\n" not in out
|
||||||
|
# The injected "Bcc:" line is folded to a single header line; the SMTP
|
||||||
|
# server will treat it as part of the subject text, not a header.
|
||||||
|
assert "Bcc:" in out # value preserved as plain text
|
||||||
|
|
||||||
|
|
||||||
|
def test_strip_header_removes_bare_lf() -> None:
|
||||||
|
out = _strip_header("Hello\nWorld")
|
||||||
|
assert "\n" not in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_strip_header_handles_non_string() -> None:
|
||||||
|
assert _strip_header(None) == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_email_rejects_crlf() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_validate_email("user@example.com\r\nBcc: x@y")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_email_rejects_no_at() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_validate_email("not-an-email")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_email_rejects_empty() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_validate_email("")
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_email_accepts_normal() -> None:
|
||||||
|
assert _validate_email("user@example.com") == "user@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_html_escapes_brackets() -> None:
|
||||||
|
out = _to_html("<script>alert(1)</script>")
|
||||||
|
assert "<script>" not in out
|
||||||
|
assert "<script>" in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_returns_error_on_invalid_to() -> None:
|
||||||
|
cfg = SmtpConfig(host="smtp.example.com", from_address="from@example.com")
|
||||||
|
client = EmailClient(cfg)
|
||||||
|
result = await client.send(
|
||||||
|
to_email="user@example.com\r\nBcc: attacker@example.com",
|
||||||
|
subject="hi",
|
||||||
|
body_text="body",
|
||||||
|
)
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Invalid email" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_returns_error_on_no_host() -> None:
|
||||||
|
cfg = SmtpConfig(host="", from_address="from@example.com")
|
||||||
|
client = EmailClient(cfg)
|
||||||
|
result = await client.send("u@x.com", "s", "b")
|
||||||
|
assert result["success"] is False
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
"""HttpProviderClient + safe_headers tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.http_base import safe_headers
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeHeaders:
|
||||||
|
def test_drops_hop_by_hop(self) -> None:
|
||||||
|
out = safe_headers({
|
||||||
|
"X-Custom": "ok",
|
||||||
|
"Host": "evil.example.com",
|
||||||
|
"Content-Length": "999",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
|
"Connection": "close",
|
||||||
|
})
|
||||||
|
assert out == {"X-Custom": "ok"}
|
||||||
|
|
||||||
|
def test_rejects_crlf_in_value(self) -> None:
|
||||||
|
out = safe_headers({
|
||||||
|
"X-Custom": "ok",
|
||||||
|
"X-Bad": "value\r\nInjected: yes",
|
||||||
|
})
|
||||||
|
assert "X-Custom" in out
|
||||||
|
assert "X-Bad" not in out
|
||||||
|
|
||||||
|
def test_rejects_crlf_in_name(self) -> None:
|
||||||
|
out = safe_headers({
|
||||||
|
"X-Custom": "ok",
|
||||||
|
"X-Bad\r\nInject": "value",
|
||||||
|
})
|
||||||
|
assert out == {"X-Custom": "ok"}
|
||||||
|
|
||||||
|
def test_empty_input(self) -> None:
|
||||||
|
assert safe_headers(None) == {}
|
||||||
|
assert safe_headers({}) == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_base_returns_safe_error_on_invalid_url() -> None:
|
||||||
|
"""An obviously-bad URL must not panic or leak the URL verbatim."""
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.http_base import HttpProviderClient
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as sess:
|
||||||
|
client = HttpProviderClient(sess, provider_name="test")
|
||||||
|
# file:// is rejected by the SSRF guard before any HTTP call.
|
||||||
|
result = await client.request("POST", "file:///etc/passwd", json={})
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Unsafe URL" in result["error"]
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
"""Matrix client validation: room_id format and quoting."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
from aioresponses import aioresponses
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.matrix.client import MatrixClient
|
||||||
|
|
||||||
|
|
||||||
|
HOMESERVER = "https://matrix.example.com"
|
||||||
|
TOKEN = "secret-bearer-token-1234567890"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_path_injection_room_id() -> None:
|
||||||
|
async with aiohttp.ClientSession() as sess:
|
||||||
|
client = MatrixClient(sess, HOMESERVER, TOKEN)
|
||||||
|
result = await client.send_message("!abc:host/../../etc/passwd", "hi")
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "room_id" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_empty_room_id() -> None:
|
||||||
|
async with aiohttp.ClientSession() as sess:
|
||||||
|
client = MatrixClient(sess, HOMESERVER, TOKEN)
|
||||||
|
result = await client.send_message("", "hi")
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "room_id" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_unicode_control_chars_in_room_id() -> None:
|
||||||
|
async with aiohttp.ClientSession() as sess:
|
||||||
|
client = MatrixClient(sess, HOMESERVER, TOKEN)
|
||||||
|
result = await client.send_message("!abc\x00:host", "hi")
|
||||||
|
assert result["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_url_encodes_room_id_special_chars() -> None:
|
||||||
|
"""``!`` and ``:`` must reach the server URL-encoded."""
|
||||||
|
captured: list[str] = []
|
||||||
|
|
||||||
|
with aioresponses() as mocked:
|
||||||
|
# Match any PUT under the rooms path; capture the URL we got.
|
||||||
|
mocked.put(
|
||||||
|
"https://matrix.example.com/_matrix/client/v3/rooms/%21abc%3Ahost.example/send/m.room.message",
|
||||||
|
status=200, body='{}', repeat=True,
|
||||||
|
)
|
||||||
|
# aioresponses doesn't expose URL templates well, so use a regex mock.
|
||||||
|
import re
|
||||||
|
mocked.put(
|
||||||
|
re.compile(r"https://matrix\.example\.com/_matrix/client/v3/rooms/[^/]+/send/m\.room\.message/.*"),
|
||||||
|
status=200, body='{}', repeat=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as sess:
|
||||||
|
client = MatrixClient(sess, HOMESERVER, TOKEN)
|
||||||
|
result = await client.send_message("!abc:host.example", "hi")
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redacts_bearer_in_error() -> None:
|
||||||
|
"""A 4xx response body must not echo the Authorization Bearer back to caller."""
|
||||||
|
import re
|
||||||
|
with aioresponses() as mocked:
|
||||||
|
mocked.put(
|
||||||
|
re.compile(r".*"),
|
||||||
|
status=403,
|
||||||
|
body='{"errcode": "M_FORBIDDEN", "Authorization": "Bearer ' + TOKEN + '"}',
|
||||||
|
repeat=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as sess:
|
||||||
|
client = MatrixClient(sess, HOMESERVER, TOKEN)
|
||||||
|
result = await client.send_message("!abc:host.example", "hi")
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert TOKEN not in result["error"]
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
"""NotificationQueue bound + concurrency regression tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.queue import (
|
||||||
|
DEFAULT_MAX_QUEUE_SIZE,
|
||||||
|
NotificationQueue,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _MemBackend:
|
||||||
|
"""In-memory storage backend stub for tests."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._data: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
async def load(self) -> dict[str, Any] | None:
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
async def save(self, data: dict[str, Any]) -> None:
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
async def remove(self) -> None:
|
||||||
|
self._data = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_with_garbage_falls_back_to_empty() -> None:
|
||||||
|
backend = _MemBackend()
|
||||||
|
backend._data = {"queue": "not-a-list"} # type: ignore[assignment]
|
||||||
|
q = NotificationQueue(backend)
|
||||||
|
await q.async_load()
|
||||||
|
assert q.get_all() == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enqueue_caps_at_max_size() -> None:
|
||||||
|
backend = _MemBackend()
|
||||||
|
q = NotificationQueue(backend, max_size=3)
|
||||||
|
await q.async_load()
|
||||||
|
for i in range(10):
|
||||||
|
await q.async_enqueue({"i": i})
|
||||||
|
items = q.get_all()
|
||||||
|
assert len(items) == 3
|
||||||
|
# FIFO drop: most recent three are kept (i=7..9).
|
||||||
|
assert [it["params"]["i"] for it in items] == [7, 8, 9]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all_returns_deep_copy() -> None:
|
||||||
|
backend = _MemBackend()
|
||||||
|
q = NotificationQueue(backend, max_size=10)
|
||||||
|
await q.async_load()
|
||||||
|
await q.async_enqueue({"key": "value"})
|
||||||
|
snap = q.get_all()
|
||||||
|
snap[0]["params"]["key"] = "MUTATED"
|
||||||
|
snap2 = q.get_all()
|
||||||
|
assert snap2[0]["params"]["key"] == "value"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_enqueue_and_clear_no_corruption() -> None:
|
||||||
|
backend = _MemBackend()
|
||||||
|
q = NotificationQueue(backend, max_size=DEFAULT_MAX_QUEUE_SIZE)
|
||||||
|
await q.async_load()
|
||||||
|
|
||||||
|
async def producer() -> None:
|
||||||
|
for i in range(50):
|
||||||
|
await q.async_enqueue({"i": i})
|
||||||
|
|
||||||
|
async def clearer() -> None:
|
||||||
|
for _ in range(10):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
await q.async_clear()
|
||||||
|
|
||||||
|
await asyncio.gather(producer(), clearer())
|
||||||
|
# No exceptions = no race-induced "dictionary changed size during iteration".
|
||||||
|
items = q.get_all()
|
||||||
|
assert isinstance(items, list)
|
||||||
@@ -0,0 +1,74 @@
|
|||||||
|
"""Secret-redaction helper regression tests.
|
||||||
|
|
||||||
|
Locks in the patterns that surface from real provider error paths:
|
||||||
|
Telegram bot URLs in aiohttp.ClientError messages, Authorization Bearer
|
||||||
|
tokens in Matrix/ntfy responses, Discord/Slack webhook tokens, URL
|
||||||
|
userinfo, and common ?token= query params.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.redact import redact, redact_exc
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"raw,expected_substr,not_in",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"Cannot connect to host api.telegram.org/bot1234567:AABBCC-secret-token/sendMessage",
|
||||||
|
"api.telegram.org/bot***",
|
||||||
|
"AABBCC-secret-token",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Authorization: Bearer ey.JhbGciOiJIUzI1NiJ9.payload.sig",
|
||||||
|
"Bearer ***",
|
||||||
|
"ey.JhbGciOiJIUzI1NiJ9",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"POST https://discord.com/api/webhooks/12345/abcdefg-token failed",
|
||||||
|
"discord.com/api/webhooks/12345/***",
|
||||||
|
"abcdefg-token",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"POST https://hooks.slack.com/services/T01/B02/zzzzz failed",
|
||||||
|
"hooks.slack.com/services/T01/B02/***",
|
||||||
|
"zzzzz",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"fetch http://user:supersecret@example.com/foo",
|
||||||
|
"http://***@example.com/foo",
|
||||||
|
"supersecret",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"https://api.example.com/x?token=mytoken123&extra=ok",
|
||||||
|
"token=***",
|
||||||
|
"mytoken123",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_redact_known_secrets(raw: str, expected_substr: str, not_in: str) -> None:
|
||||||
|
out = redact(raw)
|
||||||
|
assert expected_substr in out
|
||||||
|
assert not_in not in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_redact_idempotent() -> None:
|
||||||
|
once = redact("Bearer abcdefghij1234567890")
|
||||||
|
twice = redact(once)
|
||||||
|
assert once == twice
|
||||||
|
|
||||||
|
|
||||||
|
def test_redact_exc_returns_str() -> None:
|
||||||
|
err = RuntimeError("Bearer abcdefghij1234567890")
|
||||||
|
out = redact_exc(err)
|
||||||
|
assert isinstance(out, str)
|
||||||
|
assert "Bearer ***" in out
|
||||||
|
assert "abcdefghij1234567890" not in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_redact_handles_non_string() -> None:
|
||||||
|
# Coercion path should not raise.
|
||||||
|
out = redact(12345) # type: ignore[arg-type]
|
||||||
|
assert out == "12345"
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
"""SSRF hardening regression tests.
|
||||||
|
|
||||||
|
Covers cases the original guard missed: IPv4-mapped IPv6, CGNAT,
|
||||||
|
trailing-dot hostnames, IPv6 zone identifiers, and the safe-host repr
|
||||||
|
used in error messages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.ssrf import (
|
||||||
|
UnsafeURLError,
|
||||||
|
PinnedResolver,
|
||||||
|
avalidate_outbound_url_full,
|
||||||
|
validate_outbound_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlockedRanges:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"url",
|
||||||
|
[
|
||||||
|
"http://[::ffff:127.0.0.1]/", # IPv4-mapped IPv6 → loopback
|
||||||
|
"http://[::ffff:10.0.0.1]/", # IPv4-mapped IPv6 → RFC1918
|
||||||
|
"http://100.64.0.1/", # CGNAT (RFC 6598)
|
||||||
|
"http://0.0.0.0/", # unspecified
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rejects_extra_ranges(self, url: str) -> None:
|
||||||
|
with pytest.raises(UnsafeURLError):
|
||||||
|
validate_outbound_url(url)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHostnameNormalization:
|
||||||
|
def test_strips_trailing_dot(self) -> None:
|
||||||
|
# ``localhost.`` should normalize to ``localhost`` and still resolve
|
||||||
|
# to the loopback address — and be blocked.
|
||||||
|
with pytest.raises(UnsafeURLError):
|
||||||
|
validate_outbound_url("http://localhost./")
|
||||||
|
|
||||||
|
def test_rejects_bad_scheme_uppercase(self) -> None:
|
||||||
|
with pytest.raises(UnsafeURLError):
|
||||||
|
validate_outbound_url("FILE:///etc/passwd")
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorMessages:
|
||||||
|
def test_error_does_not_leak_long_hosts(self) -> None:
|
||||||
|
with pytest.raises(UnsafeURLError) as ei:
|
||||||
|
validate_outbound_url("http://" + "a" * 1024 + ".invalid/")
|
||||||
|
# Truncated to 64 chars in the error string.
|
||||||
|
assert len(str(ei.value)) < 256
|
||||||
|
|
||||||
|
|
||||||
|
class TestPinnedResolverSync:
|
||||||
|
def test_pin_returns_pinned_ip(self) -> None:
|
||||||
|
resolver = PinnedResolver({"example.com": "93.184.216.34"})
|
||||||
|
# Just exercise the dict path — full resolve runs in async tests.
|
||||||
|
assert resolver._map["example.com"] == "93.184.216.34" # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncFullValidator:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_resolved_ip(self) -> None:
|
||||||
|
# Literal IP — no DNS lookup; we still get back a ValidatedURL.
|
||||||
|
result = await avalidate_outbound_url_full("http://8.8.8.8/")
|
||||||
|
assert result.ip == "8.8.8.8"
|
||||||
|
assert result.host == "8.8.8.8"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_blocked_literal(self) -> None:
|
||||||
|
with pytest.raises(UnsafeURLError):
|
||||||
|
await avalidate_outbound_url_full("http://[::ffff:127.0.0.1]/")
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"""Telegram media-group mixed-type partitioning regression test.
|
||||||
|
|
||||||
|
Telegram rejects sendMediaGroup payloads that mix ``document`` with
|
||||||
|
``photo``/``video``. The client must partition before chunking so a
|
||||||
|
mixed input list still delivers all assets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_keeps_photo_video_together() -> None:
|
||||||
|
parts = TelegramClient._partition_media_by_kind([
|
||||||
|
{"type": "photo", "url": "p1"},
|
||||||
|
{"type": "video", "url": "v1"},
|
||||||
|
{"type": "photo", "url": "p2"},
|
||||||
|
])
|
||||||
|
assert len(parts) == 1
|
||||||
|
assert [a["url"] for a in parts[0]] == ["p1", "v1", "p2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_separates_documents_from_media() -> None:
|
||||||
|
parts = TelegramClient._partition_media_by_kind([
|
||||||
|
{"type": "photo", "url": "p1"},
|
||||||
|
{"type": "document", "url": "d1"},
|
||||||
|
{"type": "video", "url": "v1"},
|
||||||
|
])
|
||||||
|
assert len(parts) == 3
|
||||||
|
assert parts[0][0]["url"] == "p1"
|
||||||
|
assert parts[1][0]["url"] == "d1"
|
||||||
|
assert parts[2][0]["url"] == "v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_groups_consecutive_documents() -> None:
|
||||||
|
parts = TelegramClient._partition_media_by_kind([
|
||||||
|
{"type": "document", "url": "d1"},
|
||||||
|
{"type": "document", "url": "d2"},
|
||||||
|
{"type": "photo", "url": "p1"},
|
||||||
|
])
|
||||||
|
assert len(parts) == 2
|
||||||
|
assert [a["url"] for a in parts[0]] == ["d1", "d2"]
|
||||||
|
assert parts[1][0]["url"] == "p1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_empty() -> None:
|
||||||
|
assert TelegramClient._partition_media_by_kind([]) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_defaults_missing_type_to_photo() -> None:
|
||||||
|
"""Items without an explicit type are treated as photos for grouping."""
|
||||||
|
parts = TelegramClient._partition_media_by_kind([
|
||||||
|
{"url": "x"}, # no type
|
||||||
|
{"type": "video", "url": "v"},
|
||||||
|
])
|
||||||
|
assert len(parts) == 1
|
||||||
Reference in New Issue
Block a user