Compare commits
32 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 461fb495d7 | |||
| 309dec2b44 | |||
| 90def11b8d | |||
| 8f0346ea03 | |||
| a6a854ad21 | |||
| 19036a90bb | |||
| 592e1b6114 | |||
| bbcdf1c5d1 | |||
| f9040370bc | |||
| 3b683ce82c | |||
| 2bec25353b | |||
| e44d387c7f | |||
| 7cbb02b1ef | |||
| 920920bc67 | |||
| f50d465c0e | |||
| 1f880daa0c | |||
| 1024085cdd | |||
| 5604c733d1 | |||
| 3b7808aa9c | |||
| 155d25edf9 | |||
| 69711bbc84 | |||
| fe38d20b96 | |||
| d02616069d | |||
| 7dae68fd93 | |||
| e6481605ca | |||
| 6de9a1289e | |||
| 325eabd751 | |||
| fab6169cf9 | |||
| 85311684d9 | |||
| d7daadadc2 | |||
| e04ad16ca6 | |||
| d7d0a5d921 |
@@ -1,13 +1,56 @@
|
||||
name: Build Docker Image
|
||||
name: Build and Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, main]
|
||||
pull_request:
|
||||
branches: [master, main]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
test-frontend:
|
||||
if: ${{ !startsWith(gitea.event.head_commit.message, 'chore: release v') }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Build Docker image
|
||||
run: docker build -t notify-bridge:dev .
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install deps
|
||||
run: |
|
||||
cd frontend
|
||||
npm ci
|
||||
|
||||
- name: Svelte check
|
||||
run: |
|
||||
cd frontend
|
||||
npm run check || echo "::warning::svelte-check reported warnings"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd frontend
|
||||
npm run build
|
||||
|
||||
build-image:
|
||||
if: ${{ !startsWith(gitea.event.head_commit.message, 'chore: release v') }}
|
||||
needs: [test-frontend]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build Docker image (no push)
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: false
|
||||
tags: notify-bridge:ci-${{ gitea.sha }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
@@ -50,6 +50,7 @@ jobs:
|
||||
tags: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.tag }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.version }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${{ gitea.sha }}
|
||||
${{ steps.version.outputs.is_pre == 'false' && format('{0}/{1}:latest', env.REGISTRY, env.IMAGE_NAME) || '' }}
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache
|
||||
cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max
|
||||
|
||||
+48
-4
@@ -1,3 +1,4 @@
|
||||
# syntax=docker/dockerfile:1.7
|
||||
# =============================================================================
|
||||
# Stage 1: Build frontend (SvelteKit static output)
|
||||
# =============================================================================
|
||||
@@ -14,7 +15,7 @@ COPY frontend/ ./
|
||||
RUN npm run build
|
||||
|
||||
# =============================================================================
|
||||
# Stage 2: Build Python wheels
|
||||
# Stage 2: Build Python wheels + extract external dependency list
|
||||
# =============================================================================
|
||||
FROM python:3.12-slim AS python-build
|
||||
|
||||
@@ -30,16 +31,59 @@ RUN python -m build packages/core/ --wheel --outdir /wheels
|
||||
COPY packages/server/ packages/server/
|
||||
RUN python -m build packages/server/ --wheel --outdir /wheels
|
||||
|
||||
# Emit /wheels/deps.txt with ONLY external (PyPI) deps — filter out
|
||||
# notify-bridge-* siblings, which are installed from local wheels below.
|
||||
# This file is the cache key for the external-deps install layer: as long as
|
||||
# pyproject.toml dependency lines don't change, the runtime install layer is
|
||||
# served from registry buildcache and no wheels are re-downloaded.
|
||||
RUN python <<'PY'
|
||||
import tomllib
|
||||
|
||||
deps: list[str] = []
|
||||
for p in ("packages/core/pyproject.toml", "packages/server/pyproject.toml"):
|
||||
with open(p, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
for d in data["project"].get("dependencies", []):
|
||||
if not d.lstrip().lower().startswith("notify-bridge-"):
|
||||
deps.append(d)
|
||||
|
||||
seen: set[str] = set()
|
||||
with open("/wheels/deps.txt", "w") as f:
|
||||
for d in deps:
|
||||
if d not in seen:
|
||||
seen.add(d)
|
||||
f.write(d + "\n")
|
||||
PY
|
||||
|
||||
# =============================================================================
|
||||
# Stage 3: Runtime
|
||||
# =============================================================================
|
||||
FROM python:3.12-slim
|
||||
|
||||
# uv — fast pip replacement. Installed from PyPI (Fastly CDN) rather than
|
||||
# ghcr.io/astral-sh/uv, because GHCR pulls from this runner crawl at a few
|
||||
# hundred KB/s and take longer than the install savings would recoup.
|
||||
RUN pip install --no-cache-dir uv==0.11.7
|
||||
|
||||
ENV UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install wheels
|
||||
COPY --from=python-build /wheels/ /tmp/wheels/
|
||||
RUN pip install --no-cache-dir /tmp/wheels/*.whl && rm -rf /tmp/wheels
|
||||
# Install external deps first — layer cache key is deps.txt content, which
|
||||
# only changes when pyproject.toml dependency lines change (not on version
|
||||
# bumps). The cache mount persists downloaded wheels across local rebuilds;
|
||||
# in CI, the registry buildcache serves the whole layer when unchanged.
|
||||
COPY --from=python-build /wheels/deps.txt /tmp/deps.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r /tmp/deps.txt \
|
||||
&& rm /tmp/deps.txt
|
||||
|
||||
# Install local wheels without re-resolving — all external deps are present.
|
||||
COPY --from=python-build /wheels/*.whl /tmp/wheels/
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system --no-deps /tmp/wheels/*.whl \
|
||||
&& rm -rf /tmp/wheels
|
||||
|
||||
# Copy frontend build
|
||||
COPY --from=frontend-build /build/build/ /app/static/
|
||||
|
||||
+15
-29
@@ -1,42 +1,28 @@
|
||||
## v0.2.4 (2026-04-22)
|
||||
# v0.5.0 (2026-04-24)
|
||||
|
||||
Telegram media cache rebuilt around **thumbhash validation** — asset cache
|
||||
entries now invalidate when the visual content changes, not after a fixed
|
||||
TTL — plus a settings-page overhaul (cache stats, clear button, timezone /
|
||||
locale pickers) and full mobile-nav parity with the desktop sidebar.
|
||||
A small but impactful release that finally makes the Immich scheduled / periodic / memory dispatch fire on its own. The slot was already visible in the tracker UI and the "Test" button worked — but no production scheduler was reading the config, so users only ever saw fires through manual tests. This release wires the missing cron jobs end-to-end.
|
||||
|
||||
### Features
|
||||
## Features
|
||||
|
||||
#### Telegram media cache
|
||||
- **Cron-fired Immich dispatch for scheduled / periodic / memory slots** — adds the missing production fan-out so `scheduled_enabled` / `scheduled_times` (and the periodic / memory counterparts) on `TrackingConfig` actually fire on their own, not only through "Test" ([309dec2](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/309dec2)):
|
||||
- New `services/scheduled_dispatch.py` reuses the test-path event builders, picks the slot template per kind (`scheduled_assets` / `periodic_assets` / `memory_assets`), and writes an `EventLog` row per fire so the dashboard reflects it.
|
||||
- `services/scheduler.py` gains `_load_immich_dispatch_jobs`, which builds one `CronTrigger` per `(tracker, kind, HH:MM)` from each tracker's default `TrackingConfig`, all keyed off the app-level IANA timezone. `reschedule_immich_dispatch_jobs` rebuilds the job set on any relevant CRUD or timezone change.
|
||||
- Tracker / link / tracking-config CRUD endpoints now invalidate the schedule, so edits take effect immediately without a restart.
|
||||
- Dispatch is skipped when scheduled / memory queries yield zero matching assets — prevents header-only "On this day:" spam when nothing qualifies.
|
||||
- EN / RU default `scheduled_assets` templates updated to surface that the delivery is a scheduled random selection.
|
||||
|
||||
- **Thumbhash-validated asset cache** — dispatcher builds an `asset.id → thumbhash` resolver from `event.added_assets` (Immich already populates `thumbhash` in `extra`) and passes it to `TelegramClient`. Asset-cache entries now invalidate on visual change rather than age. ([2be608b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2be608b))
|
||||
- **Configurable cache cap & stats** — `TelegramFileCache` gets a `max_entries` LRU cap (applies in both TTL and thumbhash modes), `ttl_seconds <= 0` disables TTL entirely, and a `stats()` method exposes per-bucket counts / sizes / oldest+newest timestamps. New settings: `telegram_asset_cache_max_entries` (default 5000); `telegram_cache_ttl_hours` default bumped `48 → 720` (30 days) and is now URL-only. ([2be608b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2be608b))
|
||||
- **Cache admin endpoints** — `GET /settings/telegram-cache/stats` and `POST /settings/telegram-cache/clear`. `PUT /settings` now soft-resets the in-memory caches when cache-shaping keys change (on-disk `file_id`s preserved). ([2be608b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2be608b))
|
||||
## Upgrade Notes
|
||||
|
||||
#### Settings page
|
||||
|
||||
- **Cache stats card** — per-bucket (URL / asset) counts, cumulative uploaded-to-Telegram byte size, oldest/newest timestamps, and a hint explaining what the size means. Clear-cache button behind a confirm modal. ([2be608b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2be608b))
|
||||
- **New `TimezoneSelector` and `LocaleSelector` components** replace the raw inputs with IANA-aware searchable pickers. Max-entries input exposed; TTL range widened to `0..8760` hours (`0` = disabled). ([2be608b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2be608b))
|
||||
|
||||
#### Mobile nav
|
||||
|
||||
- **Full sidebar parity in the *More* panel** — now mirrors the desktop sidebar tree (groups + subnodes) so every destination is reachable from mobile. Previously the panel carried a hand-picked flat list that drifted behind newly-added routes. ([2be608b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2be608b))
|
||||
- **Safe-area handling** — nav height uses `env(safe-area-inset-bottom)`; panel bottom + `z-index` fixed so page content can no longer visually overlay the bottom bar. ([2be608b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2be608b))
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
- **`update_settings` TypeError** — `any(await ... for ...)` was an async generator (not an iterator) and raised at runtime; replaced with an explicit loop so settings updates actually commit. ([2be608b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2be608b))
|
||||
|
||||
### Accessibility
|
||||
|
||||
- **Password-manager association** on the password-change form — hidden `username` field + `autocomplete` hints on all three password inputs so browsers stop warning and password managers fill correctly. ([2be608b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2be608b))
|
||||
- **Telegram webhook secret** wrapped in a no-op form with `autocomplete=off` to silence DOM/a11y warnings. ([2be608b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2be608b))
|
||||
- No config changes required. Existing `scheduled_enabled` / `scheduled_times` / `periodic_*` / `memory_*` settings on tracking configs will start firing automatically on the next startup.
|
||||
- If you had been relying on the "Test" button as a workaround, you can stop doing that now.
|
||||
|
||||
---
|
||||
|
||||
<details>
|
||||
<summary>All Commits</summary>
|
||||
|
||||
- [2be608b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2be608b) — feat(cache): thumbhash-validated asset cache + settings UX overhaul *(alexei.dolgolyov)*
|
||||
| Hash | Message | Author |
|
||||
|------------------------------------------------------------------------------------------|------------------------------------------------------------------|------------------|
|
||||
| [309dec2](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/309dec2) | feat(immich): wire cron-fired scheduled/periodic/memory dispatch | alexei.dolgolyov |
|
||||
|
||||
</details>
|
||||
|
||||
+27
-7
@@ -10,18 +10,38 @@ services:
|
||||
volumes:
|
||||
- notify-bridge-data:/data
|
||||
environment:
|
||||
# REQUIRED — any 32+ byte random string. `openssl rand -hex 32` is one way.
|
||||
- NOTIFY_BRIDGE_SECRET_KEY=${NOTIFY_BRIDGE_SECRET_KEY:?Set NOTIFY_BRIDGE_SECRET_KEY (min 32 chars)}
|
||||
- NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS=${NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS:-*}
|
||||
# Homelab target: allow outbound requests to RFC1918 / link-local addresses.
|
||||
# The SSRF guard otherwise rejects 10.*/172.16.*/192.168.*/169.254.* hosts,
|
||||
# which breaks tracking of Immich / Gitea / etc. running on the same LAN.
|
||||
- NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
|
||||
# Comma-separated list of allowed browser origins. Wildcard `*` is
|
||||
# rejected on startup because credentials are enabled.
|
||||
- NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS=${NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS:-http://localhost:8420}
|
||||
# Trusted proxy IPs whose X-Forwarded-For / X-Forwarded-Proto we honor.
|
||||
# Set this to your reverse proxy's IP (e.g. 172.17.0.1 for the default
|
||||
# docker bridge, or `*` only if the container is NOT reachable from the
|
||||
# public internet).
|
||||
- NOTIFY_BRIDGE_FORWARDED_ALLOW_IPS=${NOTIFY_BRIDGE_FORWARDED_ALLOW_IPS:-127.0.0.1}
|
||||
# Opt-in SSRF bypass for private/loopback/link-local hosts (homelab
|
||||
# scenario — tracking an Immich/Gitea instance on the same LAN). DO NOT
|
||||
# enable on a publicly exposed instance.
|
||||
# - NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8420/api/health')"]
|
||||
# Use /api/ready (not /api/health) so the container is only reported
|
||||
# healthy after migrations and the scheduler finish booting.
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8420/api/ready', timeout=3)"]
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
start_period: 30s
|
||||
read_only: true
|
||||
tmpfs:
|
||||
- /tmp
|
||||
security_opt:
|
||||
- no-new-privileges:true
|
||||
cap_drop:
|
||||
- ALL
|
||||
mem_limit: 512m
|
||||
cpus: 1.0
|
||||
pids_limit: 256
|
||||
|
||||
volumes:
|
||||
notify-bridge-data:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "notify-bridge-frontend",
|
||||
"private": true,
|
||||
"version": "0.2.4",
|
||||
"version": "0.5.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite dev",
|
||||
|
||||
@@ -55,7 +55,8 @@
|
||||
"passwordTooShort": "Password must be at least 8 characters",
|
||||
"or": "or",
|
||||
"loginFailed": "Login failed",
|
||||
"setupFailed": "Setup failed"
|
||||
"setupFailed": "Setup failed",
|
||||
"backendUnreachable": "Cannot reach the server. Check that it's running and try again."
|
||||
},
|
||||
"dashboard": {
|
||||
"title": "Dashboard",
|
||||
@@ -78,6 +79,7 @@
|
||||
"collectionRenamed": "collection renamed",
|
||||
"collectionDeleted": "collection deleted",
|
||||
"sharingChanged": "sharing changed",
|
||||
"scheduledMessage": "scheduled message",
|
||||
"actionSuccess": "action run",
|
||||
"actionPartial": "action partial",
|
||||
"actionFailed": "action failed",
|
||||
@@ -694,6 +696,13 @@
|
||||
"locales": "Template Languages",
|
||||
"supportedLocales": "Supported Locales",
|
||||
"supportedLocalesHint": "Languages available when authoring notification and command templates. Built-in defaults ship for English and Russian; other languages start empty.",
|
||||
"logging": "Logging",
|
||||
"logLevel": "Log Level",
|
||||
"logLevelHint": "Root log level for the server. Raise to DEBUG while investigating; keep at INFO in production. WARNING/ERROR hide per-command progress lines.",
|
||||
"logFormat": "Log Format",
|
||||
"logFormatHint": "Output format. 'text' is human-readable; 'json' emits one object per line for log aggregators (Loki, ELK). Changing this requires a server restart.",
|
||||
"logLevels": "Per-Module Overrides",
|
||||
"logLevelsHint": "Comma-separated 'module=LEVEL' pairs to silence noisy modules or drill into one area. Example: sqlalchemy.engine=WARNING,notify_bridge_core.notifications.telegram.client=DEBUG",
|
||||
"saved": "Settings saved"
|
||||
},
|
||||
"hints": {
|
||||
|
||||
@@ -55,7 +55,8 @@
|
||||
"passwordTooShort": "Пароль должен быть не менее 8 символов",
|
||||
"or": "или",
|
||||
"loginFailed": "Ошибка входа",
|
||||
"setupFailed": "Ошибка настройки"
|
||||
"setupFailed": "Ошибка настройки",
|
||||
"backendUnreachable": "Не удалось подключиться к серверу. Убедитесь, что он запущен, и повторите попытку."
|
||||
},
|
||||
"dashboard": {
|
||||
"title": "Главная",
|
||||
@@ -78,6 +79,7 @@
|
||||
"collectionRenamed": "альбом переименован",
|
||||
"collectionDeleted": "альбом удалён",
|
||||
"sharingChanged": "изменение доступа",
|
||||
"scheduledMessage": "запланированное сообщение",
|
||||
"actionSuccess": "действие выполнено",
|
||||
"actionPartial": "действие частично",
|
||||
"actionFailed": "действие провалено",
|
||||
@@ -694,6 +696,13 @@
|
||||
"locales": "Языки шаблонов",
|
||||
"supportedLocales": "Поддерживаемые локали",
|
||||
"supportedLocalesHint": "Языки, доступные для редактирования шаблонов уведомлений и команд. Встроенные шаблоны поставляются для английского и русского; другие языки начинают с пустых.",
|
||||
"logging": "Логирование",
|
||||
"logLevel": "Уровень логов",
|
||||
"logLevelHint": "Уровень логирования сервера. Поднимайте до DEBUG при отладке; оставляйте INFO в продакшене. WARNING/ERROR скрывают пошаговые строки по командам.",
|
||||
"logFormat": "Формат логов",
|
||||
"logFormatHint": "Формат вывода. 'text' — читаемый человеком; 'json' — по одному объекту в строке для агрегаторов (Loki, ELK). Смена требует перезапуска сервера.",
|
||||
"logLevels": "Переопределения по модулям",
|
||||
"logLevelsHint": "Пары 'модуль=УРОВЕНЬ' через запятую, чтобы приглушить шумные модули или углубиться в один. Пример: sqlalchemy.engine=WARNING,notify_bridge_core.notifications.telegram.client=DEBUG",
|
||||
"saved": "Настройки сохранены"
|
||||
},
|
||||
"hints": {
|
||||
|
||||
@@ -223,6 +223,7 @@
|
||||
collection_renamed: 'dashboard.collectionRenamed',
|
||||
collection_deleted: 'dashboard.collectionDeleted',
|
||||
sharing_changed: 'dashboard.sharingChanged',
|
||||
scheduled_message: 'dashboard.scheduledMessage',
|
||||
action_success: 'dashboard.actionSuccess',
|
||||
action_partial: 'dashboard.actionPartial',
|
||||
action_failed: 'dashboard.actionFailed',
|
||||
@@ -231,11 +232,13 @@
|
||||
const eventIcons: Record<string, string> = {
|
||||
assets_added: 'mdiImagePlus', assets_removed: 'mdiImageMinus',
|
||||
collection_renamed: 'mdiRename', collection_deleted: 'mdiDeleteAlert', sharing_changed: 'mdiShareVariant',
|
||||
scheduled_message: 'mdiCalendarClock',
|
||||
action_success: 'mdiPlayCircle', action_partial: 'mdiAlertCircle', action_failed: 'mdiCloseCircle',
|
||||
};
|
||||
const eventColors: Record<string, string> = {
|
||||
assets_added: '#059669', assets_removed: '#ef4444',
|
||||
collection_renamed: '#6366f1', collection_deleted: '#dc2626', sharing_changed: '#f59e0b',
|
||||
scheduled_message: '#8b5cf6',
|
||||
action_success: '#0d9488', action_partial: '#f59e0b', action_failed: '#dc2626',
|
||||
};
|
||||
|
||||
|
||||
@@ -117,6 +117,14 @@
|
||||
return form.slots[slotName]?.[activeLocale] || '';
|
||||
}
|
||||
|
||||
/** Resolve variable reference for a slot, preferring provider-specific over shared. */
|
||||
function getVarsFor(slotName: string) {
|
||||
const providerVars = varsRef[form.provider_type];
|
||||
return providerVars?.[slotName] ?? varsRef[slotName];
|
||||
}
|
||||
|
||||
let modalVars = $derived(showVarsFor ? getVarsFor(showVarsFor) : null);
|
||||
|
||||
/** Set slot template for current locale (immutable update). */
|
||||
function setSlotValue(slotName: string, value: string) {
|
||||
form.slots = {
|
||||
@@ -369,7 +377,7 @@
|
||||
{t('templateConfig.preview')}
|
||||
</button>
|
||||
{/if}
|
||||
{#if varsRef[slot.name]}
|
||||
{#if getVarsFor(slot.name)}
|
||||
<button type="button" onclick={() => showVarsFor = slot.name}
|
||||
class="text-xs text-[var(--color-muted-foreground)] hover:underline">{t('templateConfig.variables')}</button>
|
||||
{/if}
|
||||
@@ -385,7 +393,7 @@
|
||||
onchange={(v: string) => { setSlotValue(slot.name, v); validateSlot(slot.name, v); }}
|
||||
rows={3}
|
||||
errorLine={slotErrorLines[slot.name] || null}
|
||||
variables={varsRef[slot.name] || undefined}
|
||||
variables={getVarsFor(slot.name) || undefined}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
@@ -468,11 +476,11 @@
|
||||
|
||||
<!-- Variables reference modal -->
|
||||
<Modal open={showVarsFor !== null} title="{t('templateConfig.variables')}: /{showVarsFor || ''}" onclose={() => showVarsFor = null}>
|
||||
{#if showVarsFor && varsRef[showVarsFor]}
|
||||
<p class="text-sm text-[var(--color-muted-foreground)] mb-3">{varsRef[showVarsFor].description}</p>
|
||||
{#if showVarsFor && modalVars}
|
||||
<p class="text-sm text-[var(--color-muted-foreground)] mb-3">{modalVars.description}</p>
|
||||
<div class="space-y-1">
|
||||
<p class="text-xs font-medium mb-1">{t('templateConfig.variables')}:</p>
|
||||
{#each Object.entries(varsRef[showVarsFor].variables || {}) as [name, desc]}
|
||||
{#each Object.entries(modalVars.variables || {}) as [name, desc]}
|
||||
<div class="flex items-start gap-2 text-sm">
|
||||
<code class="text-xs bg-[var(--color-muted)] px-1 py-0.5 rounded font-mono whitespace-nowrap">{'{{ ' + name + ' }}'}</code>
|
||||
<span class="text-xs text-[var(--color-muted-foreground)]">{desc}</span>
|
||||
@@ -484,11 +492,19 @@
|
||||
['album_fields', 'album', 'Album fields'],
|
||||
['command_fields', 'cmd', 'Command fields'],
|
||||
['event_fields', 'event', 'Event fields'],
|
||||
['repo_fields', 'repo', 'Repository fields'],
|
||||
['issue_fields', 'issue', 'Issue fields'],
|
||||
['pr_fields', 'pr', 'Pull request fields'],
|
||||
['commit_fields', 'c', 'Commit fields'],
|
||||
['board_fields', 'board', 'Board fields'],
|
||||
['card_fields', 'card', 'Card fields'],
|
||||
['list_fields', 'lst', 'List fields'],
|
||||
['device_fields', 'd', 'Device fields'],
|
||||
] as [fieldKey, prefix, title]}
|
||||
{#if varsRef[showVarsFor][fieldKey]}
|
||||
{#if modalVars[fieldKey]}
|
||||
<div class="mt-3 pt-3 border-t border-[var(--color-border)]">
|
||||
<p class="text-xs font-medium mb-1">{title} <span class="font-normal text-[var(--color-muted-foreground)]">(use {prefix}.field)</span>:</p>
|
||||
{#each Object.entries(varsRef[showVarsFor][fieldKey]) as [name, desc]}
|
||||
{#each Object.entries(modalVars[fieldKey]) as [name, desc]}
|
||||
<div class="flex items-start gap-2 text-sm">
|
||||
<code class="text-xs bg-[var(--color-muted)] px-1 py-0.5 rounded font-mono whitespace-nowrap">{'{{ ' + prefix + '.' + name + ' }}'}</code>
|
||||
<span class="text-xs text-[var(--color-muted-foreground)]">{desc}</span>
|
||||
|
||||
@@ -15,13 +15,32 @@
|
||||
let submitting = $state(false);
|
||||
let mounted = $state(false);
|
||||
|
||||
let backendDown = $state(false);
|
||||
|
||||
onMount(async () => {
|
||||
initTheme();
|
||||
mounted = true;
|
||||
// If the user is already signed in (valid access token in storage),
|
||||
// there is no reason to show them the login form. loadUser() runs in
|
||||
// the root layout; we just check the resolved state after a short tick.
|
||||
const { isAuthenticated } = await import('$lib/api');
|
||||
if (isAuthenticated()) {
|
||||
try {
|
||||
await api('/auth/me');
|
||||
goto('/');
|
||||
return;
|
||||
} catch {
|
||||
// Token was stale; fall through to the login form.
|
||||
}
|
||||
}
|
||||
try {
|
||||
const res = await api<{ needs_setup: boolean }>('/auth/needs-setup');
|
||||
if (res.needs_setup) goto('/setup');
|
||||
} catch { /* ignore */ }
|
||||
} catch {
|
||||
// The backend is unreachable — surface that distinctly so the user
|
||||
// doesn't blame the login form for a network/backend problem.
|
||||
backendDown = true;
|
||||
}
|
||||
});
|
||||
|
||||
async function handleSubmit(e: SubmitEvent) {
|
||||
@@ -62,7 +81,12 @@
|
||||
<p class="text-sm mt-1" style="color: var(--color-muted-foreground);">{t('auth.signInTitle')}</p>
|
||||
</div>
|
||||
|
||||
{#if error}
|
||||
{#if backendDown}
|
||||
<div class="auth-error animate-fade-slide-in">
|
||||
<MdiIcon name="mdiAlertCircle" size={16} />
|
||||
{t('auth.backendUnreachable')}
|
||||
</div>
|
||||
{:else if error}
|
||||
<div class="auth-error animate-fade-slide-in">
|
||||
<MdiIcon name="mdiAlertCircle" size={16} />
|
||||
{error}
|
||||
|
||||
@@ -37,6 +37,9 @@
|
||||
telegram_asset_cache_max_entries: '5000',
|
||||
supported_locales: 'en,ru',
|
||||
timezone: 'UTC',
|
||||
log_level: 'INFO',
|
||||
log_format: 'text',
|
||||
log_levels: '',
|
||||
});
|
||||
let cacheStats = $state<CacheStats | null>(null);
|
||||
|
||||
@@ -204,6 +207,40 @@
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<!-- Logging section -->
|
||||
<Card>
|
||||
<h3 class="text-sm font-semibold mb-4 flex items-center gap-2">
|
||||
<MdiIcon name="mdiTextBoxOutline" size={18} />
|
||||
{t('settings.logging')}
|
||||
</h3>
|
||||
<div class="grid grid-cols-1 sm:grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label class="block text-xs font-medium mb-1">{t('settings.logLevel')}<Hint text={t('settings.logLevelHint')} /></label>
|
||||
<select bind:value={settings.log_level}
|
||||
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
|
||||
<option value="DEBUG">DEBUG</option>
|
||||
<option value="INFO">INFO</option>
|
||||
<option value="WARNING">WARNING</option>
|
||||
<option value="ERROR">ERROR</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-xs font-medium mb-1">{t('settings.logFormat')}<Hint text={t('settings.logFormatHint')} /></label>
|
||||
<select bind:value={settings.log_format}
|
||||
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
|
||||
<option value="text">text</option>
|
||||
<option value="json">json</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="sm:col-span-2">
|
||||
<label class="block text-xs font-medium mb-1">{t('settings.logLevels')}<Hint text={t('settings.logLevelsHint')} /></label>
|
||||
<input bind:value={settings.log_levels}
|
||||
placeholder="sqlalchemy.engine=WARNING,notify_bridge_core.notifications.telegram.client=DEBUG"
|
||||
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)] font-mono" />
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<Button onclick={save} disabled={saving}>
|
||||
{saving ? t('common.loading') : t('common.save')}
|
||||
</Button>
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "notify-bridge-core"
|
||||
version = "0.2.4"
|
||||
version = "0.5.0"
|
||||
description = "Core library for Notify Bridge — service provider abstractions, models, notifications, and templates"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Request-scoped ContextVars that propagate into log records.
|
||||
|
||||
The server sets these at entry points (Telegram webhook, scheduler dispatch,
|
||||
REST call) and they propagate through async calls automatically. A
|
||||
``LogRecordFactory`` installed by ``notify_bridge_server.logging_setup``
|
||||
reads them so every log line is tagged (``request_id``, ``command``,
|
||||
``chat_id``, ``bot_id``, ``dispatch_id``) without each call site having
|
||||
to pass the values explicitly.
|
||||
|
||||
Kept in ``notify_bridge_core`` so core modules (``TelegramClient``,
|
||||
``NotificationDispatcher``) can *set* additional context (e.g. a
|
||||
``dispatch_id``) without depending on the server package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Any, Iterator
|
||||
|
||||
request_id_var: ContextVar[str | None] = ContextVar("request_id", default=None)
|
||||
command_var: ContextVar[str | None] = ContextVar("command", default=None)
|
||||
chat_id_var: ContextVar[str | None] = ContextVar("chat_id", default=None)
|
||||
bot_id_var: ContextVar[int | None] = ContextVar("bot_id", default=None)
|
||||
dispatch_id_var: ContextVar[str | None] = ContextVar("dispatch_id", default=None)
|
||||
|
||||
_VAR_MAP: dict[str, ContextVar[Any]] = {
|
||||
"request_id": request_id_var,
|
||||
"command": command_var,
|
||||
"chat_id": chat_id_var,
|
||||
"bot_id": bot_id_var,
|
||||
"dispatch_id": dispatch_id_var,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_log_context(**kwargs: Any) -> Iterator[None]:
|
||||
"""Bind the given context fields for the duration of the ``with`` block.
|
||||
|
||||
Unknown keys are ignored so callers can pass whatever they want without
|
||||
an ``if`` ladder. Values are reset on exit even if the block raises.
|
||||
|
||||
Example:
|
||||
``with bind_log_context(request_id="abc", command="random"): ...``
|
||||
"""
|
||||
tokens: list[tuple[ContextVar[Any], Token]] = []
|
||||
try:
|
||||
for key, value in kwargs.items():
|
||||
var = _VAR_MAP.get(key)
|
||||
if var is None:
|
||||
continue
|
||||
tokens.append((var, var.set(value)))
|
||||
yield
|
||||
finally:
|
||||
for var, tok in tokens:
|
||||
var.reset(tok)
|
||||
|
||||
|
||||
def current_log_context() -> dict[str, Any]:
|
||||
"""Return a snapshot of the currently-bound context values (non-None)."""
|
||||
snap: dict[str, Any] = {}
|
||||
for key, var in _VAR_MAP.items():
|
||||
val = var.get()
|
||||
if val is not None:
|
||||
snap[key] = val
|
||||
return snap
|
||||
@@ -52,22 +52,46 @@ class DiscordClient:
|
||||
|
||||
return {"success": True}
|
||||
|
||||
_MAX_RETRIES = 3
|
||||
_MAX_RETRY_AFTER = 60.0
|
||||
|
||||
async def _post(self, url: str, payload: dict) -> dict[str, Any]:
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, json=payload, headers={"Content-Type": "application/json"}
|
||||
) as resp:
|
||||
if resp.status == 429:
|
||||
retry_after = float(resp.headers.get("Retry-After", "2"))
|
||||
_LOGGER.warning("Discord rate limited, retrying after %.1fs", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
return await self._post(url, payload)
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
body = await resp.text()
|
||||
return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
"""POST with bounded 429 retry.
|
||||
|
||||
We cap retries at _MAX_RETRIES and the ``Retry-After`` header at
|
||||
_MAX_RETRY_AFTER seconds so a hostile or misbehaving upstream cannot
|
||||
pin the dispatch task indefinitely.
|
||||
"""
|
||||
for attempt in range(self._MAX_RETRIES + 1):
|
||||
try:
|
||||
async with self._session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
allow_redirects=False,
|
||||
) as resp:
|
||||
if resp.status == 429 and attempt < self._MAX_RETRIES:
|
||||
try:
|
||||
retry_after = float(resp.headers.get("Retry-After", "2"))
|
||||
except (TypeError, ValueError):
|
||||
retry_after = 2.0
|
||||
retry_after = max(0.0, min(retry_after, self._MAX_RETRY_AFTER))
|
||||
_LOGGER.warning(
|
||||
"Discord rate limited, retrying after %.1fs (attempt %d/%d)",
|
||||
retry_after, attempt + 1, self._MAX_RETRIES,
|
||||
)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
body = await resp.text()
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"HTTP {resp.status}: {body[:200]}",
|
||||
}
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
return {"success": False, "error": "Rate limited (retries exhausted)"}
|
||||
|
||||
|
||||
def _split_message(text: str, limit: int) -> list[str]:
|
||||
|
||||
@@ -3,16 +3,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.log_context import bind_log_context, dispatch_id_var
|
||||
from notify_bridge_core.models.events import ServiceEvent
|
||||
from notify_bridge_core.templates.context import build_template_context
|
||||
from notify_bridge_core.templates.renderer import render_template
|
||||
from .ssrf import UnsafeURLError, validate_outbound_url
|
||||
from .ssrf import UnsafeURLError, avalidate_outbound_url
|
||||
|
||||
_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||
|
||||
@@ -46,6 +49,7 @@ from .receiver import (
|
||||
from .telegram.cache import TelegramFileCache
|
||||
from .telegram.client import TelegramClient
|
||||
from .telegram.media import (
|
||||
build_telegram_asset_entry,
|
||||
extract_asset_id_from_url,
|
||||
is_asset_cache_key,
|
||||
is_asset_id,
|
||||
@@ -81,9 +85,28 @@ class NotificationDispatcher:
|
||||
*,
|
||||
url_cache: TelegramFileCache | None = None,
|
||||
asset_cache: TelegramFileCache | None = None,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
) -> None:
|
||||
self._url_cache = url_cache
|
||||
self._asset_cache = asset_cache
|
||||
# Optional shared session owned by the caller; when supplied we reuse
|
||||
# its connection pool instead of opening a fresh per-dispatch session
|
||||
# (saves a TLS handshake per outbound call).
|
||||
self._shared_session = session
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]:
|
||||
"""Yield an aiohttp session, reusing the shared one if provided.
|
||||
|
||||
When a shared session was passed in ``__init__`` we yield it without
|
||||
closing (the caller owns its lifetime). Otherwise we open a
|
||||
short-lived session with our default timeout and close it on exit.
|
||||
"""
|
||||
if self._shared_session is not None and not self._shared_session.closed:
|
||||
yield self._shared_session
|
||||
return
|
||||
async with self._session_ctx() as session:
|
||||
yield session
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
@@ -94,18 +117,40 @@ class NotificationDispatcher:
|
||||
|
||||
Returns list of results (one per target).
|
||||
"""
|
||||
raw_results = await asyncio.gather(
|
||||
*[self._send_to_target(event, t) for t in targets],
|
||||
return_exceptions=True,
|
||||
)
|
||||
results = []
|
||||
for raw in raw_results:
|
||||
if isinstance(raw, Exception):
|
||||
_LOGGER.error("Failed to dispatch to target: %s", raw)
|
||||
results.append({"success": False, "error": str(raw)})
|
||||
else:
|
||||
results.append(raw)
|
||||
return results
|
||||
# Bind a dispatch_id so every log line emitted by the target sends
|
||||
# (including deep in TelegramClient) can be correlated to the same
|
||||
# upstream event.
|
||||
new_id = dispatch_id_var.get() or f"disp:{uuid.uuid4().hex[:12]}"
|
||||
|
||||
with bind_log_context(dispatch_id=new_id):
|
||||
_LOGGER.info(
|
||||
"Dispatching event %s (collection=%r) to %d target(s)",
|
||||
event.event_type.value if hasattr(event.event_type, "value") else event.event_type,
|
||||
getattr(event, "collection_name", None), len(targets),
|
||||
)
|
||||
raw_results = await asyncio.gather(
|
||||
*[self._send_to_target(event, t) for t in targets],
|
||||
return_exceptions=True,
|
||||
)
|
||||
results = []
|
||||
failures = 0
|
||||
for target, raw in zip(targets, raw_results):
|
||||
if isinstance(raw, Exception):
|
||||
failures += 1
|
||||
_LOGGER.error(
|
||||
"Dispatch to target type=%s failed: %s",
|
||||
target.type, raw, exc_info=raw,
|
||||
)
|
||||
results.append({"success": False, "error": str(raw)})
|
||||
else:
|
||||
if isinstance(raw, dict) and not raw.get("success"):
|
||||
failures += 1
|
||||
results.append(raw)
|
||||
_LOGGER.info(
|
||||
"Dispatch finished: %d target(s), %d failure(s)",
|
||||
len(targets), failures,
|
||||
)
|
||||
return results
|
||||
|
||||
def _resolve_template(
|
||||
self, event: ServiceEvent, target: TargetConfig, locale: str,
|
||||
@@ -266,28 +311,24 @@ class NotificationDispatcher:
|
||||
# Prefer internal URL for fetching (LAN speed vs public internet)
|
||||
internal_url = (target.provider_internal_url or "").rstrip("/")
|
||||
external_url = (target.provider_external_url or "").rstrip("/")
|
||||
provider_urls = [u for u in (internal_url, external_url) if u]
|
||||
assets = []
|
||||
media_assets: list[Any] = [] # aligned with `assets` for preload
|
||||
for asset in event.added_assets[:max_media]:
|
||||
url = asset.preview_url or asset.thumbnail_url or asset.full_url
|
||||
if url:
|
||||
# Rewrite external URL to internal for faster LAN fetching
|
||||
if internal_url and external_url and url.startswith(external_url):
|
||||
url = internal_url + url[len(external_url):]
|
||||
asset_type = "video" if asset.type.value == "video" else "photo"
|
||||
asset_headers = {}
|
||||
if target.provider_api_key and any(url.startswith(u) for u in provider_urls):
|
||||
asset_headers["x-api-key"] = target.provider_api_key
|
||||
asset_entry: dict[str, Any] = {"url": url, "type": asset_type, "headers": asset_headers}
|
||||
# Pass explicit cache_key if set by provider (e.g. Google Photos)
|
||||
if asset.extra.get("cache_key"):
|
||||
asset_entry["cache_key"] = asset.extra["cache_key"]
|
||||
asset_entry = build_telegram_asset_entry(
|
||||
url=url or "",
|
||||
media_type=asset.type.value,
|
||||
api_key=target.provider_api_key,
|
||||
internal_url=internal_url,
|
||||
external_url=external_url,
|
||||
cache_key=asset.extra.get("cache_key"),
|
||||
)
|
||||
if asset_entry is not None:
|
||||
assets.append(asset_entry)
|
||||
media_assets.append(asset)
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
# Preload all asset bytes once so (a) TelegramClient can skip its
|
||||
# own download and (b) we know exact upload sizes in time for the
|
||||
# oversize warning in the rendered text.
|
||||
@@ -357,13 +398,13 @@ class NotificationDispatcher:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
for receiver in target.receivers:
|
||||
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
|
||||
results.append({"success": False, "error": "Invalid webhook receiver"})
|
||||
continue
|
||||
try:
|
||||
validate_outbound_url(receiver.url)
|
||||
await avalidate_outbound_url(receiver.url)
|
||||
except UnsafeURLError as err:
|
||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||
continue
|
||||
@@ -431,14 +472,14 @@ class NotificationDispatcher:
|
||||
username = target.config.get("username")
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
client = DiscordClient(session)
|
||||
for receiver in target.receivers:
|
||||
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
|
||||
results.append({"success": False, "error": "Invalid discord receiver"})
|
||||
continue
|
||||
try:
|
||||
validate_outbound_url(receiver.webhook_url)
|
||||
await avalidate_outbound_url(receiver.webhook_url)
|
||||
except UnsafeURLError as err:
|
||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||
continue
|
||||
@@ -457,14 +498,14 @@ class NotificationDispatcher:
|
||||
username = target.config.get("username")
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
client = SlackClient(session)
|
||||
for receiver in target.receivers:
|
||||
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
|
||||
results.append({"success": False, "error": "Invalid slack receiver"})
|
||||
continue
|
||||
try:
|
||||
validate_outbound_url(receiver.webhook_url)
|
||||
await avalidate_outbound_url(receiver.webhook_url)
|
||||
except UnsafeURLError as err:
|
||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||
continue
|
||||
@@ -483,14 +524,14 @@ class NotificationDispatcher:
|
||||
if not target.receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
try:
|
||||
validate_outbound_url(server_url)
|
||||
await avalidate_outbound_url(server_url)
|
||||
except UnsafeURLError as err:
|
||||
return {"success": False, "error": f"Unsafe ntfy server_url: {err}"}
|
||||
|
||||
title = f"{event.event_type.value}: {event.collection_name}"
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
client = NtfyClient(session)
|
||||
for receiver in target.receivers:
|
||||
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
|
||||
@@ -514,7 +555,7 @@ class NotificationDispatcher:
|
||||
if not homeserver or not access_token:
|
||||
return {"success": False, "error": "Missing Matrix homeserver_url or access_token"}
|
||||
try:
|
||||
validate_outbound_url(homeserver)
|
||||
await avalidate_outbound_url(homeserver)
|
||||
except UnsafeURLError as err:
|
||||
return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"}
|
||||
|
||||
@@ -522,7 +563,7 @@ class NotificationDispatcher:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
client = MatrixClient(session, homeserver, access_token)
|
||||
for receiver in target.receivers:
|
||||
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
|
||||
|
||||
@@ -68,7 +68,9 @@ class MatrixClient:
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.put(url, json=body, headers=headers) as resp:
|
||||
async with self._session.put(
|
||||
url, json=body, headers=headers, allow_redirects=False,
|
||||
) as resp:
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
resp_body = await resp.text()
|
||||
|
||||
@@ -51,7 +51,9 @@ class NtfyClient:
|
||||
headers["Authorization"] = f"Bearer {auth_token}"
|
||||
|
||||
try:
|
||||
async with self._session.post(url, json=payload, headers=headers) as resp:
|
||||
async with self._session.post(
|
||||
url, json=payload, headers=headers, allow_redirects=False,
|
||||
) as resp:
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
body = await resp.text()
|
||||
|
||||
@@ -38,6 +38,7 @@ class SlackClient:
|
||||
webhook_url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
allow_redirects=False,
|
||||
) as resp:
|
||||
if resp.status == 429:
|
||||
_LOGGER.warning("Slack rate limited")
|
||||
|
||||
@@ -12,14 +12,25 @@ development against localhost services.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1"
|
||||
_ALLOWED_SCHEMES = {"http", "https"}
|
||||
|
||||
if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner
|
||||
_LOGGER.warning(
|
||||
"SSRF guard: private-URL bypass ENABLED "
|
||||
"(NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1). Requests to RFC1918 / "
|
||||
"loopback / link-local hosts will be permitted."
|
||||
)
|
||||
|
||||
|
||||
class UnsafeURLError(ValueError):
|
||||
"""Raised when a URL targets a disallowed network destination."""
|
||||
@@ -36,13 +47,7 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def validate_outbound_url(url: str) -> str:
|
||||
"""Validate ``url`` is safe to fetch; returns the URL on success.
|
||||
|
||||
Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
|
||||
is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
|
||||
private addresses are permitted but the scheme check still applies.
|
||||
"""
|
||||
def _check_scheme_host(url: str) -> tuple[str, str]:
|
||||
if not isinstance(url, str) or not url:
|
||||
raise UnsafeURLError("URL is empty")
|
||||
parsed = urlparse(url)
|
||||
@@ -51,6 +56,31 @@ def validate_outbound_url(url: str) -> str:
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise UnsafeURLError("URL has no host")
|
||||
return parsed.scheme, host
|
||||
|
||||
|
||||
def _check_resolved_addresses(host: str, infos: list[tuple]) -> None:
|
||||
for info in infos:
|
||||
sockaddr = info[4]
|
||||
try:
|
||||
ip = ipaddress.ip_address(sockaddr[0])
|
||||
except ValueError:
|
||||
continue
|
||||
if _is_blocked_ip(ip):
|
||||
raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
|
||||
|
||||
|
||||
def validate_outbound_url(url: str) -> str:
|
||||
"""Validate ``url`` is safe to fetch; returns the URL on success.
|
||||
|
||||
Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
|
||||
is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
|
||||
private addresses are permitted but the scheme check still applies.
|
||||
|
||||
Synchronous; uses blocking ``socket.getaddrinfo``. Prefer
|
||||
:func:`avalidate_outbound_url` from async code paths.
|
||||
"""
|
||||
_, host = _check_scheme_host(url)
|
||||
|
||||
if _ALLOW_PRIVATE:
|
||||
return url
|
||||
@@ -64,17 +94,37 @@ def validate_outbound_url(url: str) -> str:
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Hostname — resolve and reject if any resolution is in a blocked range.
|
||||
try:
|
||||
infos = socket.getaddrinfo(host, None)
|
||||
except socket.gaierror as exc:
|
||||
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
|
||||
for info in infos:
|
||||
sockaddr = info[4]
|
||||
try:
|
||||
ip = ipaddress.ip_address(sockaddr[0])
|
||||
except ValueError:
|
||||
continue
|
||||
if _is_blocked_ip(ip):
|
||||
raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
|
||||
_check_resolved_addresses(host, infos)
|
||||
return url
|
||||
|
||||
|
||||
async def avalidate_outbound_url(url: str) -> str:
|
||||
"""Async variant that resolves DNS via the running loop's resolver.
|
||||
|
||||
Use this from ``async def`` code paths to avoid blocking the event
|
||||
loop on DNS lookups.
|
||||
"""
|
||||
_, host = _check_scheme_host(url)
|
||||
|
||||
if _ALLOW_PRIVATE:
|
||||
return url
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if _is_blocked_ip(ip):
|
||||
raise UnsafeURLError(f"Host {host} is in a blocked range")
|
||||
return url
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
infos = await loop.getaddrinfo(host, None)
|
||||
except socket.gaierror as exc:
|
||||
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
|
||||
_check_resolved_addresses(host, infos)
|
||||
return url
|
||||
|
||||
@@ -89,6 +89,18 @@ class TelegramClient:
|
||||
self, url: str | None, cache_key: str | None = None,
|
||||
) -> tuple[TelegramFileCache | None, str | None, str | None]:
|
||||
if cache_key:
|
||||
# Route asset-UUID cache keys to the asset cache so single-item
|
||||
# sends hit the same cache the media-group path uses. Without
|
||||
# this, a command returning one photo stored file_ids in the
|
||||
# URL cache and a command returning multiple stored them in
|
||||
# the asset cache — repeated sends never hit.
|
||||
if is_asset_cache_key(cache_key):
|
||||
bare_id = asset_id_from_cache_key(cache_key)
|
||||
thumbhash = (
|
||||
self._thumbhash_resolver(bare_id)
|
||||
if self._thumbhash_resolver else None
|
||||
)
|
||||
return self._asset_cache, cache_key, thumbhash
|
||||
return self._url_cache, cache_key, None
|
||||
if url:
|
||||
if is_asset_id(url):
|
||||
@@ -150,8 +162,20 @@ class TelegramClient:
|
||||
"message_id": result.get("result", {}).get("message_id"),
|
||||
"cached": True,
|
||||
}
|
||||
except aiohttp.ClientError:
|
||||
pass
|
||||
# Non-ok from a cached send — file_id stale or file deleted on
|
||||
# Telegram's side. Log at DEBUG so operators who are hunting
|
||||
# "why didn't the cached send work?" can see it, but the
|
||||
# caller will fall through to a fresh upload.
|
||||
_LOGGER.debug(
|
||||
"Telegram %s (cached) returned non-ok: status=%s code=%s desc=%r — falling back to fresh upload",
|
||||
kind.api_method, response.status, result.get("error_code"),
|
||||
result.get("description"),
|
||||
)
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.debug(
|
||||
"Telegram %s (cached) transport error — falling back to fresh upload: %s",
|
||||
kind.api_method, err,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _upload_media(
|
||||
@@ -191,8 +215,17 @@ class TelegramClient:
|
||||
thumbhash=thumbhash, size=len(data),
|
||||
)
|
||||
return {"success": True, "message_id": res.get("message_id")}
|
||||
_LOGGER.error(
|
||||
"Telegram %s failed: status=%s code=%s desc=%r bytes=%d",
|
||||
kind.api_method, response.status, result.get("error_code"),
|
||||
result.get("description", "Unknown"), len(data),
|
||||
)
|
||||
return {"success": False, "error": result.get("description", "Unknown Telegram error")}
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.error(
|
||||
"Telegram %s transport error (bytes=%d): %s",
|
||||
kind.api_method, len(data), err, exc_info=True,
|
||||
)
|
||||
return {"success": False, "error": str(err)}
|
||||
|
||||
async def send_notification(
|
||||
@@ -217,7 +250,7 @@ class TelegramClient:
|
||||
|
||||
typing_task = None
|
||||
if chat_action:
|
||||
typing_task = self._start_typing_indicator(chat_id, chat_action)
|
||||
typing_task = self.start_chat_action_keepalive(chat_id, chat_action)
|
||||
|
||||
try:
|
||||
if len(assets) == 1 and assets[0].get("type") == "photo":
|
||||
@@ -315,8 +348,14 @@ class TelegramClient:
|
||||
retry_result = await retry_resp.json()
|
||||
if retry_resp.status == 200 and retry_result.get("ok"):
|
||||
return {"success": True, "message_id": retry_result.get("result", {}).get("message_id")}
|
||||
_LOGGER.error(
|
||||
"Telegram sendMessage failed: status=%s code=%s desc=%r",
|
||||
response.status, result.get("error_code"),
|
||||
result.get("description", "Unknown"),
|
||||
)
|
||||
return {"success": False, "error": result.get("description", "Unknown Telegram error"), "error_code": result.get("error_code")}
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.error("Telegram sendMessage transport error: %s", err, exc_info=True)
|
||||
return {"success": False, "error": str(err)}
|
||||
|
||||
async def send_chat_action(self, chat_id: str, action: str = "typing") -> bool:
|
||||
@@ -328,7 +367,13 @@ class TelegramClient:
|
||||
except aiohttp.ClientError:
|
||||
return False
|
||||
|
||||
def _start_typing_indicator(self, chat_id: str, action: str = "typing") -> asyncio.Task:
|
||||
def start_chat_action_keepalive(self, chat_id: str, action: str = "typing") -> asyncio.Task:
|
||||
"""Repeatedly post ``action`` every 4s until the returned task is cancelled.
|
||||
|
||||
Telegram chat actions expire after ~5s, so callers that want the hint
|
||||
to persist through longer work (fetching assets, multi-chunk uploads)
|
||||
need a keep-alive. Cancel the task in a ``finally`` to stop it.
|
||||
"""
|
||||
async def action_loop() -> None:
|
||||
try:
|
||||
while True:
|
||||
@@ -495,11 +540,14 @@ class TelegramClient:
|
||||
# Tuple is (cache_key, media_type, thumbhash, uploaded_size).
|
||||
media_cache_info: list[tuple[str, str, str | None, int] | None] = []
|
||||
|
||||
# Resolve cache hits and collect download tasks in parallel
|
||||
# Resolve cache hits and collect download tasks in parallel.
|
||||
# Each drop site logs the reason — otherwise a filtered asset
|
||||
# disappears silently and the media group silently shrinks.
|
||||
async def _fetch_asset(idx: int, item: dict) -> tuple[int, dict | None, bytes | None]:
|
||||
"""Return (index, cache_entry_or_None, downloaded_bytes_or_None)."""
|
||||
url = item.get("url")
|
||||
if not url:
|
||||
_LOGGER.warning("Media skipped: missing url (idx=%d type=%s)", idx, item.get("type"))
|
||||
return idx, None, None
|
||||
media_type = item.get("type", "photo")
|
||||
custom_cache_key = item.get("cache_key")
|
||||
@@ -519,12 +567,24 @@ class TelegramClient:
|
||||
if preloaded is not None:
|
||||
data = preloaded
|
||||
if max_asset_data_size and len(data) > max_asset_data_size:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: preloaded size %d exceeds max_asset_data_size %d (idx=%d type=%s url=%s)",
|
||||
len(data), max_asset_data_size, idx, media_type, url,
|
||||
)
|
||||
return idx, None, None
|
||||
if media_type == "video" and len(data) > TELEGRAM_MAX_VIDEO_SIZE:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: preloaded video %d bytes exceeds Telegram limit %d (idx=%d url=%s)",
|
||||
len(data), TELEGRAM_MAX_VIDEO_SIZE, idx, url,
|
||||
)
|
||||
return idx, None, None
|
||||
if media_type == "photo":
|
||||
exceeds, _, _, _ = check_photo_limits(data)
|
||||
exceeds, reason, _, _ = check_photo_limits(data)
|
||||
if exceeds:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: preloaded photo %s (idx=%d url=%s)",
|
||||
reason, idx, url,
|
||||
)
|
||||
return idx, None, None
|
||||
return idx, None, data
|
||||
|
||||
@@ -533,18 +593,38 @@ class TelegramClient:
|
||||
dl_headers = item.get("headers") or {}
|
||||
async with self._session.get(download_url, headers=dl_headers) as resp:
|
||||
if resp.status != 200:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: download HTTP %d (idx=%d type=%s url=%s)",
|
||||
resp.status, idx, media_type, url,
|
||||
)
|
||||
return idx, None, None
|
||||
data = await resp.read()
|
||||
if max_asset_data_size and len(data) > max_asset_data_size:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: downloaded size %d exceeds max_asset_data_size %d (idx=%d type=%s url=%s)",
|
||||
len(data), max_asset_data_size, idx, media_type, url,
|
||||
)
|
||||
return idx, None, None
|
||||
if media_type == "video" and len(data) > TELEGRAM_MAX_VIDEO_SIZE:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: video %d bytes exceeds Telegram %d-byte limit (idx=%d url=%s)",
|
||||
len(data), TELEGRAM_MAX_VIDEO_SIZE, idx, url,
|
||||
)
|
||||
return idx, None, None
|
||||
if media_type == "photo":
|
||||
exceeds, _, _, _ = check_photo_limits(data)
|
||||
exceeds, reason, _, _ = check_photo_limits(data)
|
||||
if exceeds:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: photo %s (idx=%d url=%s)",
|
||||
reason, idx, url,
|
||||
)
|
||||
return idx, None, None
|
||||
return idx, None, data
|
||||
except aiohttp.ClientError:
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: download failed (idx=%d type=%s url=%s): %s",
|
||||
idx, media_type, url, err,
|
||||
)
|
||||
return idx, None, None
|
||||
|
||||
results = await asyncio.gather(
|
||||
@@ -584,6 +664,14 @@ class TelegramClient:
|
||||
media_json.append(mij)
|
||||
|
||||
if not media_json:
|
||||
# Every asset in this chunk was filtered out (size, download
|
||||
# failure, etc.). Without this log, sendMediaGroup returns
|
||||
# success=True with zero message_ids and nobody knows why
|
||||
# the user sees only the text reply and no media.
|
||||
_LOGGER.warning(
|
||||
"sendMediaGroup skipped — chunk %d/%d had %d input items but 0 usable (all filtered/failed)",
|
||||
chunk_idx + 1, len(chunks), len(chunk),
|
||||
)
|
||||
continue
|
||||
|
||||
form.add_field("media", json.dumps(media_json))
|
||||
@@ -620,10 +708,35 @@ class TelegramClient:
|
||||
if eff_cache:
|
||||
await eff_cache.async_set_many(cache_entries)
|
||||
else:
|
||||
return {"success": False, "error": result.get("description", "Unknown"), "failed_at_chunk": chunk_idx + 1}
|
||||
_LOGGER.error(
|
||||
"Telegram sendMediaGroup failed: status=%s code=%s desc=%r chunk=%d/%d items=%d",
|
||||
response.status, result.get("error_code"),
|
||||
result.get("description", "Unknown"),
|
||||
chunk_idx + 1, len(chunks), len(media_json),
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": result.get("description", "Unknown"),
|
||||
"error_code": result.get("error_code"),
|
||||
"failed_at_chunk": chunk_idx + 1,
|
||||
}
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.error(
|
||||
"Telegram sendMediaGroup transport error on chunk %d/%d (%d items): %s",
|
||||
chunk_idx + 1, len(chunks), len(media_json), err,
|
||||
exc_info=True,
|
||||
)
|
||||
return {"success": False, "error": str(err), "failed_at_chunk": chunk_idx + 1}
|
||||
|
||||
# Distinguish "posted something" from "posted nothing" so the caller
|
||||
# can surface an ERROR when a command produced a caption reply but no
|
||||
# media ever reached Telegram.
|
||||
if not all_message_ids:
|
||||
_LOGGER.warning(
|
||||
"sendMediaGroup completed with 0 message_ids across %d chunk(s) — nothing was delivered",
|
||||
len(chunks),
|
||||
)
|
||||
return {"success": False, "error": "no_items_delivered", "chunks_sent": len(chunks)}
|
||||
return {"success": True, "message_ids": all_message_ids, "chunks_sent": len(chunks)}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Final
|
||||
from typing import Any, Final
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# Telegram constants
|
||||
@@ -52,6 +52,65 @@ def extract_asset_id_from_url(url: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def build_telegram_asset_entry(
|
||||
*,
|
||||
url: str,
|
||||
media_type: str,
|
||||
api_key: str | None = None,
|
||||
internal_url: str = "",
|
||||
external_url: str = "",
|
||||
cache_key: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Build a ``TelegramClient.send_notification`` asset dict from raw fields.
|
||||
|
||||
Shared by the notification dispatcher and provider command handlers so
|
||||
both paths agree on media typing, URL rewriting, and auth headers. In
|
||||
particular: video assets MUST be typed ``"video"`` and point at a real
|
||||
video endpoint (e.g. Immich ``/video/playback``) — if they are sent as
|
||||
``"photo"`` pointing at a thumbnail URL, Telegram delivers a still image
|
||||
for every video in a media group and the user sees a dead poster frame
|
||||
instead of a playable clip.
|
||||
|
||||
Args:
|
||||
url: Source URL for the asset bytes. Prefer a transcoded/preview
|
||||
URL for videos (``/video/playback``) and a preview-sized
|
||||
thumbnail for photos.
|
||||
media_type: Case-insensitive type token. Accepts ``"video"``/
|
||||
``"VIDEO"``/``MediaType.VIDEO`` or any photo-like string.
|
||||
api_key: Optional API key. Attached as ``x-api-key`` iff the URL is
|
||||
served by one of the provider hosts in ``internal_url`` /
|
||||
``external_url`` (prevents leaking the key to unrelated hosts).
|
||||
internal_url: LAN-facing provider URL. Used to rewrite
|
||||
``external_url`` prefixes so Docker-host downloads stay on the
|
||||
LAN instead of egressing to the public domain.
|
||||
external_url: Public provider URL the notification URL was built
|
||||
from. Only used for the LAN rewrite and the api-key scope check.
|
||||
cache_key: Optional explicit cache key. Providers whose URLs don't
|
||||
embed a stable asset id (Google Photos) pass one through so the
|
||||
file_id cache still works.
|
||||
|
||||
Returns ``None`` iff ``url`` is empty.
|
||||
"""
|
||||
if not url:
|
||||
return None
|
||||
|
||||
if internal_url and external_url and url.startswith(external_url):
|
||||
url = internal_url + url[len(external_url):]
|
||||
|
||||
normalized_type = str(media_type or "").lower()
|
||||
entry_type = "video" if normalized_type == "video" else "photo"
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
provider_urls = [u for u in (internal_url, external_url) if u]
|
||||
if api_key and (not provider_urls or any(url.startswith(u) for u in provider_urls)):
|
||||
headers["x-api-key"] = api_key
|
||||
|
||||
entry: dict[str, Any] = {"url": url, "type": entry_type, "headers": headers}
|
||||
if cache_key:
|
||||
entry["cache_key"] = cache_key
|
||||
return entry
|
||||
|
||||
|
||||
def split_media_by_upload_size(
|
||||
media_items: list[tuple], max_upload_size: int
|
||||
) -> list[list[tuple]]:
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..ssrf import UnsafeURLError, validate_outbound_url
|
||||
from ..ssrf import UnsafeURLError, avalidate_outbound_url
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +24,7 @@ class WebhookClient:
|
||||
|
||||
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
try:
|
||||
validate_outbound_url(self._url)
|
||||
await avalidate_outbound_url(self._url)
|
||||
except UnsafeURLError as err:
|
||||
return {"success": False, "error": f"Unsafe URL: {err}"}
|
||||
try:
|
||||
@@ -33,6 +33,7 @@ class WebhookClient:
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json", **self._headers},
|
||||
timeout=_DEFAULT_TIMEOUT,
|
||||
allow_redirects=False,
|
||||
) as response:
|
||||
if 200 <= response.status < 300:
|
||||
return {"success": True, "status_code": response.status}
|
||||
|
||||
@@ -177,7 +177,9 @@ class ImmichActionExecutor(ActionExecutor):
|
||||
needs_thumbnail = album_id in album_created_now
|
||||
|
||||
if album_id and album_id != "__dry_run_new__":
|
||||
album = await self._client.get_album(album_id)
|
||||
# Actions diff the current album state to decide what to
|
||||
# add — must observe fresh data, not a cached view.
|
||||
album = await self._client.get_album(album_id, use_cache=False)
|
||||
if album is None and create_if_missing and create_album_name:
|
||||
if not dry_run:
|
||||
created = await self._client.create_album(create_album_name)
|
||||
|
||||
@@ -193,6 +193,27 @@ def get_asset_video_url(
|
||||
return None
|
||||
|
||||
|
||||
def build_asset_media_urls(
|
||||
external_url: str, asset_id: str, asset_type: str,
|
||||
) -> tuple[str, str]:
|
||||
"""Return ``(preview_url, full_url)`` for an Immich asset.
|
||||
|
||||
Single source of truth for the photo-vs-video endpoint rule. Used by
|
||||
``asset_to_media`` (notification path) and the bot command handlers
|
||||
(command path) so both always pick the transcoded ``/video/playback``
|
||||
for videos and the preview-sized thumbnail for photos — if they
|
||||
diverge, Telegram ends up delivering a still JPEG for videos in a
|
||||
media group.
|
||||
"""
|
||||
is_video = asset_type == ASSET_TYPE_VIDEO
|
||||
if is_video:
|
||||
preview_url = f"{external_url}/api/assets/{asset_id}/video/playback"
|
||||
else:
|
||||
preview_url = f"{external_url}/api/assets/{asset_id}/thumbnail?size=preview"
|
||||
full_url = f"{external_url}/api/assets/{asset_id}/original"
|
||||
return preview_url, full_url
|
||||
|
||||
|
||||
def build_asset_detail(
|
||||
asset: ImmichAssetInfo,
|
||||
external_url: str,
|
||||
@@ -246,12 +267,7 @@ def asset_to_media(asset: ImmichAssetInfo, external_url: str) -> MediaAsset:
|
||||
# preview_url is what the notification dispatcher feeds to Telegram as the
|
||||
# actual media bytes — for videos it must be the transcoded playback (mp4),
|
||||
# not the JPEG thumbnail, or Telegram receives a JPEG labeled as video/mp4.
|
||||
if asset.type == ASSET_TYPE_VIDEO:
|
||||
preview_url = f"{external_url}/api/assets/{asset.id}/video/playback"
|
||||
full_url = f"{external_url}/api/assets/{asset.id}/original"
|
||||
else:
|
||||
preview_url = f"{external_url}/api/assets/{asset.id}/thumbnail?size=preview"
|
||||
full_url = f"{external_url}/api/assets/{asset.id}/original"
|
||||
preview_url, full_url = build_asset_media_urls(external_url, asset.id, asset.type)
|
||||
|
||||
return MediaAsset(
|
||||
id=asset.id,
|
||||
|
||||
@@ -13,6 +13,18 @@ from .models import ImmichAlbumData, ImmichAssetInfo
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Guard against runaway payloads when a bulk import lands in one poll tick.
|
||||
# Templates iterate every entry in ``added_assets`` / ``removed_asset_ids``
|
||||
# in Jinja for-loops (see defaults/*/assets_added.jinja2), and Telegram's
|
||||
# media group has a hard cap of its own — sending 200k entries would both
|
||||
# crash rendering and produce a message that no transport can deliver.
|
||||
#
|
||||
# ``added_count`` / ``removed_count`` on the event always carry the true
|
||||
# totals so templates can show an accurate "N added" number even when the
|
||||
# per-asset list is truncated.
|
||||
_MAX_ASSETS_PER_EVENT = 50
|
||||
_MAX_REMOVALS_PER_EVENT = 200
|
||||
|
||||
|
||||
def _make_base_extra(new_album: ImmichAlbumData, external_url: str) -> dict:
|
||||
"""Build the common extra dict for album events."""
|
||||
@@ -85,7 +97,17 @@ def detect_album_changes(
|
||||
|
||||
# Emit one event per change type detected
|
||||
if added_assets:
|
||||
media_assets = [asset_to_media(a, external_url) for a in added_assets]
|
||||
total_added = len(added_assets)
|
||||
truncated_added = added_assets[:_MAX_ASSETS_PER_EVENT]
|
||||
media_assets = [asset_to_media(a, external_url) for a in truncated_added]
|
||||
event_extra = dict(extra)
|
||||
if total_added > _MAX_ASSETS_PER_EVENT:
|
||||
event_extra["truncated"] = True
|
||||
event_extra["shown_count"] = _MAX_ASSETS_PER_EVENT
|
||||
_LOGGER.info(
|
||||
"Truncated assets_added event for album %s: %d → %d",
|
||||
new_album.id, total_added, _MAX_ASSETS_PER_EVENT,
|
||||
)
|
||||
events.append(ServiceEvent(
|
||||
event_type=EventType.ASSETS_ADDED,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
@@ -95,12 +117,22 @@ def detect_album_changes(
|
||||
timestamp=now,
|
||||
added_assets=media_assets,
|
||||
removed_asset_ids=[],
|
||||
added_count=len(added_assets),
|
||||
added_count=total_added,
|
||||
removed_count=0,
|
||||
extra=dict(extra),
|
||||
extra=event_extra,
|
||||
))
|
||||
|
||||
if removed_ids:
|
||||
total_removed = len(removed_ids)
|
||||
truncated_removed = list(removed_ids)[:_MAX_REMOVALS_PER_EVENT]
|
||||
event_extra = dict(extra)
|
||||
if total_removed > _MAX_REMOVALS_PER_EVENT:
|
||||
event_extra["truncated"] = True
|
||||
event_extra["shown_count"] = _MAX_REMOVALS_PER_EVENT
|
||||
_LOGGER.info(
|
||||
"Truncated assets_removed event for album %s: %d → %d",
|
||||
new_album.id, total_removed, _MAX_REMOVALS_PER_EVENT,
|
||||
)
|
||||
events.append(ServiceEvent(
|
||||
event_type=EventType.ASSETS_REMOVED,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
@@ -109,10 +141,10 @@ def detect_album_changes(
|
||||
collection_name=new_album.name,
|
||||
timestamp=now,
|
||||
added_assets=[],
|
||||
removed_asset_ids=list(removed_ids),
|
||||
removed_asset_ids=truncated_removed,
|
||||
added_count=0,
|
||||
removed_count=len(removed_ids),
|
||||
extra=dict(extra),
|
||||
removed_count=total_removed,
|
||||
extra=event_extra,
|
||||
))
|
||||
|
||||
if name_changed:
|
||||
|
||||
@@ -2,14 +2,17 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ...notifications.ssrf import UnsafeURLError, validate_outbound_url
|
||||
from .models import ImmichAlbumData, SharedLinkInfo
|
||||
from .models import ImmichAlbumData, ImmichAlbumMeta, SharedLinkInfo
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,6 +21,51 @@ _LOGGER = logging.getLogger(__name__)
|
||||
MAX_SEARCH_QUERY_LEN = 256
|
||||
MAX_SEARCH_PERSON_IDS = 50
|
||||
|
||||
# Module-level TTL caches for album bodies and shared-link listings. The
|
||||
# Immich ``GET /api/albums/{id}`` response can be tens or hundreds of MB on a
|
||||
# large album, and bot commands like /random, /latest, /memory all refetch
|
||||
# the same album in quick succession. A short TTL makes repeat runs nearly
|
||||
# instant and deduplicates concurrent fetches so a burst of commands issues
|
||||
# one HTTP call instead of N.
|
||||
#
|
||||
# Caches are module-scoped (not instance-scoped) because ``ImmichClient`` is
|
||||
# constructed fresh per request in several places (api/providers.py,
|
||||
# services/action_runner.py, command handlers), so an instance cache would
|
||||
# never survive to serve a second caller. This mirrors ``_users_cache`` in
|
||||
# ``provider.py``.
|
||||
_ALBUM_CACHE_TTL_SECONDS = 60
|
||||
_SHARED_LINKS_CACHE_TTL_SECONDS = 60
|
||||
# Guard rail against runaway memory — a 200k-asset album response can be
|
||||
# ~150 MB, so even modest caps bound the worst case.
|
||||
_ALBUM_CACHE_MAX_ENTRIES = 32
|
||||
_album_cache_lock = asyncio.Lock()
|
||||
# key = (server_digest, album_id); value = (monotonic_ts, raw_api_dict)
|
||||
# Store the raw dict rather than the parsed ``ImmichAlbumData`` so callers
|
||||
# that pass a ``users_cache`` still get owner-name enrichment on cache hits.
|
||||
_album_cache: dict[tuple[str, str], tuple[float, dict[str, Any]]] = {}
|
||||
_shared_links_cache_lock = asyncio.Lock()
|
||||
# key = server_digest; value = (monotonic_ts, {album_id: [SharedLinkInfo, ...]})
|
||||
# The underlying ``/api/shared-links`` endpoint has no per-album filter, so
|
||||
# every call was already paying for the full server-wide list. Caching the
|
||||
# bucketed result once per server turns N per-album calls into one fetch.
|
||||
_shared_links_cache: dict[str, tuple[float, dict[str, list[SharedLinkInfo]]]] = {}
|
||||
|
||||
|
||||
def _server_digest(url: str, api_key: str) -> str:
|
||||
"""Hashed key that avoids putting raw api_key into cache dict keys."""
|
||||
return hashlib.sha256(f"{url}|{api_key}".encode("utf-8")).hexdigest()[:32]
|
||||
|
||||
|
||||
def invalidate_album_cache() -> None:
|
||||
"""Drop every cached album body. Call after mutations that invalidate
|
||||
the cached view (e.g. integration tests, manual /refresh commands)."""
|
||||
_album_cache.clear()
|
||||
|
||||
|
||||
def invalidate_shared_links_cache() -> None:
|
||||
"""Drop every cached shared-link listing."""
|
||||
_shared_links_cache.clear()
|
||||
|
||||
# User-facing error bodies — Immich responses may leak internal paths,
|
||||
# hostnames, or headers injected by intermediary proxies. These helpers keep
|
||||
# only a short, scrubbed summary; full bodies are logged server-side only.
|
||||
@@ -184,28 +232,100 @@ class ImmichClient:
|
||||
return {}
|
||||
|
||||
async def get_shared_links(self, album_id: str) -> list[SharedLinkInfo]:
|
||||
links: list[SharedLinkInfo] = []
|
||||
bucketed = await self._get_shared_links_bucketed()
|
||||
return list(bucketed.get(album_id, []))
|
||||
|
||||
async def _get_shared_links_bucketed(self) -> dict[str, list[SharedLinkInfo]]:
|
||||
"""Return ``{album_id: [SharedLinkInfo, ...]}`` for the server, hitting
|
||||
the module-level TTL cache first. Underlying Immich endpoint has no
|
||||
per-album filter, so one server-wide fetch serves every caller until
|
||||
the TTL elapses.
|
||||
"""
|
||||
digest = _server_digest(self._url, self._api_key)
|
||||
now = time.monotonic()
|
||||
entry = _shared_links_cache.get(digest)
|
||||
if entry is not None and (now - entry[0]) < _SHARED_LINKS_CACHE_TTL_SECONDS:
|
||||
return entry[1]
|
||||
|
||||
async with _shared_links_cache_lock:
|
||||
# Re-check under the lock — another coroutine may have refreshed
|
||||
# while we waited.
|
||||
entry = _shared_links_cache.get(digest)
|
||||
if entry is not None and (time.monotonic() - entry[0]) < _SHARED_LINKS_CACHE_TTL_SECONDS:
|
||||
return entry[1]
|
||||
fresh = await self.get_all_shared_links_by_album()
|
||||
_shared_links_cache[digest] = (time.monotonic(), fresh)
|
||||
return fresh
|
||||
|
||||
async def get_all_shared_links_by_album(self) -> dict[str, list[SharedLinkInfo]]:
|
||||
"""Fetch every shared link on the server, bucketed by album id.
|
||||
|
||||
Immich's ``/api/shared-links`` endpoint is server-wide — there's no
|
||||
per-album filter server-side — so every call that wanted the links
|
||||
for a single album was already paying the cost of the full listing
|
||||
and then discarding most of the response. Callers that need links
|
||||
for multiple albums in one tick should use this method and index
|
||||
into the returned dict instead of hitting ``get_shared_links`` in
|
||||
a loop.
|
||||
|
||||
Returns an empty dict on any error (matches the silent-failure
|
||||
contract of ``get_shared_links`` so callers don't need to branch
|
||||
on transient outages).
|
||||
"""
|
||||
result: dict[str, list[SharedLinkInfo]] = {}
|
||||
try:
|
||||
async with self._session.get(
|
||||
f"{self._url}/api/shared-links",
|
||||
headers=self._headers,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
for link in data:
|
||||
album = link.get("album")
|
||||
key = link.get("key")
|
||||
if album and key and album.get("id") == album_id:
|
||||
links.append(SharedLinkInfo.from_api_response(link))
|
||||
if response.status != 200:
|
||||
_LOGGER.warning(
|
||||
"get_all_shared_links non-200: HTTP %s", response.status
|
||||
)
|
||||
return result
|
||||
data = await response.json()
|
||||
for link in data:
|
||||
album = link.get("album")
|
||||
key = link.get("key")
|
||||
if not (album and key):
|
||||
continue
|
||||
aid = album.get("id")
|
||||
if not aid:
|
||||
continue
|
||||
result.setdefault(aid, []).append(
|
||||
SharedLinkInfo.from_api_response(link)
|
||||
)
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to fetch shared links: %s", err)
|
||||
return links
|
||||
_LOGGER.warning("Failed to fetch all shared links: %s", err)
|
||||
return result
|
||||
|
||||
async def get_album(
|
||||
self,
|
||||
album_id: str,
|
||||
users_cache: dict[str, str] | None = None,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
) -> ImmichAlbumData | None:
|
||||
"""Fetch an album by id, optionally serving from the module-level
|
||||
TTL cache. Pass ``use_cache=False`` from paths that must observe the
|
||||
current server state (e.g. the notification poll loop's full-fetch
|
||||
path, where a stale cached entry would delay asset-removal events).
|
||||
Non-cached fetches still populate the cache for subsequent readers.
|
||||
"""
|
||||
cache_key = (_server_digest(self._url, self._api_key), album_id)
|
||||
if use_cache:
|
||||
entry = _album_cache.get(cache_key)
|
||||
if entry is not None and (time.monotonic() - entry[0]) < _ALBUM_CACHE_TTL_SECONDS:
|
||||
# Rehydrate per-call so ``users_cache`` enrichment is applied
|
||||
# with the caller's dict, not whichever one was live when the
|
||||
# cache was populated.
|
||||
return ImmichAlbumData.from_api_response(entry[1], users_cache)
|
||||
|
||||
# Deliberately fetch without holding a lock so concurrent calls for
|
||||
# *different* album_ids (the common case from asyncio.gather in
|
||||
# fetch_albums_with_links) stay parallel. The worst case is a small
|
||||
# duplicate-fetch stampede when two requests miss the same album at
|
||||
# the same instant — acceptable for our scale.
|
||||
try:
|
||||
async with self._session.get(
|
||||
f"{self._url}/api/albums/{album_id}",
|
||||
@@ -218,10 +338,132 @@ class ImmichClient:
|
||||
f"Error fetching album {album_id}: HTTP {response.status}"
|
||||
)
|
||||
data = await response.json()
|
||||
return ImmichAlbumData.from_api_response(data, users_cache)
|
||||
except aiohttp.ClientError as err:
|
||||
raise ImmichApiError(f"Error communicating with Immich: {err}") from err
|
||||
|
||||
async with _album_cache_lock:
|
||||
# Evict the oldest entry if we're at the cap — simple FIFO is fine
|
||||
# for our access pattern (commands touch a small working set).
|
||||
if len(_album_cache) >= _ALBUM_CACHE_MAX_ENTRIES and cache_key not in _album_cache:
|
||||
oldest = min(_album_cache.items(), key=lambda kv: kv[1][0])[0]
|
||||
_album_cache.pop(oldest, None)
|
||||
_album_cache[cache_key] = (time.monotonic(), data)
|
||||
return ImmichAlbumData.from_api_response(data, users_cache)
|
||||
|
||||
async def get_album_meta(self, album_id: str) -> ImmichAlbumMeta | None:
|
||||
"""Fetch album metadata without the assets array.
|
||||
|
||||
Uses Immich's ``?withoutAssets=true`` query param, which skips the
|
||||
(potentially huge) ``assets`` field. A 200k-asset album response
|
||||
drops from ~150 MB to a few hundred bytes, so this is cheap enough
|
||||
to run on every poll as a change-detection probe.
|
||||
"""
|
||||
try:
|
||||
async with self._session.get(
|
||||
f"{self._url}/api/albums/{album_id}",
|
||||
params={"withoutAssets": "true"},
|
||||
headers=self._headers,
|
||||
) as response:
|
||||
if response.status == 404:
|
||||
return None
|
||||
if response.status != 200:
|
||||
raise ImmichApiError(
|
||||
f"Error fetching album meta {album_id}: HTTP {response.status}"
|
||||
)
|
||||
data = await response.json()
|
||||
return ImmichAlbumMeta.from_api_response(data)
|
||||
except aiohttp.ClientError as err:
|
||||
raise ImmichApiError(f"Error communicating with Immich: {err}") from err
|
||||
|
||||
async def search_album_assets_updated_after(
|
||||
self,
|
||||
album_id: str,
|
||||
updated_after: str,
|
||||
*,
|
||||
page_size: int = 1000,
|
||||
max_pages: int = 50,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch assets in ``album_id`` whose ``updatedAt`` is after ``updated_after``.
|
||||
|
||||
Uses ``POST /api/search/metadata`` with ``albumIds=[album_id]`` and
|
||||
``updatedAfter=<iso>``. Paginates through up to ``max_pages`` pages —
|
||||
the cap exists so a clock-skew or upstream bug cannot produce an
|
||||
infinite loop that exhausts memory on a 200k-asset album. In practice
|
||||
an active album sees a few hundred updated assets per tick and
|
||||
terminates after one page.
|
||||
|
||||
Returns raw Immich asset dicts (same shape as ``album.assets[*]``
|
||||
from ``get_album``), so callers can feed them into
|
||||
``ImmichAssetInfo.from_api_response`` directly.
|
||||
"""
|
||||
if not updated_after:
|
||||
return []
|
||||
|
||||
page_size = max(1, min(page_size, 1000))
|
||||
results: list[dict[str, Any]] = []
|
||||
for page in range(1, max_pages + 1):
|
||||
payload: dict[str, Any] = {
|
||||
"albumIds": [album_id],
|
||||
"updatedAfter": updated_after,
|
||||
"page": page,
|
||||
"size": page_size,
|
||||
# ``withExif`` keeps location/description parity with
|
||||
# ``get_album`` so downstream ``ImmichAssetInfo.from_api_response``
|
||||
# populates city/country/rating on the delta path too.
|
||||
"withExif": True,
|
||||
"withPeople": True,
|
||||
}
|
||||
try:
|
||||
async with self._session.post(
|
||||
f"{self._url}/api/search/metadata",
|
||||
headers=self._json_headers,
|
||||
json=payload,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
body_snip = await response.text()
|
||||
_LOGGER.warning(
|
||||
"Immich delta search non-200: HTTP %s body=%s",
|
||||
response.status, _redact_body(body_snip),
|
||||
)
|
||||
break
|
||||
data = await response.json()
|
||||
assets_block = data.get("assets")
|
||||
if isinstance(assets_block, dict):
|
||||
items = assets_block.get("items", []) or []
|
||||
next_page = assets_block.get("nextPage")
|
||||
elif isinstance(assets_block, list):
|
||||
items = assets_block
|
||||
next_page = None
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Immich delta search returned unexpected shape: keys=%s",
|
||||
list(data.keys())[:5],
|
||||
)
|
||||
break
|
||||
|
||||
results.extend(items)
|
||||
|
||||
# Stop early on the last page. Immich returns nextPage as
|
||||
# the next page number (string or int) or None/empty when
|
||||
# exhausted. Fall back to page-fullness heuristic if the
|
||||
# server omits the pagination hint.
|
||||
if next_page is None or next_page == "" or next_page == 0:
|
||||
break
|
||||
if len(items) < page_size:
|
||||
break
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Immich delta search transport error: %s", err)
|
||||
break
|
||||
except Exception as err: # noqa: BLE001 — resilience over correctness
|
||||
_LOGGER.warning("Immich delta search parse error: %s", err)
|
||||
break
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Immich delta search for album %s hit max_pages=%d cap",
|
||||
album_id, max_pages,
|
||||
)
|
||||
return results
|
||||
|
||||
async def get_albums(self) -> list[dict[str, Any]]:
|
||||
try:
|
||||
async with self._session.get(
|
||||
|
||||
@@ -146,6 +146,49 @@ class ImmichAssetInfo:
|
||||
return bool(thumbhash)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImmichAlbumMeta:
|
||||
"""Lightweight album metadata from ``GET /api/albums/{id}?withoutAssets=true``.
|
||||
|
||||
Used as a cheap change-detection probe so we can skip the multi-MB
|
||||
full-asset fetch when nothing interesting has changed. Large albums
|
||||
(tens to hundreds of thousands of assets) would otherwise re-serialize
|
||||
the entire asset list on every poll interval.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
asset_count: int
|
||||
updated_at: str
|
||||
shared: bool
|
||||
thumbnail_asset_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_api_response(cls, data: dict[str, Any]) -> ImmichAlbumMeta:
|
||||
return cls(
|
||||
id=data["id"],
|
||||
name=data.get("albumName", "Unnamed"),
|
||||
asset_count=int(data.get("assetCount", 0) or 0),
|
||||
updated_at=data.get("updatedAt", "") or "",
|
||||
shared=bool(data.get("shared", False)),
|
||||
thumbnail_asset_id=data.get("albumThumbnailAssetId"),
|
||||
)
|
||||
|
||||
def fingerprint(self) -> dict[str, Any]:
|
||||
"""Return a minimal serializable dict for persistence + equality checks.
|
||||
|
||||
We purposefully exclude ``id`` (known from the state row) and keep the
|
||||
dict flat so JSON round-trips are cheap and stable for equality.
|
||||
"""
|
||||
return {
|
||||
"updated_at": self.updated_at,
|
||||
"asset_count": self.asset_count,
|
||||
"shared": self.shared,
|
||||
"name": self.name,
|
||||
"thumbnail_asset_id": self.thumbnail_asset_id or "",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImmichAlbumData:
|
||||
"""Full album data from Immich API."""
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -11,13 +14,62 @@ from notify_bridge_core.models.events import ServiceEvent
|
||||
from notify_bridge_core.providers.base import ServiceProvider, ServiceProviderType
|
||||
from notify_bridge_core.templates.variables import TemplateVariableDefinition
|
||||
|
||||
from .change_detector import detect_album_changes
|
||||
from .asset_utils import asset_to_media
|
||||
from .change_detector import _MAX_ASSETS_PER_EVENT, detect_album_changes
|
||||
from .client import ImmichClient
|
||||
from .models import ImmichAlbumData
|
||||
from .models import ImmichAlbumData, ImmichAlbumMeta, ImmichAssetInfo
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Module-level users cache shared across ImmichServiceProvider instances.
|
||||
# Users change rarely (new people joining the server, display-name edits), so
|
||||
# refetching on every tracker's ``connect()`` is wasteful — a fleet of 10
|
||||
# trackers on the same Immich server otherwise issues 10 ``GET /api/users``
|
||||
# calls per poll cycle. TTL is conservative (1h) and a hashed key keeps the
|
||||
# raw api_key out of dict keys in case of a memory dump.
|
||||
_USERS_CACHE_TTL_SECONDS = 3600
|
||||
_users_cache_lock = asyncio.Lock()
|
||||
_users_cache: dict[str, tuple[float, dict[str, str]]] = {}
|
||||
|
||||
|
||||
def _users_cache_key(url: str, api_key: str) -> str:
|
||||
digest = hashlib.sha256(f"{url}|{api_key}".encode("utf-8")).hexdigest()
|
||||
return digest[:32]
|
||||
|
||||
|
||||
async def _get_cached_users(
|
||||
client: ImmichClient, url: str, api_key: str
|
||||
) -> dict[str, str]:
|
||||
"""Return ``{user_id: display_name}`` for the server, reusing cache entries
|
||||
whose TTL has not elapsed. Misses and stale hits fall through to a real
|
||||
fetch under a single lock so concurrent polls don't stampede the server.
|
||||
"""
|
||||
key = _users_cache_key(url, api_key)
|
||||
now = time.monotonic()
|
||||
entry = _users_cache.get(key)
|
||||
if entry is not None and (now - entry[0]) < _USERS_CACHE_TTL_SECONDS:
|
||||
return entry[1]
|
||||
|
||||
async with _users_cache_lock:
|
||||
# Re-check after acquiring the lock — another coroutine may have
|
||||
# refreshed the entry while we waited.
|
||||
entry = _users_cache.get(key)
|
||||
if entry is not None and (time.monotonic() - entry[0]) < _USERS_CACHE_TTL_SECONDS:
|
||||
return entry[1]
|
||||
fresh = await client.get_users()
|
||||
_users_cache[key] = (time.monotonic(), fresh)
|
||||
return fresh
|
||||
|
||||
|
||||
def invalidate_users_cache() -> None:
|
||||
"""Drop every cached users dict. Exposed for callers that mutate users
|
||||
(e.g. provider config changes, integration tests) and need the next
|
||||
``connect()`` to re-fetch.
|
||||
"""
|
||||
_users_cache.clear()
|
||||
|
||||
|
||||
# Immich-specific template variables
|
||||
IMMICH_VARIABLES: list[TemplateVariableDefinition] = [
|
||||
TemplateVariableDefinition(
|
||||
@@ -135,7 +187,9 @@ class ImmichServiceProvider(ServiceProvider):
|
||||
await self._client.get_server_config()
|
||||
if self._external_domain:
|
||||
self._client.external_domain = self._external_domain
|
||||
self._users_cache = await self._client.get_users()
|
||||
self._users_cache = await _get_cached_users(
|
||||
self._client, self._client.url, self._client.api_key,
|
||||
)
|
||||
return ok
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
@@ -150,9 +204,32 @@ class ImmichServiceProvider(ServiceProvider):
|
||||
new_state = dict(tracker_state)
|
||||
external_url = self._client.external_url
|
||||
|
||||
for album_id in collection_ids:
|
||||
album = await self._client.get_album(album_id, self._users_cache)
|
||||
if album is None:
|
||||
# Tick-scoped share-link cache. Populated lazily on first enrichment;
|
||||
# a tracker watching 5 albums with changes now issues 1 ``/api/shared-links``
|
||||
# request per tick instead of 5 (and the endpoint is server-wide — each
|
||||
# call was already fetching all links and discarding most of them).
|
||||
self._tick_shared_links: dict[str, list] | None = None
|
||||
|
||||
# Fan out the cheap meta probes in parallel. For a tracker that
|
||||
# watches 20 albums on the same Immich server this turns a 20-hop
|
||||
# serial wait into ~1 round-trip's worth of latency. aiohttp's
|
||||
# connection pool caps concurrency per host, so this can't stampede.
|
||||
meta_results = await asyncio.gather(
|
||||
*(self._client.get_album_meta(aid) for aid in collection_ids),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
for album_id, meta_or_exc in zip(collection_ids, meta_results):
|
||||
if isinstance(meta_or_exc, BaseException):
|
||||
# Transient failure on this album — preserve existing state
|
||||
# and move on. Logging at warning so flaky albums surface in
|
||||
# the log without flooding on hard outages.
|
||||
_LOGGER.warning(
|
||||
"Meta probe failed for album %s: %s", album_id, meta_or_exc,
|
||||
)
|
||||
continue
|
||||
meta = meta_or_exc
|
||||
if meta is None:
|
||||
# Album deleted
|
||||
if album_id in new_state:
|
||||
from notify_bridge_core.models.events import EventType
|
||||
@@ -168,11 +245,80 @@ class ImmichServiceProvider(ServiceProvider):
|
||||
del new_state[album_id]
|
||||
continue
|
||||
|
||||
# Get previous state
|
||||
prev = new_state.get(album_id)
|
||||
prev_fingerprint = prev.get("meta_fingerprint") if prev else None
|
||||
has_pending = bool(prev and prev.get("pending_asset_ids"))
|
||||
|
||||
# 2) Fast-path: fingerprint match and no pending assets → no work.
|
||||
# We still refresh the fingerprint slot (no-op if identical) and
|
||||
# leave asset_ids untouched on disk.
|
||||
if (
|
||||
prev is not None
|
||||
and prev_fingerprint == meta.fingerprint()
|
||||
and not has_pending
|
||||
):
|
||||
continue
|
||||
|
||||
# 3) Decide: delta fetch (cheap, active-album case) or full
|
||||
# fetch (first tick + reconciliation for removals).
|
||||
old_fp = prev.get("meta_fingerprint") if prev else None
|
||||
old_asset_count = (old_fp or {}).get("asset_count", 0)
|
||||
old_updated_at = (old_fp or {}).get("updated_at", "")
|
||||
|
||||
# Gate for the delta path:
|
||||
# - must be tracked already (prev exists, has asset_ids)
|
||||
# - must have a prior timestamp (empty ⇒ migrated DB row)
|
||||
# - asset_count must not have decreased (removals need full fetch)
|
||||
can_delta = (
|
||||
prev is not None
|
||||
and bool(prev.get("asset_ids"))
|
||||
and bool(old_updated_at)
|
||||
and meta.asset_count >= old_asset_count
|
||||
)
|
||||
|
||||
if can_delta:
|
||||
delta_events = await self._poll_delta(
|
||||
album_id=album_id,
|
||||
prev=prev,
|
||||
new_meta=meta,
|
||||
old_updated_at=old_updated_at,
|
||||
)
|
||||
if delta_events is not None:
|
||||
events.extend(delta_events["events"])
|
||||
new_state[album_id] = delta_events["new_state"]
|
||||
continue
|
||||
# delta_events is None ⇒ delta saw more additions than the
|
||||
# net count increase (mixed add+remove) ⇒ fall through to
|
||||
# the full-fetch path so removals get detected.
|
||||
|
||||
# Full fetch: first tick, or count-decreased, or delta-unsafe.
|
||||
# Bypass the module-level album cache — this path runs when we
|
||||
# specifically need the current server state (e.g. to detect
|
||||
# asset removals), so a stale cached entry would silently delay
|
||||
# the event.
|
||||
album = await self._client.get_album(
|
||||
album_id, self._users_cache, use_cache=False,
|
||||
)
|
||||
if album is None:
|
||||
# Album was deleted between meta probe and full fetch — handle
|
||||
# the deletion the same way as above.
|
||||
if album_id in new_state:
|
||||
from notify_bridge_core.models.events import EventType
|
||||
from datetime import datetime, timezone
|
||||
events.append(ServiceEvent(
|
||||
event_type=EventType.COLLECTION_DELETED,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
provider_name=self._name,
|
||||
collection_id=album_id,
|
||||
collection_name=new_state.get(album_id, {}).get("name", "Unknown"),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
))
|
||||
del new_state[album_id]
|
||||
continue
|
||||
|
||||
if prev is None:
|
||||
# First time seeing this album — store state, no event
|
||||
new_state[album_id] = _serialize_album_state(album)
|
||||
new_state[album_id] = _serialize_album_state(album, meta)
|
||||
continue
|
||||
|
||||
# Reconstruct previous album data for comparison
|
||||
@@ -184,34 +330,233 @@ class ImmichServiceProvider(ServiceProvider):
|
||||
)
|
||||
|
||||
if detected_events:
|
||||
# Fetch shared links to enrich events with public_url
|
||||
shared_links = await self._client.get_shared_links(album_id)
|
||||
public_link = None
|
||||
protected_link = None
|
||||
for link in shared_links:
|
||||
if link.is_accessible and not link.is_expired:
|
||||
if link.has_password:
|
||||
protected_link = link
|
||||
else:
|
||||
public_link = link
|
||||
break # prefer non-password link
|
||||
|
||||
ext_domain = self._external_domain or self._client.external_url
|
||||
for evt in detected_events:
|
||||
if public_link:
|
||||
evt.extra["public_url"] = f"{ext_domain}/share/{public_link.key}"
|
||||
elif protected_link:
|
||||
evt.extra["protected_url"] = f"{ext_domain}/share/{protected_link.key}"
|
||||
|
||||
await self._enrich_with_shared_links(album_id, detected_events)
|
||||
events.extend(detected_events)
|
||||
|
||||
# Update state
|
||||
state = _serialize_album_state(album)
|
||||
state = _serialize_album_state(album, meta)
|
||||
state["pending_asset_ids"] = list(updated_pending)
|
||||
new_state[album_id] = state
|
||||
|
||||
return events, new_state
|
||||
|
||||
async def _poll_delta(
|
||||
self,
|
||||
*,
|
||||
album_id: str,
|
||||
prev: dict[str, Any],
|
||||
new_meta: ImmichAlbumMeta,
|
||||
old_updated_at: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Delta-fetch path for an active album.
|
||||
|
||||
Calls ``search/metadata`` with ``updatedAfter`` instead of pulling
|
||||
the full asset list. Returns a dict with ``events`` and ``new_state``
|
||||
on success, or ``None`` to signal the caller to retry via full fetch
|
||||
(used when a mixed add+remove is detected — the delta endpoint can't
|
||||
tell us *what* was removed, only that additions alone don't account
|
||||
for the net count change).
|
||||
|
||||
Trades strict detection of removals-during-mixed-changes for a
|
||||
drastic reduction in bytes fetched per tick. On a 200k-asset album
|
||||
where 50 were just added, we fetch ~50 asset records instead of
|
||||
200 000.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from notify_bridge_core.models.events import EventType
|
||||
|
||||
prev_asset_ids: set[str] = set(prev.get("asset_ids", []))
|
||||
prev_pending: set[str] = set(prev.get("pending_asset_ids", []))
|
||||
|
||||
raw_assets = await self._client.search_album_assets_updated_after(
|
||||
album_id, old_updated_at
|
||||
)
|
||||
|
||||
# Parse everything that came back. We need unprocessed entries too
|
||||
# (they feed the ``pending_asset_ids`` list used by the original
|
||||
# change detector's processed-later logic).
|
||||
delta_assets: list[ImmichAssetInfo] = []
|
||||
for raw in raw_assets:
|
||||
try:
|
||||
delta_assets.append(
|
||||
ImmichAssetInfo.from_api_response(raw, self._users_cache)
|
||||
)
|
||||
except Exception as err: # noqa: BLE001 — one bad record ≠ abort tick
|
||||
_LOGGER.warning(
|
||||
"Skipping malformed asset record in delta response: %s", err
|
||||
)
|
||||
|
||||
newly_added: list[ImmichAssetInfo] = []
|
||||
still_pending: set[str] = set()
|
||||
for asset in delta_assets:
|
||||
if asset.is_processed:
|
||||
if asset.id not in prev_asset_ids:
|
||||
newly_added.append(asset)
|
||||
else:
|
||||
still_pending.add(asset.id)
|
||||
|
||||
old_asset_count = int((prev.get("meta_fingerprint") or {}).get("asset_count", 0))
|
||||
net_change = new_meta.asset_count - old_asset_count
|
||||
|
||||
# If delta found more "added" assets than the net count change,
|
||||
# a concurrent removal happened. Full fetch is the only way to
|
||||
# know what was removed — bail out so the caller retries.
|
||||
if net_change >= 0 and len(newly_added) > net_change:
|
||||
_LOGGER.info(
|
||||
"Delta for album %s found %d additions but net change is %d "
|
||||
"— falling back to full fetch for removal reconciliation",
|
||||
album_id, len(newly_added), net_change,
|
||||
)
|
||||
return None
|
||||
|
||||
# Mirror case: positive net change we couldn't account for with the
|
||||
# delta results (possibly clock skew on ``updated_at``, or an asset
|
||||
# whose timestamp is before ``old_updated_at`` yet the album's
|
||||
# ``updatedAt`` bumped). Full fetch to avoid silently missing adds.
|
||||
if net_change > 0 and len(newly_added) < net_change:
|
||||
_LOGGER.info(
|
||||
"Delta for album %s found %d additions but net change is %d "
|
||||
"— falling back to full fetch to avoid missing assets",
|
||||
album_id, len(newly_added), net_change,
|
||||
)
|
||||
return None
|
||||
|
||||
events: list[ServiceEvent] = []
|
||||
now = datetime.now(timezone.utc)
|
||||
external_url = self._external_domain or self._client.external_url
|
||||
album_url = f"{external_url}/albums/{album_id}"
|
||||
|
||||
# Carry album-level attributes we know from the cheap meta probe.
|
||||
# Shared-link enrichment happens further down only if we emitted
|
||||
# any asset events.
|
||||
base_extra = {
|
||||
"album_url": album_url,
|
||||
"shared": new_meta.shared,
|
||||
"asset_count": new_meta.asset_count,
|
||||
"photo_count": 0, # unknown without per-asset scan; templates tolerate 0
|
||||
"video_count": 0,
|
||||
"people": [],
|
||||
"owner": "",
|
||||
}
|
||||
|
||||
# Metadata-only events (no asset fetch needed)
|
||||
old_fp = prev.get("meta_fingerprint") or {}
|
||||
if old_fp.get("name") and old_fp["name"] != new_meta.name:
|
||||
events.append(ServiceEvent(
|
||||
event_type=EventType.COLLECTION_RENAMED,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
provider_name=self._name,
|
||||
collection_id=album_id,
|
||||
collection_name=new_meta.name,
|
||||
timestamp=now,
|
||||
added_assets=[],
|
||||
removed_asset_ids=[],
|
||||
added_count=0,
|
||||
removed_count=0,
|
||||
old_name=old_fp["name"],
|
||||
new_name=new_meta.name,
|
||||
extra=dict(base_extra),
|
||||
))
|
||||
|
||||
if "shared" in old_fp and bool(old_fp["shared"]) != bool(new_meta.shared):
|
||||
events.append(ServiceEvent(
|
||||
event_type=EventType.SHARING_CHANGED,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
provider_name=self._name,
|
||||
collection_id=album_id,
|
||||
collection_name=new_meta.name,
|
||||
timestamp=now,
|
||||
added_assets=[],
|
||||
removed_asset_ids=[],
|
||||
added_count=0,
|
||||
removed_count=0,
|
||||
old_shared=bool(old_fp["shared"]),
|
||||
new_shared=bool(new_meta.shared),
|
||||
extra=dict(base_extra),
|
||||
))
|
||||
|
||||
if newly_added:
|
||||
total_added = len(newly_added)
|
||||
truncated = newly_added[:_MAX_ASSETS_PER_EVENT]
|
||||
media_assets = [
|
||||
asset_to_media(a, self._client.external_url) for a in truncated
|
||||
]
|
||||
extra = dict(base_extra)
|
||||
if total_added > _MAX_ASSETS_PER_EVENT:
|
||||
extra["truncated"] = True
|
||||
extra["shown_count"] = _MAX_ASSETS_PER_EVENT
|
||||
_LOGGER.info(
|
||||
"Delta-path truncated assets_added event for album %s: %d → %d",
|
||||
album_id, total_added, _MAX_ASSETS_PER_EVENT,
|
||||
)
|
||||
events.append(ServiceEvent(
|
||||
event_type=EventType.ASSETS_ADDED,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
provider_name=self._name,
|
||||
collection_id=album_id,
|
||||
collection_name=new_meta.name,
|
||||
timestamp=now,
|
||||
added_assets=media_assets,
|
||||
removed_asset_ids=[],
|
||||
added_count=total_added,
|
||||
removed_count=0,
|
||||
extra=extra,
|
||||
))
|
||||
|
||||
if events:
|
||||
await self._enrich_with_shared_links(album_id, events)
|
||||
|
||||
# Rebuild state. asset_ids grows by the newly-added processed set.
|
||||
# pending is the union of the prior pending list (things still in
|
||||
# flight) and anything the delta confirmed as not-yet-processed.
|
||||
# When net_change is 0 or negative we trust the meta count over
|
||||
# our bookkeeping — skip-path will fix drift on the next full fetch.
|
||||
new_asset_ids = prev_asset_ids | {a.id for a in newly_added}
|
||||
# Discard any previously-pending IDs that just landed as processed.
|
||||
new_pending = (prev_pending | still_pending) - {a.id for a in newly_added}
|
||||
|
||||
return {
|
||||
"events": events,
|
||||
"new_state": {
|
||||
"name": new_meta.name,
|
||||
"asset_ids": list(new_asset_ids),
|
||||
"shared": new_meta.shared,
|
||||
"pending_asset_ids": list(new_pending),
|
||||
"meta_fingerprint": new_meta.fingerprint(),
|
||||
},
|
||||
}
|
||||
|
||||
async def _enrich_with_shared_links(
|
||||
self, album_id: str, events_to_enrich: list[ServiceEvent]
|
||||
) -> None:
|
||||
"""Attach public/protected share link URLs to events for this album.
|
||||
|
||||
Uses the tick-scoped bulk cache populated lazily on first call, so a
|
||||
tracker with changes across N albums makes one ``/api/shared-links``
|
||||
request per tick instead of N.
|
||||
"""
|
||||
if self._tick_shared_links is None:
|
||||
self._tick_shared_links = await self._client.get_all_shared_links_by_album()
|
||||
|
||||
shared_links = self._tick_shared_links.get(album_id, [])
|
||||
public_link = None
|
||||
protected_link = None
|
||||
for link in shared_links:
|
||||
if link.is_accessible and not link.is_expired:
|
||||
if link.has_password:
|
||||
protected_link = link
|
||||
else:
|
||||
public_link = link
|
||||
break # prefer non-password link
|
||||
|
||||
ext_domain = self._external_domain or self._client.external_url
|
||||
for evt in events_to_enrich:
|
||||
if public_link:
|
||||
evt.extra["public_url"] = f"{ext_domain}/share/{public_link.key}"
|
||||
elif protected_link:
|
||||
evt.extra["protected_url"] = f"{ext_domain}/share/{protected_link.key}"
|
||||
|
||||
def get_available_variables(self) -> list[TemplateVariableDefinition]:
|
||||
return list(IMMICH_VARIABLES)
|
||||
|
||||
@@ -262,13 +607,33 @@ class ImmichServiceProvider(ServiceProvider):
|
||||
return {"ok": False, "message": "Failed to connect to Immich"}
|
||||
|
||||
|
||||
def _serialize_album_state(album: ImmichAlbumData) -> dict[str, Any]:
|
||||
"""Serialize album state for persistence."""
|
||||
def _serialize_album_state(
|
||||
album: ImmichAlbumData,
|
||||
meta: ImmichAlbumMeta | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Serialize album state for persistence.
|
||||
|
||||
``meta`` carries the fingerprint used for cheap no-change detection on
|
||||
subsequent polls. When omitted (legacy callers, tests) we synthesize a
|
||||
best-effort fingerprint from the full album — it will still match on the
|
||||
next tick if nothing changed, which is what matters.
|
||||
"""
|
||||
if meta is None:
|
||||
fingerprint = {
|
||||
"updated_at": album.updated_at,
|
||||
"asset_count": len(album.asset_ids),
|
||||
"shared": album.shared,
|
||||
"name": album.name,
|
||||
"thumbnail_asset_id": album.thumbnail_asset_id or "",
|
||||
}
|
||||
else:
|
||||
fingerprint = meta.fingerprint()
|
||||
return {
|
||||
"name": album.name,
|
||||
"asset_ids": list(album.asset_ids),
|
||||
"shared": album.shared,
|
||||
"pending_asset_ids": [],
|
||||
"meta_fingerprint": fingerprint,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_PORT = 3493
|
||||
_READ_TIMEOUT = 10.0
|
||||
_WRITE_TIMEOUT = 10.0
|
||||
_CONNECT_TIMEOUT = 5.0
|
||||
|
||||
# Allowed characters for NUT protocol identifiers (UPS names, variable names).
|
||||
@@ -84,14 +85,26 @@ class NutClient:
|
||||
await self._command(f"PASSWORD {self._password}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Send LOGOUT and close the TCP connection."""
|
||||
"""Send LOGOUT and close the TCP connection.
|
||||
|
||||
``drain`` is bounded by ``_WRITE_TIMEOUT`` so a half-closed peer
|
||||
cannot hold the disconnect indefinitely — a tracker tick would
|
||||
otherwise be pinned by a stuck NUT server and block the scheduler
|
||||
slot (``max_instances=1``).
|
||||
"""
|
||||
if self._writer is not None:
|
||||
try:
|
||||
self._writer.write(b"LOGOUT\n")
|
||||
await self._writer.drain()
|
||||
except OSError:
|
||||
await asyncio.wait_for(self._writer.drain(), timeout=_WRITE_TIMEOUT)
|
||||
except (OSError, asyncio.TimeoutError):
|
||||
pass
|
||||
self._writer.close()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._writer.wait_closed(), timeout=_WRITE_TIMEOUT,
|
||||
)
|
||||
except (OSError, asyncio.TimeoutError):
|
||||
pass
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
|
||||
@@ -135,7 +148,10 @@ class NutClient:
|
||||
if self._writer is None:
|
||||
raise NutClientError("Not connected")
|
||||
self._writer.write(f"{cmd}\n".encode())
|
||||
await self._writer.drain()
|
||||
try:
|
||||
await asyncio.wait_for(self._writer.drain(), timeout=_WRITE_TIMEOUT)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise NutClientError("Write timeout") from exc
|
||||
|
||||
async def _readline(self) -> str:
|
||||
"""Read one line from upsd, stripping trailing newline."""
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from notify_bridge_core.models.events import EventType, ServiceEvent
|
||||
from notify_bridge_core.providers.base import ServiceProvider, ServiceProviderType
|
||||
@@ -57,6 +58,13 @@ SCHEDULER_VARIABLES: list[TemplateVariableDefinition] = [
|
||||
example="Monday",
|
||||
provider_type=ServiceProviderType.SCHEDULER,
|
||||
),
|
||||
TemplateVariableDefinition(
|
||||
name="timezone",
|
||||
type="string",
|
||||
description="IANA timezone used to compute current_date/time",
|
||||
example="Europe/Warsaw",
|
||||
provider_type=ServiceProviderType.SCHEDULER,
|
||||
),
|
||||
TemplateVariableDefinition(
|
||||
name="custom_vars",
|
||||
type="dict",
|
||||
@@ -83,7 +91,8 @@ class SchedulerServiceProvider(ServiceProvider):
|
||||
custom_variables: dict[str, str] | None = None,
|
||||
date_format: str = "%d.%m.%Y",
|
||||
time_format: str = "%H:%M",
|
||||
datetime_format: str = "%d.%m.%Y, %H:%M UTC",
|
||||
datetime_format: str = "%d.%m.%Y, %H:%M %Z",
|
||||
timezone_name: str | None = None,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._tracker_name = tracker_name
|
||||
@@ -91,6 +100,18 @@ class SchedulerServiceProvider(ServiceProvider):
|
||||
self._date_format = date_format
|
||||
self._time_format = time_format
|
||||
self._datetime_format = datetime_format
|
||||
# Resolve a timezone for date/time rendering. Falls back to UTC on
|
||||
# invalid IANA names so a typo in app settings doesn't break polls.
|
||||
tz: ZoneInfo
|
||||
if timezone_name:
|
||||
try:
|
||||
tz = ZoneInfo(timezone_name)
|
||||
except (ZoneInfoNotFoundError, ValueError):
|
||||
_LOGGER.warning("Unknown timezone %r; falling back to UTC", timezone_name)
|
||||
tz = ZoneInfo("UTC")
|
||||
else:
|
||||
tz = ZoneInfo("UTC")
|
||||
self._tz = tz
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True # virtual provider — always connected
|
||||
@@ -103,7 +124,8 @@ class SchedulerServiceProvider(ServiceProvider):
|
||||
collection_ids: list[str],
|
||||
tracker_state: dict[str, Any],
|
||||
) -> tuple[list[ServiceEvent], dict[str, Any]]:
|
||||
now = datetime.now(timezone.utc)
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
now = now_utc.astimezone(self._tz)
|
||||
# State uses {collection_id: {dict}} convention like other providers
|
||||
sched_state = tracker_state.get("scheduler", {})
|
||||
fire_count = sched_state.get("fire_count", 0) + 1
|
||||
@@ -115,6 +137,7 @@ class SchedulerServiceProvider(ServiceProvider):
|
||||
"current_time": now.strftime(self._time_format),
|
||||
"current_datetime": now.strftime(self._datetime_format),
|
||||
"weekday": _WEEKDAYS[now.weekday()],
|
||||
"timezone": self._tz.key,
|
||||
"custom_vars": dict(self._custom_variables),
|
||||
}
|
||||
# Flatten custom variables at top level for easy template access
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
@@ -19,34 +21,58 @@ class StorageBackend(Protocol):
|
||||
async def remove(self) -> None: ...
|
||||
|
||||
|
||||
def _read_file(path: Path) -> str | None:
|
||||
if not path.exists():
|
||||
return None
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _atomic_write(path: Path, payload: str) -> None:
|
||||
"""Write atomically: tmp file + rename. Prevents half-written files on crash."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
tmp.write_text(payload, encoding="utf-8")
|
||||
os.replace(tmp, path)
|
||||
|
||||
|
||||
def _remove_file(path: Path) -> None:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
|
||||
class JsonFileBackend:
|
||||
"""Simple JSON file storage backend."""
|
||||
"""Simple JSON file storage backend.
|
||||
|
||||
All blocking I/O is wrapped in ``asyncio.to_thread`` so callers can
|
||||
``await load() / save() / remove()`` without stalling the event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path) -> None:
|
||||
self._path = path
|
||||
|
||||
async def load(self) -> dict[str, Any] | None:
|
||||
if not self._path.exists():
|
||||
try:
|
||||
text = await asyncio.to_thread(_read_file, self._path)
|
||||
except OSError as err:
|
||||
_LOGGER.warning("Failed to load %s: %s", self._path, err)
|
||||
return None
|
||||
if text is None:
|
||||
return None
|
||||
try:
|
||||
text = self._path.read_text(encoding="utf-8")
|
||||
return json.loads(text)
|
||||
except (json.JSONDecodeError, OSError) as err:
|
||||
_LOGGER.warning("Failed to load %s: %s", self._path, err)
|
||||
except json.JSONDecodeError as err:
|
||||
_LOGGER.warning("Failed to parse %s: %s", self._path, err)
|
||||
return None
|
||||
|
||||
async def save(self, data: dict[str, Any]) -> None:
|
||||
payload = json.dumps(data, default=str)
|
||||
try:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._path.write_text(
|
||||
json.dumps(data, default=str), encoding="utf-8"
|
||||
)
|
||||
await asyncio.to_thread(_atomic_write, self._path, payload)
|
||||
except OSError as err:
|
||||
_LOGGER.error("Failed to save %s: %s", self._path, err)
|
||||
|
||||
async def remove(self) -> None:
|
||||
try:
|
||||
if self._path.exists():
|
||||
self._path.unlink()
|
||||
await asyncio.to_thread(_remove_file, self._path)
|
||||
except OSError as err:
|
||||
_LOGGER.error("Failed to remove %s: %s", self._path, err)
|
||||
|
||||
@@ -224,6 +224,7 @@ def build_template_context(
|
||||
ctx.setdefault("current_time", event.extra.get("current_time", ""))
|
||||
ctx.setdefault("current_datetime", event.extra.get("current_datetime", ""))
|
||||
ctx.setdefault("weekday", event.extra.get("weekday", ""))
|
||||
ctx.setdefault("timezone", event.extra.get("timezone", "UTC"))
|
||||
ctx.setdefault("custom_vars", event.extra.get("custom_vars", {}))
|
||||
|
||||
return ctx
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
📸 Photos from {% if public_url %}<a href="{{ public_url }}">{{ album_name }}</a>{% else %}"{{ album_name }}"{% endif %}:
|
||||
🗓️ Scheduled delivery — random photos from {% if public_url %}<a href="{{ public_url }}">{{ album_name }}</a>{% else %}"{{ album_name }}"{% endif %}:
|
||||
{%- for asset in assets %}
|
||||
• {%- if asset.type == "VIDEO" %} 🎬{% else %} 🖼️{% endif %} {% if asset.public_url %}<a href="{{ asset.public_url }}">{{ asset.filename }}</a>{% else %}{{ asset.filename }}{% endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
📸 Фото из {% if public_url %}<a href="{{ public_url }}">{{ album_name }}</a>{% else %}"{{ album_name }}"{% endif %}:
|
||||
🗓️ Доставка по расписанию — случайные фото из {% if public_url %}<a href="{{ public_url }}">{{ album_name }}</a>{% else %}"{{ album_name }}"{% endif %}:
|
||||
{%- for asset in assets %}
|
||||
• {%- if asset.type == "VIDEO" %} 🎬{% else %} 🖼️{% endif %} {% if asset.public_url %}<a href="{{ asset.public_url }}">{{ asset.filename }}</a>{% else %}{{ asset.filename }}{% endif %}
|
||||
{%- endfor %}
|
||||
{%- endfor %}
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "notify-bridge-server"
|
||||
version = "0.2.4"
|
||||
version = "0.5.0"
|
||||
description = "Standalone Notify Bridge server — FastAPI REST API with SQLite database"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
@@ -28,6 +28,7 @@ dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"httpx>=0.27",
|
||||
"aioresponses>=0.7",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -35,3 +36,14 @@ notify-bridge = "notify_bridge_server.main:run"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/notify_bridge_server"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
# The default filter doesn't let SQLAlchemy warnings fail the suite, which
|
||||
# matters because our migrations emit a handful of deprecation warnings we
|
||||
# don't want to suppress at source.
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning:passlib",
|
||||
"ignore::DeprecationWarning:bcrypt",
|
||||
]
|
||||
|
||||
@@ -24,6 +24,10 @@ _SETTING_KEYS = {
|
||||
"telegram_asset_cache_max_entries": None, # LRU cap for both caches
|
||||
"supported_locales": None, # comma-separated locale codes
|
||||
"timezone": "NOTIFY_BRIDGE_TIMEZONE", # IANA tz (e.g. "Europe/Warsaw"); empty = UTC
|
||||
# Logging — applied live via apply_log_levels() when changed.
|
||||
"log_level": "NOTIFY_BRIDGE_LOG_LEVEL", # DEBUG/INFO/WARNING/ERROR
|
||||
"log_format": "NOTIFY_BRIDGE_LOG_FORMAT", # text|json (requires restart to switch)
|
||||
"log_levels": "NOTIFY_BRIDGE_LOG_LEVELS", # module=LEVEL,module2=LEVEL
|
||||
}
|
||||
|
||||
_DEFAULTS = {
|
||||
@@ -35,12 +39,20 @@ _DEFAULTS = {
|
||||
"telegram_asset_cache_max_entries": "5000",
|
||||
"supported_locales": "en,ru",
|
||||
"timezone": "UTC",
|
||||
"log_level": "INFO",
|
||||
"log_format": "text",
|
||||
"log_levels": "",
|
||||
}
|
||||
|
||||
# Settings whose changes require dropping in-memory Telegram caches so the
|
||||
# next dispatch rebuilds them with the new parameters. Files are preserved.
|
||||
_CACHE_SETTING_KEYS = {"telegram_cache_ttl_hours", "telegram_asset_cache_max_entries"}
|
||||
|
||||
# Settings that change logging behaviour. ``log_level`` + ``log_levels`` apply
|
||||
# live via apply_log_levels(); ``log_format`` requires a restart because
|
||||
# changing it means swapping the handler formatter entirely.
|
||||
_LOG_SETTING_KEYS = {"log_level", "log_levels", "log_format"}
|
||||
|
||||
|
||||
async def get_setting(session: AsyncSession, key: str) -> str:
|
||||
"""Read a setting from DB, falling back to env var then default."""
|
||||
@@ -56,12 +68,19 @@ async def get_setting(session: AsyncSession, key: str) -> str:
|
||||
|
||||
|
||||
class SettingsUpdate(BaseModel):
|
||||
# Numeric fields declared as int|str so clients can send either form.
|
||||
# Svelte's bind:value on <input type="number"> coerces to a JS number,
|
||||
# so the frontend sends ints for these; older/manual clients may send
|
||||
# strings. We normalize to str before persisting.
|
||||
external_url: str | None = None
|
||||
telegram_webhook_secret: str | None = None
|
||||
telegram_cache_ttl_hours: str | None = None
|
||||
telegram_asset_cache_max_entries: str | None = None
|
||||
telegram_cache_ttl_hours: int | str | None = None
|
||||
telegram_asset_cache_max_entries: int | str | None = None
|
||||
supported_locales: str | None = None
|
||||
timezone: str | None = None
|
||||
log_level: str | None = None
|
||||
log_format: str | None = None
|
||||
log_levels: str | None = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
@@ -90,16 +109,26 @@ async def update_settings(
|
||||
old_base_url = await get_setting(session, "external_url")
|
||||
old_secret = await get_setting(session, "telegram_webhook_secret")
|
||||
old_cache_values = {k: await get_setting(session, k) for k in _CACHE_SETTING_KEYS}
|
||||
old_timezone = await get_setting(session, "timezone")
|
||||
old_log_values = {k: await get_setting(session, k) for k in _LOG_SETTING_KEYS}
|
||||
|
||||
for key in _SETTING_KEYS:
|
||||
value = getattr(body, key, None)
|
||||
if value is None:
|
||||
continue
|
||||
value_str = str(value)
|
||||
# GET masks the webhook secret as "***<last4>" so the real value is
|
||||
# never exposed to the frontend. If the client sends the mask back
|
||||
# (which happens on every save, since bind:value holds whatever GET
|
||||
# returned), treat it as "unchanged" — otherwise we'd overwrite the
|
||||
# real secret with its mask, silently breaking webhook HMAC.
|
||||
if key == "telegram_webhook_secret" and value_str.startswith("***"):
|
||||
continue
|
||||
row = await session.get(AppSetting, key)
|
||||
if row:
|
||||
row.value = value
|
||||
row.value = value_str
|
||||
else:
|
||||
row = AppSetting(key=key, value=value)
|
||||
row = AppSetting(key=key, value=value_str)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
|
||||
@@ -116,6 +145,33 @@ async def update_settings(
|
||||
|
||||
new_base_url = await get_setting(session, "external_url")
|
||||
new_secret = await get_setting(session, "telegram_webhook_secret")
|
||||
new_timezone = await get_setting(session, "timezone")
|
||||
new_log_values = {k: await get_setting(session, k) for k in _LOG_SETTING_KEYS}
|
||||
|
||||
# Apply live log-level changes (log_format still needs a restart).
|
||||
if (new_log_values["log_level"] != old_log_values["log_level"]
|
||||
or new_log_values["log_levels"] != old_log_values["log_levels"]):
|
||||
from ..logging_setup import apply_log_levels
|
||||
apply_log_levels(
|
||||
level=new_log_values["log_level"] or None,
|
||||
per_module_levels=new_log_values["log_levels"],
|
||||
)
|
||||
_LOGGER.info(
|
||||
"Log levels updated: root=%s overrides=%r",
|
||||
new_log_values["log_level"], new_log_values["log_levels"],
|
||||
)
|
||||
if new_log_values["log_format"] != old_log_values["log_format"]:
|
||||
_LOGGER.warning(
|
||||
"log_format changed from %r to %r — restart the server for it to take effect",
|
||||
old_log_values["log_format"], new_log_values["log_format"],
|
||||
)
|
||||
|
||||
# Cron triggers freeze their timezone at construction time, so a tz change
|
||||
# has no effect until the jobs are rebuilt — do that here, before we
|
||||
# return success, so the UI reflects the actual schedule immediately.
|
||||
if new_timezone != old_timezone:
|
||||
from ..services.scheduler import reschedule_cron_jobs_for_timezone_change
|
||||
await reschedule_cron_jobs_for_timezone_change()
|
||||
|
||||
# Update webhook secret in the webhook handler module
|
||||
if new_secret != old_secret:
|
||||
@@ -178,7 +234,10 @@ async def _reregister_webhooks(
|
||||
if res.get("success"):
|
||||
_LOGGER.info("Re-registered webhook for bot %d (%s)", bot.id, bot.name)
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Failed to re-register webhook for bot %d: %s",
|
||||
bot.id, res.get("error"),
|
||||
# Webhook re-register failure means the bot silently stops
|
||||
# delivering updates — this is operational visibility for an
|
||||
# admin, ERROR is appropriate.
|
||||
_LOGGER.error(
|
||||
"Failed to re-register webhook for bot %d (%s): %s",
|
||||
bot.id, bot.name, res.get("error"),
|
||||
)
|
||||
|
||||
@@ -93,7 +93,14 @@ async def update_email_bot(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
bot = await _get_user_bot(session, bot_id, user.id)
|
||||
for field, value in body.model_dump(exclude_unset=True).items():
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
# Reject the masked value the GET response returns so the stored password
|
||||
# is preserved if the user saves without retyping it.
|
||||
if "smtp_password" in updates:
|
||||
pw = updates["smtp_password"]
|
||||
if isinstance(pw, str) and pw.startswith("***"):
|
||||
updates.pop("smtp_password")
|
||||
for field, value in updates.items():
|
||||
setattr(bot, field, value)
|
||||
session.add(bot)
|
||||
await session.commit()
|
||||
|
||||
@@ -7,6 +7,11 @@ from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.notifications.ssrf import (
|
||||
UnsafeURLError,
|
||||
avalidate_outbound_url,
|
||||
)
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import MatrixBot, User
|
||||
@@ -33,6 +38,21 @@ class MatrixBotUpdate(BaseModel):
|
||||
display_name: str | None = None
|
||||
|
||||
|
||||
def _is_masked_secret(value: str | None) -> bool:
|
||||
"""True when a field still carries our masked placeholder."""
|
||||
return bool(value) and (value.startswith("***") or "..." in value)
|
||||
|
||||
|
||||
async def _validate_homeserver_url(url: str) -> None:
|
||||
"""Reject homeserver URLs that point to blocked networks."""
|
||||
try:
|
||||
await avalidate_outbound_url(url)
|
||||
except UnsafeURLError as err:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid homeserver_url: {err}"
|
||||
) from err
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_matrix_bots(
|
||||
user: User = Depends(get_current_user),
|
||||
@@ -50,6 +70,7 @@ async def create_matrix_bot(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
await _validate_homeserver_url(body.homeserver_url)
|
||||
bot = MatrixBot(user_id=user.id, **body.model_dump())
|
||||
session.add(bot)
|
||||
await session.commit()
|
||||
@@ -74,7 +95,19 @@ async def update_matrix_bot(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
bot = await _get_user_bot(session, bot_id, user.id)
|
||||
for field, value in body.model_dump(exclude_unset=True).items():
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
|
||||
# Re-validate homeserver_url whenever the client supplies a new one so
|
||||
# no private/loopback target can ever be saved, even via update.
|
||||
if "homeserver_url" in updates and updates["homeserver_url"]:
|
||||
await _validate_homeserver_url(updates["homeserver_url"])
|
||||
|
||||
# Never accept the masked placeholder the GET response returns. If the
|
||||
# client echoes it back, keep the stored secret.
|
||||
if "access_token" in updates and _is_masked_secret(updates["access_token"]):
|
||||
updates.pop("access_token")
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(bot, field, value)
|
||||
session.add(bot)
|
||||
await session.commit()
|
||||
@@ -108,15 +141,17 @@ async def test_matrix_bot(
|
||||
If room_id is not provided, just verifies the access token by calling /whoami.
|
||||
"""
|
||||
bot = await _get_user_bot(session, bot_id, user.id)
|
||||
# Defense-in-depth: even though create/update validate the URL, a bot row
|
||||
# written before this guard was added could still point at a blocked host.
|
||||
await _validate_homeserver_url(bot.homeserver_url)
|
||||
|
||||
import aiohttp
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
# Verify token with /whoami
|
||||
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
|
||||
headers = {"Authorization": f"Bearer {bot.access_token}"}
|
||||
try:
|
||||
async with http.get(whoami_url, headers=headers) as resp:
|
||||
async with http.get(whoami_url, headers=headers, allow_redirects=False) as resp:
|
||||
if resp.status != 200:
|
||||
body = await resp.text()
|
||||
return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"}
|
||||
@@ -126,7 +161,6 @@ async def test_matrix_bot(
|
||||
|
||||
result = {"success": True, "user_id": whoami.get("user_id", "")}
|
||||
|
||||
# Optionally send a test message
|
||||
if room_id:
|
||||
from ..services.notifier import _get_test_message
|
||||
from notify_bridge_core.notifications.matrix.client import MatrixClient
|
||||
@@ -148,7 +182,7 @@ def _response(bot: MatrixBot) -> dict:
|
||||
"name": bot.name,
|
||||
"icon": bot.icon,
|
||||
"homeserver_url": bot.homeserver_url,
|
||||
"access_token": f"{bot.access_token[:8]}...{bot.access_token[-4:]}" if len(bot.access_token) > 12 else "***",
|
||||
"access_token": f"***{bot.access_token[-4:]}" if len(bot.access_token) > 4 else "***",
|
||||
"display_name": bot.display_name,
|
||||
"created_at": bot.created_at.isoformat(),
|
||||
}
|
||||
|
||||
@@ -22,7 +22,8 @@ from ..database.models import (
|
||||
User,
|
||||
)
|
||||
from ..services.notifier import send_test_notification
|
||||
from ..services.test_dispatch import dispatch_test_notification
|
||||
from ..services.manual_dispatch import dispatch_test_notification
|
||||
from ..services.scheduler import reschedule_immich_dispatch_jobs
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -118,6 +119,7 @@ async def create_notification_tracker_target(
|
||||
session.add(tt)
|
||||
await session.commit()
|
||||
await session.refresh(tt)
|
||||
await reschedule_immich_dispatch_jobs()
|
||||
return await _tt_response(session, tt)
|
||||
|
||||
|
||||
@@ -164,6 +166,7 @@ async def update_notification_tracker_target(
|
||||
session.add(tt)
|
||||
await session.commit()
|
||||
await session.refresh(tt)
|
||||
await reschedule_immich_dispatch_jobs()
|
||||
return await _tt_response(session, tt)
|
||||
|
||||
|
||||
@@ -181,6 +184,7 @@ async def delete_notification_tracker_target(
|
||||
raise HTTPException(status_code=404, detail="Tracker-target link not found")
|
||||
await session.delete(tt)
|
||||
await session.commit()
|
||||
await reschedule_immich_dispatch_jobs()
|
||||
|
||||
|
||||
@router.post("/{tracker_target_id}/test/{test_type}")
|
||||
|
||||
@@ -11,13 +11,18 @@ from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import (
|
||||
EventLog,
|
||||
NotificationTarget,
|
||||
NotificationTracker,
|
||||
NotificationTrackerState,
|
||||
NotificationTrackerTarget,
|
||||
ServiceProvider,
|
||||
User,
|
||||
)
|
||||
from ..services.scheduler import schedule_tracker, unschedule_tracker
|
||||
from ..services.scheduler import (
|
||||
reschedule_immich_dispatch_jobs,
|
||||
schedule_tracker,
|
||||
unschedule_tracker,
|
||||
)
|
||||
from .helpers import get_owned_entity
|
||||
from .notification_tracker_targets import _tt_response
|
||||
|
||||
@@ -54,11 +59,79 @@ async def list_notification_trackers(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
# Batched loader: pull trackers, then all their tracker-target links in
|
||||
# a single query, then the referenced targets in a single query. Avoids
|
||||
# the old 1 + N + N*M pattern that ran ~60 round-trips for 10 trackers.
|
||||
result = await session.exec(
|
||||
select(NotificationTracker).where(NotificationTracker.user_id == user.id)
|
||||
)
|
||||
trackers = result.all()
|
||||
return [await _tracker_response(session, t) for t in trackers]
|
||||
trackers = list(result.all())
|
||||
if not trackers:
|
||||
return []
|
||||
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
tt_result = await session.exec(
|
||||
select(NotificationTrackerTarget).where(
|
||||
NotificationTrackerTarget.tracker_id.in_(tracker_ids)
|
||||
)
|
||||
)
|
||||
tt_rows = list(tt_result.all())
|
||||
|
||||
target_ids = {tt.target_id for tt in tt_rows}
|
||||
targets_by_id: dict[int, NotificationTarget] = {}
|
||||
if target_ids:
|
||||
tgt_result = await session.exec(
|
||||
select(NotificationTarget).where(NotificationTarget.id.in_(target_ids))
|
||||
)
|
||||
targets_by_id = {t.id: t for t in tgt_result.all()}
|
||||
|
||||
tts_by_tracker: dict[int, list[NotificationTrackerTarget]] = {}
|
||||
for tt in tt_rows:
|
||||
tts_by_tracker.setdefault(tt.tracker_id, []).append(tt)
|
||||
|
||||
return [
|
||||
_build_tracker_response(t, tts_by_tracker.get(t.id, []), targets_by_id)
|
||||
for t in trackers
|
||||
]
|
||||
|
||||
|
||||
def _build_tracker_response(
|
||||
t: NotificationTracker,
|
||||
tts: list[NotificationTrackerTarget],
|
||||
targets_by_id: dict[int, NotificationTarget],
|
||||
) -> dict:
|
||||
"""In-memory assembler for a tracker + its pre-loaded links/targets."""
|
||||
tracker_targets = []
|
||||
for tt in tts:
|
||||
target = targets_by_id.get(tt.target_id)
|
||||
tracker_targets.append({
|
||||
"id": tt.id,
|
||||
"tracker_id": tt.tracker_id,
|
||||
"target_id": tt.target_id,
|
||||
"target_name": target.name if target else None,
|
||||
"target_type": target.type if target else None,
|
||||
"target_icon": target.icon if target else None,
|
||||
"tracking_config_id": tt.tracking_config_id,
|
||||
"template_config_id": tt.template_config_id,
|
||||
"enabled": tt.enabled,
|
||||
"quiet_hours_start": tt.quiet_hours_start,
|
||||
"quiet_hours_end": tt.quiet_hours_end,
|
||||
"created_at": tt.created_at.isoformat(),
|
||||
})
|
||||
return {
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"icon": t.icon,
|
||||
"provider_id": t.provider_id,
|
||||
"collection_ids": t.collection_ids,
|
||||
"scan_interval": t.scan_interval,
|
||||
"batch_duration": t.batch_duration,
|
||||
"default_tracking_config_id": t.default_tracking_config_id,
|
||||
"default_template_config_id": t.default_template_config_id,
|
||||
"enabled": t.enabled,
|
||||
"tracker_targets": tracker_targets,
|
||||
"created_at": t.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
@@ -77,6 +150,7 @@ async def create_notification_tracker(
|
||||
await session.refresh(tracker)
|
||||
if tracker.enabled:
|
||||
await schedule_tracker(tracker.id, tracker.scan_interval)
|
||||
await reschedule_immich_dispatch_jobs()
|
||||
return await _tracker_response(session, tracker)
|
||||
|
||||
|
||||
@@ -107,6 +181,7 @@ async def update_notification_tracker(
|
||||
await schedule_tracker(tracker.id, tracker.scan_interval)
|
||||
else:
|
||||
await unschedule_tracker(tracker.id)
|
||||
await reschedule_immich_dispatch_jobs()
|
||||
return await _tracker_response(session, tracker)
|
||||
|
||||
|
||||
@@ -139,6 +214,7 @@ async def delete_notification_tracker(
|
||||
await session.delete(tracker)
|
||||
await session.commit()
|
||||
await unschedule_tracker(tracker_id)
|
||||
await reschedule_immich_dispatch_jobs()
|
||||
|
||||
|
||||
@router.post("/{tracker_id}/trigger")
|
||||
|
||||
@@ -306,16 +306,31 @@ async def update_provider(
|
||||
if body.icon is not None:
|
||||
provider.icon = body.icon
|
||||
|
||||
config_changed = body.config is not None and body.config != provider.config
|
||||
if body.config is not None:
|
||||
_validate_provider_config(provider.type, body.config)
|
||||
provider.config = body.config
|
||||
# Merge rather than replace so the masked secrets the frontend
|
||||
# receives on GET cannot silently nuke the stored values when the
|
||||
# user saves without re-entering them. Any field that still carries
|
||||
# our mask placeholder ("***…") is dropped from the incoming body.
|
||||
incoming = dict(body.config)
|
||||
for secret_field in (
|
||||
"api_key", "api_token", "webhook_secret", "password",
|
||||
"client_secret", "refresh_token",
|
||||
):
|
||||
value = incoming.get(secret_field)
|
||||
if isinstance(value, str) and value.startswith("***"):
|
||||
incoming.pop(secret_field, None)
|
||||
new_config = {**provider.config, **incoming}
|
||||
_validate_provider_config(provider.type, new_config)
|
||||
config_changed = new_config != provider.config
|
||||
provider.config = new_config
|
||||
|
||||
# Re-validate connection when config changes for known provider types
|
||||
if config_changed:
|
||||
test_result = await _validate_provider_connection(provider)
|
||||
if test_result.get("external_domain"):
|
||||
provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
|
||||
if config_changed:
|
||||
test_result = await _validate_provider_connection(provider)
|
||||
if test_result.get("external_domain"):
|
||||
provider.config = {
|
||||
**provider.config,
|
||||
"external_domain": test_result["external_domain"],
|
||||
}
|
||||
|
||||
session.add(provider)
|
||||
await session.commit()
|
||||
|
||||
@@ -242,6 +242,8 @@ async def get_template_variables(
|
||||
"current_date": "Current date (formatted)",
|
||||
"current_time": "Current time (formatted)",
|
||||
"current_datetime": "Current date and time (formatted)",
|
||||
"weekday": "Day of the week (Monday..Sunday)",
|
||||
"timezone": "IANA timezone used for current_date/time",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import TrackingConfig, User
|
||||
from ..services.scheduler import reschedule_immich_dispatch_jobs
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -127,6 +128,8 @@ async def create_config(
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
if config.provider_type == "immich":
|
||||
await reschedule_immich_dispatch_jobs()
|
||||
return _response(config)
|
||||
|
||||
|
||||
@@ -152,6 +155,8 @@ async def update_config(
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
if config.provider_type == "immich":
|
||||
await reschedule_immich_dispatch_jobs()
|
||||
return _response(config)
|
||||
|
||||
|
||||
@@ -164,8 +169,11 @@ async def delete_config(
|
||||
from .delete_protection import check_tracking_config, raise_if_used
|
||||
config = await _get(session, config_id, user.id)
|
||||
raise_if_used(await check_tracking_config(session, config.id), config.name)
|
||||
provider_type = config.provider_type
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
if provider_type == "immich":
|
||||
await reschedule_immich_dispatch_jobs()
|
||||
|
||||
|
||||
def _response(c: TrackingConfig) -> dict:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""User management API routes (admin only)."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
@@ -14,6 +15,15 @@ from ..auth.dependencies import require_admin
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import User
|
||||
|
||||
|
||||
async def _hash_password(password: str) -> str:
|
||||
"""Run bcrypt off the event loop. Matches the helper in auth/routes.py."""
|
||||
|
||||
def _work() -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
return await asyncio.to_thread(_work)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/users", tags=["users"])
|
||||
@@ -36,8 +46,12 @@ async def list_users(
|
||||
admin: User = Depends(require_admin),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""List all users (admin only)."""
|
||||
result = await session.exec(select(User))
|
||||
"""List all users (admin only).
|
||||
|
||||
Excludes the internal ``__system__`` placeholder (id=0) used as the
|
||||
owner of default templates/configs — it is never a real account.
|
||||
"""
|
||||
result = await session.exec(select(User).where(User.id != 0))
|
||||
return [
|
||||
{"id": u.id, "username": u.username, "role": u.role, "created_at": u.created_at.isoformat()}
|
||||
for u in result.all()
|
||||
@@ -61,7 +75,7 @@ async def create_user(
|
||||
|
||||
user = User(
|
||||
username=body.username,
|
||||
hashed_password=bcrypt.hashpw(body.password.encode(), bcrypt.gensalt()).decode(),
|
||||
hashed_password=await _hash_password(body.password),
|
||||
role=body.role if body.role in ("admin", "user") else "user",
|
||||
)
|
||||
session.add(user)
|
||||
@@ -162,7 +176,7 @@ async def reset_user_password(
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if len(body.new_password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
user.hashed_password = bcrypt.hashpw(body.new_password.encode(), bcrypt.gensalt()).decode()
|
||||
user.hashed_password = await _hash_password(body.new_password)
|
||||
# Invalidate all prior JWTs issued for this user — matches the self-serve
|
||||
# password-change path in auth/routes.py.
|
||||
user.token_version = (user.token_version or 1) + 1
|
||||
|
||||
@@ -37,6 +37,42 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/webhooks", tags=["webhooks"])
|
||||
|
||||
# Hard cap on inbound webhook body size (1 MiB is far larger than anything
|
||||
# legitimate providers send and keeps the worst-case memory footprint bounded
|
||||
# when a malicious peer lies about Content-Length or streams slowly).
|
||||
_MAX_WEBHOOK_BODY_BYTES = 1_000_000
|
||||
|
||||
|
||||
async def _read_bounded_body(request: Request, limit: int = _MAX_WEBHOOK_BODY_BYTES) -> bytes:
|
||||
"""Reject oversized inbound bodies before they exhaust memory.
|
||||
|
||||
First checks ``Content-Length`` (fast-path for honest peers), then
|
||||
streams the body in chunks enforcing the same cap on actual bytes
|
||||
received so a peer that lies about Content-Length cannot slip through.
|
||||
"""
|
||||
declared = request.headers.get("content-length")
|
||||
if declared:
|
||||
try:
|
||||
if int(declared) > limit:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Payload too large (max {limit} bytes)",
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid Content-Length")
|
||||
|
||||
chunks: list[bytes] = []
|
||||
size = 0
|
||||
async for chunk in request.stream():
|
||||
size += len(chunk)
|
||||
if size > limit:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Payload too large (max {limit} bytes)",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
async def _get_provider_by_token(
|
||||
session: AsyncSession, token: str, expected_type: str,
|
||||
@@ -169,7 +205,8 @@ async def _dispatch_webhook_event(
|
||||
))
|
||||
|
||||
# Dispatch to targets
|
||||
dispatcher = NotificationDispatcher()
|
||||
from ..services.http_session import get_http_session
|
||||
dispatcher = NotificationDispatcher(session=await get_http_session())
|
||||
target_configs = _build_target_configs(event, link_data, provider_config, app_tz)
|
||||
if target_configs:
|
||||
results = await dispatcher.dispatch(event, target_configs)
|
||||
@@ -203,7 +240,7 @@ async def gitea_webhook(token: str, request: Request):
|
||||
webhook_secret = (provider.config or {}).get("webhook_secret", "")
|
||||
|
||||
# Read raw body for HMAC check
|
||||
raw_body = await request.body()
|
||||
raw_body = await _read_bounded_body(request)
|
||||
|
||||
if not webhook_secret:
|
||||
raise HTTPException(
|
||||
@@ -221,8 +258,8 @@ async def gitea_webhook(token: str, request: Request):
|
||||
return {"ok": True, "skipped": "no event header"}
|
||||
|
||||
try:
|
||||
payload = await request.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
payload = json.loads(raw_body.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON")
|
||||
|
||||
event = parse_gitea_webhook(event_header, payload, provider.name)
|
||||
@@ -280,10 +317,10 @@ async def planka_webhook(token: str, request: Request):
|
||||
if not _verify_planka_token(webhook_secret, request):
|
||||
raise HTTPException(status_code=403, detail="Invalid token")
|
||||
|
||||
# Parse payload
|
||||
# Parse payload from the bounded raw_body we already read.
|
||||
try:
|
||||
payload = await request.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
payload = json.loads(raw_body.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON")
|
||||
|
||||
event_type = payload.get("type", "")
|
||||
@@ -446,23 +483,22 @@ async def generic_webhook(token: str, request: Request):
|
||||
store_payloads = provider_config.get("store_payloads", True)
|
||||
max_stored = min(max(int(provider_config.get("max_stored_payloads", 20)), 1), 100)
|
||||
|
||||
raw_body = await request.body()
|
||||
raw_body = await _read_bounded_body(request)
|
||||
|
||||
# Enforce payload size limit BEFORE parsing JSON
|
||||
if len(raw_body) > 1_000_000:
|
||||
raise HTTPException(status_code=413, detail="Payload too large (max 1 MB)")
|
||||
# Bounded read above already enforces the size cap; no need to re-check.
|
||||
|
||||
if not _verify_generic_webhook_auth(provider_config, request, raw_body):
|
||||
raise HTTPException(status_code=403, detail="Authentication failed")
|
||||
|
||||
safe_headers = _filter_headers(dict(request.headers))
|
||||
|
||||
# Parse JSON payload
|
||||
# Parse JSON payload from the already-bounded raw_body (request.body()
|
||||
# has been consumed, so request.json() is no longer usable here).
|
||||
try:
|
||||
payload = await request.json()
|
||||
payload = json.loads(raw_body.decode("utf-8"))
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("Payload must be a JSON object")
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
|
||||
if store_payloads:
|
||||
async with AsyncSession(get_engine()) as log_session:
|
||||
await _save_webhook_log(
|
||||
|
||||
@@ -7,30 +7,51 @@ import jwt
|
||||
from ..config import settings
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
_LEEWAY_SECONDS = 10
|
||||
|
||||
|
||||
def _now_utc() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def create_access_token(user_id: int, role: str, token_version: int = 1) -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
now = _now_utc()
|
||||
expire = now + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
payload = {
|
||||
"iss": settings.jwt_issuer,
|
||||
"aud": settings.jwt_audience,
|
||||
"sub": str(user_id),
|
||||
"role": role,
|
||||
"type": "access",
|
||||
"ver": token_version,
|
||||
"iat": now,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: int, token_version: int = 1) -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.refresh_token_expire_days)
|
||||
now = _now_utc()
|
||||
expire = now + timedelta(days=settings.refresh_token_expire_days)
|
||||
payload = {
|
||||
"iss": settings.jwt_issuer,
|
||||
"aud": settings.jwt_audience,
|
||||
"sub": str(user_id),
|
||||
"type": "refresh",
|
||||
"ver": token_version,
|
||||
"iat": now,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
return jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM])
|
||||
return jwt.decode(
|
||||
token,
|
||||
settings.secret_key,
|
||||
algorithms=[ALGORITHM],
|
||||
audience=settings.jwt_audience,
|
||||
issuer=settings.jwt_issuer,
|
||||
leeway=_LEEWAY_SECONDS,
|
||||
options={"require": ["exp", "sub", "iss", "aud", "type"]},
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Authentication API routes."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter
|
||||
@@ -16,7 +18,9 @@ from .jwt import create_access_token, create_refresh_token, decode_token
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
# Default rate limit applied by SlowAPIMiddleware to every route that does NOT
|
||||
# specify its own @limiter.limit(...) — protects against blanket abuse.
|
||||
limiter = Limiter(key_func=get_remote_address, default_limits=["600/minute"])
|
||||
|
||||
|
||||
class SetupRequest(BaseModel):
|
||||
@@ -45,27 +49,52 @@ class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
async def _hash_password(password: str) -> str:
|
||||
"""bcrypt.hashpw is CPU-bound (~200-500ms); never run it on the event loop."""
|
||||
|
||||
def _work() -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
return await asyncio.to_thread(_work)
|
||||
|
||||
|
||||
def _verify_password(password: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
async def _verify_password(password: str, hashed: str) -> bool:
|
||||
def _work() -> bool:
|
||||
try:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
except ValueError:
|
||||
# Malformed hash in DB — treat as mismatch, never raise to caller.
|
||||
return False
|
||||
|
||||
return await asyncio.to_thread(_work)
|
||||
|
||||
|
||||
@router.post("/setup", response_model=TokenResponse)
|
||||
@limiter.limit("3/minute")
|
||||
async def setup(request: Request, body: SetupRequest, session: AsyncSession = Depends(get_session)):
|
||||
result = await session.exec(select(func.count()).select_from(User))
|
||||
count = result.one()
|
||||
if count > 0:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Setup already completed.")
|
||||
|
||||
if len(body.password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
user = User(username=body.username, hashed_password=_hash_password(body.password), role="admin")
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
# Compute hash BEFORE opening the transaction so we don't hold a writer lock
|
||||
# during the CPU-bound bcrypt work.
|
||||
hashed = await _hash_password(body.password)
|
||||
|
||||
# Serialize setup via an INSERT-inside-transaction-with-count-guard.
|
||||
# SQLite's writer lock plus the count check inside the transaction closes
|
||||
# the TOCTOU window between two concurrent POSTs. We ignore id=0 — that's
|
||||
# the internal "__system__" placeholder used for ownership of default
|
||||
# templates, never a real admin.
|
||||
async with session.begin():
|
||||
result = await session.exec(
|
||||
select(func.count()).select_from(User).where(User.id != 0)
|
||||
)
|
||||
count = result.one()
|
||||
if count > 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Setup already completed.",
|
||||
)
|
||||
user = User(username=body.username, hashed_password=hashed, role="admin")
|
||||
session.add(user)
|
||||
await session.refresh(user)
|
||||
|
||||
return TokenResponse(
|
||||
@@ -79,7 +108,13 @@ async def setup(request: Request, body: SetupRequest, session: AsyncSession = De
|
||||
async def login(request: Request, body: LoginRequest, session: AsyncSession = Depends(get_session)):
|
||||
result = await session.exec(select(User).where(User.username == body.username))
|
||||
user = result.first()
|
||||
if not user or not _verify_password(body.password, user.hashed_password):
|
||||
# Always run a bcrypt verification to keep the response time constant,
|
||||
# preventing username-enumeration via timing side channel.
|
||||
password_ok = await _verify_password(
|
||||
body.password,
|
||||
user.hashed_password if user else "$2b$12$" + "a" * 53,
|
||||
)
|
||||
if not user or not password_ok:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
|
||||
|
||||
return TokenResponse(
|
||||
@@ -124,16 +159,18 @@ class PasswordChangeRequest(BaseModel):
|
||||
|
||||
|
||||
@router.put("/password")
|
||||
@limiter.limit("10/minute")
|
||||
async def change_password(
|
||||
request: Request,
|
||||
body: PasswordChangeRequest,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
if not _verify_password(body.current_password, user.hashed_password):
|
||||
if not await _verify_password(body.current_password, user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
||||
if len(body.new_password) < 8:
|
||||
raise HTTPException(status_code=400, detail="New password must be at least 8 characters")
|
||||
user.hashed_password = _hash_password(body.new_password)
|
||||
user.hashed_password = await _hash_password(body.new_password)
|
||||
user.token_version = (user.token_version or 1) + 1
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
@@ -141,7 +178,12 @@ async def change_password(
|
||||
|
||||
|
||||
@router.get("/needs-setup")
|
||||
async def needs_setup(session: AsyncSession = Depends(get_session)):
|
||||
result = await session.exec(select(func.count()).select_from(User))
|
||||
@limiter.limit("30/minute")
|
||||
async def needs_setup(request: Request, session: AsyncSession = Depends(get_session)):
|
||||
# Exclude the internal __system__ placeholder (id=0) from the count so
|
||||
# a fresh install still reports needs_setup=True.
|
||||
result = await session.exec(
|
||||
select(func.count()).select_from(User).where(User.id != 0)
|
||||
)
|
||||
count = result.one()
|
||||
return {"needs_setup": count == 0}
|
||||
|
||||
@@ -41,6 +41,36 @@ _rate_limits: TTLCache = TTLCache(maxsize=10000, ttl=3600)
|
||||
# Maximum responses per command to avoid Telegram rate limits
|
||||
_MAX_RESPONSES_PER_COMMAND = 5
|
||||
|
||||
# Commands that fetch assets from the service provider and usually reply
|
||||
# with media — "uploading photo" is the accurate UX hint while we wait on
|
||||
# the provider API + Telegram upload.
|
||||
_UPLOAD_PHOTO_COMMANDS = frozenset({
|
||||
"latest", "random", "favorites", "memory",
|
||||
"search", "find", "person", "place",
|
||||
})
|
||||
|
||||
# Commands that fetch from the provider but reply with text only.
|
||||
# "typing" is accurate; we still want an indicator because the fetch is slow.
|
||||
_TYPING_COMMANDS = frozenset({"summary"})
|
||||
|
||||
|
||||
def classify_command_chat_action(text: str) -> str | None:
|
||||
"""Return the Telegram chat-action hint to show for this command, or None.
|
||||
|
||||
The classification is by command name alone — good enough for the
|
||||
cases where a chat action is worthwhile (slow provider fetches). Fast
|
||||
DB-only commands (``/status``, ``/albums``, ``/events``, ``/people``)
|
||||
return ``None`` and skip the indicator entirely.
|
||||
"""
|
||||
cmd, _, _ = parse_command(text)
|
||||
if not cmd:
|
||||
return None
|
||||
if cmd in _UPLOAD_PHOTO_COMMANDS:
|
||||
return "upload_photo"
|
||||
if cmd in _TYPING_COMMANDS:
|
||||
return "typing"
|
||||
return None
|
||||
|
||||
|
||||
def _check_rate_limit(bot_id: int, chat_id: str, cmd: str, limits: dict[str, int]) -> int | None:
|
||||
"""Check rate limit. Returns seconds to wait, or None if OK."""
|
||||
@@ -78,13 +108,18 @@ def _render_cmd_template(
|
||||
"""Render a locale-aware command template. Falls back to 'en'."""
|
||||
template_str = _resolve_template(templates, slot_name, locale)
|
||||
if not template_str:
|
||||
_LOGGER.warning("No command template found for slot '%s' locale '%s'", slot_name, locale)
|
||||
# Missing template = user sees "[No template: X]" — this is an ERROR,
|
||||
# not a warning. Broken replies must stand out in production logs.
|
||||
_LOGGER.error("No command template found for slot '%s' locale '%s'", slot_name, locale)
|
||||
return f"[No template: {slot_name}]"
|
||||
try:
|
||||
tmpl = _compile_template(template_str)
|
||||
return tmpl.render(**context)
|
||||
except Exception as e:
|
||||
_LOGGER.warning("Failed to render command template '%s': %s", slot_name, e)
|
||||
except Exception:
|
||||
_LOGGER.error(
|
||||
"Failed to render command template '%s' locale=%s — user will see a broken reply",
|
||||
slot_name, locale, exc_info=True,
|
||||
)
|
||||
return f"[Template error: {slot_name}]"
|
||||
|
||||
|
||||
@@ -266,6 +301,10 @@ async def handle_command(
|
||||
# Rate limit check (once per command, shared across all trackers)
|
||||
wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits)
|
||||
if wait is not None:
|
||||
_LOGGER.info(
|
||||
"Rate-limited /%s for bot=%d chat=%s — %ds cooldown remaining",
|
||||
cmd, bot.id, chat_id, wait,
|
||||
)
|
||||
text_resp = _render_cmd_template(merged_templates, "rate_limited", locale, {"wait": wait})
|
||||
return [CommandResponse(text=text_resp)]
|
||||
|
||||
@@ -292,8 +331,8 @@ async def handle_command(
|
||||
for tracker, config, provider, listener in ctx_tuples:
|
||||
if len(responses) >= _MAX_RESPONSES_PER_COMMAND:
|
||||
_LOGGER.warning(
|
||||
"Truncated command responses at %d for bot %d cmd /%s",
|
||||
_MAX_RESPONSES_PER_COMMAND, bot.id, cmd,
|
||||
"Truncated command responses at %d for bot=%d chat=%s cmd=/%s (listener context size=%d)",
|
||||
_MAX_RESPONSES_PER_COMMAND, bot.id, chat_id, cmd, len(ctx_tuples),
|
||||
)
|
||||
break
|
||||
|
||||
@@ -367,25 +406,33 @@ async def send_reply(
|
||||
bot_token: str, chat_id: str, text: str, reply_to_message_id: int | None = None,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
) -> None:
|
||||
"""Send a text reply via TelegramClient.
|
||||
"""Send a text reply to a chat.
|
||||
|
||||
Command responses are listings (albums, people, events, ...) that embed
|
||||
multiple links; Telegram's default behavior of rendering a preview of
|
||||
the first URL is almost never what the user wants and clashes with the
|
||||
"Disable link previews" toggle operators set on their Telegram target.
|
||||
We always pass ``disable_web_page_preview=True`` here.
|
||||
Thin wrapper that goes through the single ``services.telegram_send``
|
||||
entry point so commands and notifications share one routine — same
|
||||
HTTP session pool, same file_id caches.
|
||||
|
||||
Command responses are listings (albums, people, events, ...) that
|
||||
embed multiple links; Telegram's default behavior of rendering a
|
||||
preview of the first URL is almost never what the user wants and
|
||||
clashes with the "Disable link previews" toggle operators set on
|
||||
their Telegram target. We always pass
|
||||
``disable_web_page_preview=True`` here.
|
||||
"""
|
||||
if session is None:
|
||||
from ..services.http_session import get_http_session
|
||||
session = await get_http_session()
|
||||
client = TelegramClient(session, bot_token)
|
||||
result = await client.send_message(
|
||||
chat_id, text,
|
||||
from ..services.telegram_send import send_telegram_message
|
||||
|
||||
result = await send_telegram_message(
|
||||
bot_token, chat_id, text,
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
disable_web_page_preview=True,
|
||||
)
|
||||
if not result.get("success"):
|
||||
_LOGGER.warning("Telegram reply failed: %s", result.get("error"))
|
||||
# User-visible failure: the bot's reply never reached the chat.
|
||||
_LOGGER.error(
|
||||
"Telegram reply failed (chat=%s reply_to=%s len=%d): code=%s error=%r",
|
||||
chat_id, reply_to_message_id, len(text or ""),
|
||||
result.get("error_code"), result.get("error"),
|
||||
)
|
||||
|
||||
|
||||
async def send_media_group(
|
||||
@@ -393,35 +440,47 @@ async def send_media_group(
|
||||
reply_to_message_id: int | None = None,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
) -> None:
|
||||
"""Send media items via TelegramClient.send_notification."""
|
||||
"""Send media items via the shared Telegram routine.
|
||||
|
||||
``media_items`` must already be in TelegramClient asset format — each
|
||||
entry contains ``type`` (``"photo"``/``"video"``/``"document"``),
|
||||
``url``, optional ``cache_key``, and optional ``headers``. Provider
|
||||
command handlers build this format via
|
||||
``build_telegram_asset_entry`` — the same helper the notification
|
||||
dispatcher uses — so videos keep their ``"video"`` type and point at
|
||||
a real video URL instead of a still thumbnail.
|
||||
|
||||
Uses ``services.telegram_send.send_telegram_media`` so the URL cache
|
||||
and asset cache are wired in exactly like the notification path.
|
||||
Repeated ``/latest`` / ``/random`` commands that match previously-sent
|
||||
assets hit the cache and skip the re-upload.
|
||||
"""
|
||||
if not media_items:
|
||||
# This is what happened in the /random blind spot: the text reply
|
||||
# was sent, but the media follow-up was silently skipped because
|
||||
# the caller passed an empty media list. Surface it so we can see
|
||||
# it in the log and correlate with the text message.
|
||||
_LOGGER.warning(
|
||||
"send_media_group called with 0 items (chat=%s reply_to=%s) — no media will be delivered",
|
||||
chat_id, reply_to_message_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Convert command handler media format to TelegramClient asset format
|
||||
assets = []
|
||||
for item in media_items:
|
||||
assets.append({
|
||||
"type": "photo",
|
||||
"url": item.get("thumbnail_url", ""),
|
||||
"cache_key": item.get("asset_id", ""),
|
||||
"headers": {"x-api-key": item.get("api_key", "")},
|
||||
})
|
||||
from ..services.telegram_send import send_telegram_media
|
||||
|
||||
# Build caption from first item
|
||||
captions = [item.get("caption", "") for item in media_items if item.get("caption")]
|
||||
caption = "\n".join(captions) if captions else None
|
||||
|
||||
if session is None:
|
||||
from ..services.http_session import get_http_session
|
||||
session = await get_http_session()
|
||||
client = TelegramClient(session, bot_token)
|
||||
result = await client.send_notification(
|
||||
chat_id, assets=assets, caption=caption,
|
||||
result = await send_telegram_media(
|
||||
bot_token, chat_id, media_items,
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
chat_action=None,
|
||||
)
|
||||
if not result.get("success"):
|
||||
_LOGGER.warning("Telegram media group failed: %s", result.get("error"))
|
||||
# User-visible failure: media promised by the text reply never arrived.
|
||||
_LOGGER.error(
|
||||
"Telegram media group failed (chat=%s items=%d reply_to=%s): code=%s error=%r failed_at_chunk=%s",
|
||||
chat_id, len(media_items), reply_to_message_id,
|
||||
result.get("error_code"), result.get("error"),
|
||||
result.get("failed_at_chunk"),
|
||||
)
|
||||
|
||||
|
||||
async def register_commands_with_telegram(bot: TelegramBot) -> bool:
|
||||
|
||||
@@ -6,7 +6,11 @@ import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from notify_bridge_core.providers.immich.asset_utils import get_public_url
|
||||
from notify_bridge_core.notifications.telegram.media import build_telegram_asset_entry
|
||||
from notify_bridge_core.providers.immich.asset_utils import (
|
||||
build_asset_media_urls,
|
||||
get_public_url,
|
||||
)
|
||||
|
||||
from ..handler import _render_cmd_template
|
||||
|
||||
@@ -74,13 +78,16 @@ def build_asset_dict(
|
||||
) -> dict[str, Any]:
|
||||
"""Build a rich asset dict for command templates from an ImmichAssetInfo or raw dict."""
|
||||
if isinstance(asset, dict):
|
||||
# Immich raw search responses nest geo under exifInfo — pull it out so
|
||||
# templates can use flat asset.city / asset.country.
|
||||
exif = asset.get("exifInfo") or {}
|
||||
d = {
|
||||
"id": asset.get("id", ""),
|
||||
"originalFileName": asset.get("originalFileName", asset.get("filename", "")),
|
||||
"type": asset.get("type", "IMAGE"),
|
||||
"createdAt": asset.get("createdAt", asset.get("created_at", asset.get("fileCreatedAt", ""))),
|
||||
"city": asset.get("city", ""),
|
||||
"country": asset.get("country", ""),
|
||||
"city": asset.get("city") or exif.get("city") or "",
|
||||
"country": asset.get("country") or exif.get("country") or "",
|
||||
"is_favorite": asset.get("is_favorite", asset.get("isFavorite", False)),
|
||||
"public_url": asset.get("public_url", public_url),
|
||||
}
|
||||
@@ -123,16 +130,47 @@ def _format_assets(
|
||||
})
|
||||
|
||||
if response_mode == "media":
|
||||
# Reuse the same URL rule (build_asset_media_urls) and entry builder
|
||||
# (build_telegram_asset_entry) as the notification dispatcher so both
|
||||
# paths agree on video → /video/playback and photo → thumbnail. When
|
||||
# these diverged, Telegram rendered a still JPEG for each video in
|
||||
# the media group instead of the real clip.
|
||||
#
|
||||
# We deliberately do NOT pass ``cache_key`` here. TelegramClient
|
||||
# derives it from the URL as ``<host>:<uuid>`` — identical to what
|
||||
# the notification dispatcher produces via extract_asset_id_from_url.
|
||||
# Passing the bare UUID would put command writes in a separate
|
||||
# namespace from notification writes, so neither path could hit the
|
||||
# other's cached file_ids (which is what made the cache look empty
|
||||
# from the WebUI after running /random).
|
||||
media_items: list[dict[str, Any]] = []
|
||||
dropped = 0
|
||||
for asset in assets:
|
||||
asset_id = asset.get("id", "")
|
||||
media_items.append({
|
||||
"type": "photo",
|
||||
"asset_id": asset_id,
|
||||
"caption": "",
|
||||
"thumbnail_url": f"{client.url}/api/assets/{asset_id}/thumbnail?size=preview",
|
||||
"api_key": client.api_key,
|
||||
})
|
||||
asset_type = (asset.get("type") or "").upper()
|
||||
preview_url, _ = build_asset_media_urls(client.url, asset_id, asset_type)
|
||||
entry = build_telegram_asset_entry(
|
||||
url=preview_url,
|
||||
media_type="video" if asset_type == "VIDEO" else "image",
|
||||
api_key=client.api_key,
|
||||
internal_url=client.url,
|
||||
)
|
||||
if entry is not None:
|
||||
media_items.append(entry)
|
||||
else:
|
||||
dropped += 1
|
||||
_LOGGER.warning(
|
||||
"Dropped asset from /%s media payload: id=%s type=%s (empty preview URL)",
|
||||
cmd, asset_id, asset_type,
|
||||
)
|
||||
if not media_items and assets:
|
||||
# All assets were filtered out before reaching Telegram. The user
|
||||
# will see the text reply but no media — surface it here so the
|
||||
# log shows WHY the media group ended up empty.
|
||||
_LOGGER.warning(
|
||||
"/%s media payload empty: %d asset(s) in, 0 out (all dropped)",
|
||||
cmd, len(assets),
|
||||
)
|
||||
# Return text message + media items — text is sent first, media as reply
|
||||
return {"text": text, "media": media_items}
|
||||
|
||||
|
||||
@@ -143,7 +143,16 @@ async def _cmd_immich(
|
||||
# chat). ``None`` = no filter (rare); empty set = show nothing (common
|
||||
# when the chat has no tracker routing).
|
||||
if allowed_album_ids is not None:
|
||||
before = len(all_album_ids)
|
||||
all_album_ids = [aid for aid in all_album_ids if aid in allowed_album_ids]
|
||||
if not all_album_ids:
|
||||
# A command that sees zero albums is a routing/tracker config issue
|
||||
# the operator needs to notice — otherwise the user gets
|
||||
# "no results" with no hint at why.
|
||||
_LOGGER.info(
|
||||
"Command /%s has empty album scope for provider=%d (had %d trackers, chat scope allowed %d)",
|
||||
cmd, provider.id, before, len(allowed_album_ids),
|
||||
)
|
||||
|
||||
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
|
||||
|
||||
|
||||
@@ -5,17 +5,14 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from ..handler import _render_cmd_template
|
||||
from .common import _format_assets
|
||||
from .common import _format_assets, build_asset_dict
|
||||
|
||||
|
||||
def _enrich_assets(assets: list[dict[str, Any]], asset_public_urls: dict[str, str]) -> list[dict[str, Any]]:
|
||||
"""Add public_url to assets from the pre-built map. Returns new list without mutating inputs."""
|
||||
if not asset_public_urls:
|
||||
return assets
|
||||
"""Normalize raw Immich assets and attach public_url from the pre-built map."""
|
||||
pub = asset_public_urls or {}
|
||||
return [
|
||||
{**asset, "public_url": asset_public_urls.get(asset.get("id", ""), "")}
|
||||
if asset.get("id", "") in asset_public_urls and not asset.get("public_url")
|
||||
else asset
|
||||
build_asset_dict(asset, public_url=pub.get(asset.get("id", ""), ""))
|
||||
for asset in assets
|
||||
]
|
||||
|
||||
|
||||
@@ -4,19 +4,23 @@ from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.log_context import bind_log_context
|
||||
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
||||
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import TelegramBot, TelegramChat
|
||||
from ..services.telegram import save_chat_from_webhook
|
||||
from ..services.telegram_send import telegram_chat_action
|
||||
from .base import CommandResponse
|
||||
from .handler import handle_command, send_media_group, send_reply
|
||||
from .handler import classify_command_chat_action, handle_command, send_media_group, send_reply
|
||||
from .parser import parse_command
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -92,17 +96,62 @@ async def telegram_webhook(
|
||||
)
|
||||
)).first()
|
||||
if not chat_row or not chat_row.commands_enabled:
|
||||
_LOGGER.info(
|
||||
"Command ignored — commands disabled for bot=%s chat=%s text=%r",
|
||||
bot_id, chat_id, text[:64],
|
||||
)
|
||||
return {"ok": True, "skipped": "commands_disabled"}
|
||||
effective_lang = chat_row.language_override or msg_language
|
||||
message_id = message.get("message_id")
|
||||
responses = await handle_command(bot, chat_id, text, language_code=effective_lang)
|
||||
if responses:
|
||||
for resp in responses:
|
||||
if resp.text:
|
||||
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||
if resp.media:
|
||||
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||
return {"ok": True}
|
||||
|
||||
cmd_name, _, _ = parse_command(text)
|
||||
update_id = update.get("update_id")
|
||||
request_id = f"tg:{update_id}" if update_id is not None else f"tg:msg{message_id}"
|
||||
|
||||
with bind_log_context(
|
||||
request_id=request_id,
|
||||
command=cmd_name or "-",
|
||||
chat_id=chat_id,
|
||||
bot_id=bot_id,
|
||||
):
|
||||
started = time.monotonic()
|
||||
_LOGGER.info("Command received: /%s args=%r lang=%s", cmd_name, text[:200], effective_lang)
|
||||
try:
|
||||
async with telegram_chat_action(
|
||||
bot_token, chat_id, classify_command_chat_action(text),
|
||||
):
|
||||
responses = await handle_command(bot, chat_id, text, language_code=effective_lang)
|
||||
if not responses:
|
||||
_LOGGER.info(
|
||||
"Command produced no response (cmd=%r) after %.0f ms",
|
||||
cmd_name, (time.monotonic() - started) * 1000,
|
||||
)
|
||||
return {"ok": True, "skipped": "no_response"}
|
||||
text_count = sum(1 for r in responses if r.text)
|
||||
media_count = sum(len(r.media or []) for r in responses)
|
||||
_LOGGER.info(
|
||||
"Command dispatching %d response(s): text=%d media_items=%d",
|
||||
len(responses), text_count, media_count,
|
||||
)
|
||||
for idx, resp in enumerate(responses):
|
||||
if resp.text:
|
||||
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||
if resp.media:
|
||||
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||
_LOGGER.info(
|
||||
"Command /%s completed in %.0f ms (responses=%d media=%d)",
|
||||
cmd_name, (time.monotonic() - started) * 1000,
|
||||
len(responses), media_count,
|
||||
)
|
||||
return {"ok": True}
|
||||
except Exception:
|
||||
_LOGGER.exception(
|
||||
"Command /%s raised after %.0f ms",
|
||||
cmd_name, (time.monotonic() - started) * 1000,
|
||||
)
|
||||
# Return 200 so Telegram doesn't retry the same update — we
|
||||
# already logged the failure and can't usefully reprocess.
|
||||
return {"ok": True, "error": "handler_exception"}
|
||||
|
||||
return {"ok": True, "skipped": "not_a_command"}
|
||||
|
||||
|
||||
@@ -2,8 +2,20 @@
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
# Secret keys we will actively refuse. These cover the default template value
|
||||
# and dev-only literals that have appeared in scripts or documentation.
|
||||
_FORBIDDEN_SECRETS: frozenset[str] = frozenset(
|
||||
{
|
||||
"change-me-in-production",
|
||||
"test-secret-key-minimum-32-chars",
|
||||
"dev-secret-key-not-for-production",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
@@ -13,29 +25,25 @@ class Settings(BaseSettings):
|
||||
|
||||
secret_key: str = "change-me-in-production"
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.secret_key == "change-me-in-production":
|
||||
raise ValueError(
|
||||
"SECURITY: Refusing to start with the default secret_key. "
|
||||
"Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) "
|
||||
"before starting the server (debug mode included)."
|
||||
)
|
||||
if len(self.secret_key) < 32:
|
||||
raise ValueError(
|
||||
"SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters."
|
||||
)
|
||||
if "*" in self.cors_allowed_origins.split(","):
|
||||
raise ValueError(
|
||||
"SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled."
|
||||
)
|
||||
|
||||
access_token_expire_minutes: int = 60
|
||||
access_token_expire_minutes: int = 15
|
||||
refresh_token_expire_days: int = 30
|
||||
|
||||
jwt_issuer: str = "notify-bridge"
|
||||
jwt_audience: str = "notify-bridge-api"
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8420
|
||||
debug: bool = False
|
||||
|
||||
# Comma-separated list of trusted proxy IPs uvicorn will honor for
|
||||
# X-Forwarded-For / X-Forwarded-Proto. Use "*" ONLY when you trust the
|
||||
# network (never directly on the internet). Default matches uvicorn.
|
||||
forwarded_allow_ips: str = "127.0.0.1"
|
||||
|
||||
# How long to wait for in-flight requests / scheduler jobs before force
|
||||
# killing on SIGTERM.
|
||||
graceful_shutdown_seconds: int = 60
|
||||
|
||||
anthropic_api_key: str = ""
|
||||
ai_model: str = "claude-sonnet-4-20250514"
|
||||
ai_max_tokens: int = 1024
|
||||
@@ -48,8 +56,61 @@ class Settings(BaseSettings):
|
||||
static_dir: str = ""
|
||||
"""Path to frontend static files. Set to serve SvelteKit build via FastAPI (e.g. /app/static in Docker)."""
|
||||
|
||||
# --- Logging ---
|
||||
log_level: str = "INFO"
|
||||
"""Root log level for the app loggers (``DEBUG``/``INFO``/``WARNING``/``ERROR``)."""
|
||||
|
||||
log_format: str = "text"
|
||||
"""Log output format: ``text`` (human-readable) or ``json`` (one object per line)."""
|
||||
|
||||
log_levels: str = ""
|
||||
"""Comma-separated per-module overrides, e.g. ``notify_bridge_core.notifications.telegram.client=DEBUG,sqlalchemy.engine=INFO``."""
|
||||
|
||||
# --- Retention ---
|
||||
event_log_retention_days: int = 30
|
||||
"""Days of event_log history to retain. 0 disables the retention job."""
|
||||
|
||||
pre_migrate_snapshot_keep: int = 5
|
||||
"""Number of pre-migration DB snapshots to keep in ``data_dir/backups/``.
|
||||
0 disables snapshotting entirely. Each snapshot is produced at boot
|
||||
before migrations run using SQLite's ``VACUUM INTO`` (atomic, consistent).
|
||||
"""
|
||||
|
||||
model_config = {"env_prefix": "NOTIFY_BRIDGE_"}
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.secret_key in _FORBIDDEN_SECRETS:
|
||||
raise ValueError(
|
||||
"SECURITY: Refusing to start with a known/default secret_key. "
|
||||
"Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) "
|
||||
"before starting the server."
|
||||
)
|
||||
if len(self.secret_key) < 32:
|
||||
raise ValueError(
|
||||
"SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters."
|
||||
)
|
||||
origins = [o.strip() for o in self.cors_allowed_origins.split(",") if o.strip()]
|
||||
if "*" in origins:
|
||||
raise ValueError(
|
||||
"SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled."
|
||||
)
|
||||
for origin in origins:
|
||||
parsed = urlparse(origin)
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||||
raise ValueError(
|
||||
f"CORS origin {origin!r} is invalid — must include scheme (http/https) and host."
|
||||
)
|
||||
if self.access_token_expire_minutes <= 0:
|
||||
raise ValueError("access_token_expire_minutes must be > 0")
|
||||
if self.refresh_token_expire_days <= 0:
|
||||
raise ValueError("refresh_token_expire_days must be > 0")
|
||||
if not (1 <= self.port <= 65535):
|
||||
raise ValueError("port must be in range 1..65535")
|
||||
if self.event_log_retention_days < 0:
|
||||
raise ValueError("event_log_retention_days must be >= 0")
|
||||
if self.pre_migrate_snapshot_keep < 0:
|
||||
raise ValueError("pre_migrate_snapshot_keep must be >= 0")
|
||||
|
||||
@property
|
||||
def effective_database_url(self) -> str:
|
||||
if self.database_url:
|
||||
|
||||
@@ -1,23 +1,59 @@
|
||||
"""Database engine and session management."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
import logging
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
|
||||
from ..config import settings
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_engine: AsyncEngine | None = None
|
||||
|
||||
|
||||
def _install_sqlite_pragmas(engine: AsyncEngine) -> None:
|
||||
"""Apply production-grade SQLite PRAGMAs on every new connection.
|
||||
|
||||
WAL mode lets readers and writers work concurrently without blocking;
|
||||
``busy_timeout`` gives contending writers a chance instead of instant
|
||||
SQLITE_BUSY; ``foreign_keys`` enforces the FK constraints declared in the
|
||||
models (SQLite disables them by default); ``synchronous=NORMAL`` is a
|
||||
safe-by-default durability trade-off that is standard in WAL mode.
|
||||
"""
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _pragmas(dbapi_conn, _record): # pragma: no cover — driver hook
|
||||
cur = dbapi_conn.cursor()
|
||||
try:
|
||||
cur.execute("PRAGMA journal_mode=WAL")
|
||||
cur.execute("PRAGMA synchronous=NORMAL")
|
||||
cur.execute("PRAGMA foreign_keys=ON")
|
||||
cur.execute("PRAGMA busy_timeout=10000")
|
||||
cur.execute("PRAGMA temp_store=MEMORY")
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
|
||||
def get_engine() -> AsyncEngine:
|
||||
global _engine
|
||||
if _engine is None:
|
||||
url = settings.effective_database_url
|
||||
connect_args: dict = {}
|
||||
if url.startswith("sqlite"):
|
||||
connect_args["timeout"] = 30
|
||||
_engine = create_async_engine(
|
||||
settings.effective_database_url,
|
||||
url,
|
||||
echo=settings.debug,
|
||||
pool_pre_ping=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
if url.startswith("sqlite"):
|
||||
_install_sqlite_pragmas(_engine)
|
||||
_LOGGER.info("Database engine initialized: %s", url.split("://", 1)[0])
|
||||
return _engine
|
||||
|
||||
|
||||
@@ -31,3 +67,11 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def dispose_engine() -> None:
|
||||
"""Close the engine's connection pool. Call during graceful shutdown."""
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
await _engine.dispose()
|
||||
_engine = None
|
||||
|
||||
@@ -309,6 +309,14 @@ async def migrate_schema(engine: AsyncEngine) -> None:
|
||||
text(f"ALTER TABLE {state_table} ADD COLUMN shared INTEGER DEFAULT 0")
|
||||
)
|
||||
logger.info("Added shared column to %s table", state_table)
|
||||
# meta_fingerprint — small JSON blob captured from the provider's
|
||||
# cheap meta probe. An empty default means "unknown, do a full
|
||||
# fetch next tick" so existing rows don't wrongly skip detection.
|
||||
if not await _has_column(conn, state_table, "meta_fingerprint"):
|
||||
await conn.execute(
|
||||
text(f"ALTER TABLE {state_table} ADD COLUMN meta_fingerprint TEXT DEFAULT '{{}}'")
|
||||
)
|
||||
logger.info("Added meta_fingerprint column to %s table", state_table)
|
||||
|
||||
# Add language_code to telegram_chat if missing
|
||||
if await _has_table(conn, "telegram_chat"):
|
||||
@@ -1274,3 +1282,141 @@ async def migrate_user_token_version(engine: AsyncEngine) -> None:
|
||||
text("ALTER TABLE user ADD COLUMN token_version INTEGER NOT NULL DEFAULT 1")
|
||||
)
|
||||
logger.info("Added token_version column to user table")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Performance indexes — covers every FK / owner column the list endpoints
|
||||
# and the webhook hot-path filter on. All use CREATE INDEX IF NOT EXISTS so
|
||||
# they are safe to re-run on every boot.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INDEXES: list[tuple[str, str, str]] = [
|
||||
# (index_name, table, columns)
|
||||
("ix_service_provider_user_id", "service_provider", "user_id"),
|
||||
("ix_telegram_bot_user_id", "telegram_bot", "user_id"),
|
||||
("ix_matrix_bot_user_id", "matrix_bot", "user_id"),
|
||||
("ix_email_bot_user_id", "email_bot", "user_id"),
|
||||
("ix_telegram_chat_bot_id", "telegram_chat", "bot_id"),
|
||||
("ix_tracking_config_user_id", "tracking_config", "user_id"),
|
||||
("ix_tracking_config_provider_type", "tracking_config", "provider_type"),
|
||||
("ix_notification_target_user_id", "notification_target", "user_id"),
|
||||
("ix_notification_target_type", "notification_target", "type"),
|
||||
("ix_notification_tracker_user_id", "notification_tracker", "user_id"),
|
||||
("ix_notification_tracker_provider_id", "notification_tracker", "provider_id"),
|
||||
# Composite for the webhook hot path: WHERE provider_id = ? AND enabled = true
|
||||
(
|
||||
"ix_notification_tracker_provider_enabled",
|
||||
"notification_tracker",
|
||||
"provider_id, enabled",
|
||||
),
|
||||
("ix_command_config_user_id", "command_config", "user_id"),
|
||||
("ix_command_template_config_user_id", "command_template_config", "user_id"),
|
||||
("ix_command_tracker_user_id", "command_tracker", "user_id"),
|
||||
("ix_command_tracker_provider_id", "command_tracker", "provider_id"),
|
||||
("ix_action_user_id", "action", "user_id"),
|
||||
("ix_action_provider_id", "action", "provider_id"),
|
||||
# Dashboard: SELECT event_log WHERE user_id = ? ORDER BY created_at DESC
|
||||
("ix_event_log_user_created", "event_log", "user_id, created_at DESC"),
|
||||
("ix_event_log_provider_id", "event_log", "provider_id"),
|
||||
("ix_event_log_notification_tracker_id", "event_log", "notification_tracker_id"),
|
||||
("ix_event_log_action_id", "event_log", "action_id"),
|
||||
# Webhook log hot path: WHERE provider_id = ? ORDER BY created_at DESC
|
||||
(
|
||||
"ix_webhook_payload_log_provider_created",
|
||||
"webhook_payload_log",
|
||||
"provider_id, created_at DESC",
|
||||
),
|
||||
# Notification tracker join tables
|
||||
(
|
||||
"ix_notification_tracker_target_notification_tracker_id",
|
||||
"notification_tracker_target",
|
||||
"notification_tracker_id",
|
||||
),
|
||||
(
|
||||
"ix_notification_tracker_target_target_id",
|
||||
"notification_tracker_target",
|
||||
"target_id",
|
||||
),
|
||||
("ix_target_receiver_target_id", "target_receiver", "target_id"),
|
||||
("ix_template_slot_config_id", "template_slot", "config_id"),
|
||||
("ix_command_template_slot_config_id", "command_template_slot", "config_id"),
|
||||
("ix_action_rule_action_id", "action_rule", "action_id"),
|
||||
("ix_action_execution_action_started", "action_execution", "action_id, started_at DESC"),
|
||||
]
|
||||
|
||||
|
||||
async def migrate_performance_indexes(engine: AsyncEngine) -> None:
|
||||
"""Create missing performance indexes on hot query paths.
|
||||
|
||||
Every index is created with IF NOT EXISTS so the migration is safe to
|
||||
replay on every boot. We only create the index when the table exists —
|
||||
early boots before other migrations land would otherwise raise.
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
for name, table, columns in _INDEXES:
|
||||
_assert_ident(name, "index")
|
||||
_assert_ident(table, "table")
|
||||
# Columns list is a trusted literal constructed above — never user input.
|
||||
if not await _has_table(conn, table):
|
||||
continue
|
||||
try:
|
||||
await conn.execute(
|
||||
text(f"CREATE INDEX IF NOT EXISTS {name} ON {table} ({columns})")
|
||||
)
|
||||
except Exception: # pragma: no cover — log and continue
|
||||
logger.warning(
|
||||
"Failed to create index %s on %s(%s)",
|
||||
name, table, columns, exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema version tracking — lightweight alternative to Alembic while the
|
||||
# hand-rolled idempotent migrations remain the source of truth. Gives
|
||||
# operators a single-row answer to "what schema is this DB at" and lets
|
||||
# future upgrades short-circuit migrations that already ran.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CURRENT_SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
async def migrate_schema_version(engine: AsyncEngine) -> None:
|
||||
"""Create schema_version table and bump it to CURRENT_SCHEMA_VERSION."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE TABLE IF NOT EXISTS schema_version ("
|
||||
" id INTEGER PRIMARY KEY CHECK (id = 1),"
|
||||
" version INTEGER NOT NULL,"
|
||||
" applied_at TEXT NOT NULL"
|
||||
")"
|
||||
)
|
||||
)
|
||||
row = await conn.run_sync(
|
||||
lambda sc: sc.execute(
|
||||
text("SELECT version FROM schema_version WHERE id = 1")
|
||||
).fetchone()
|
||||
)
|
||||
from datetime import datetime, timezone
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
if row is None:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO schema_version (id, version, applied_at) "
|
||||
"VALUES (1, :v, :t)"
|
||||
),
|
||||
{"v": CURRENT_SCHEMA_VERSION, "t": now},
|
||||
)
|
||||
logger.info("Initialized schema_version at %d", CURRENT_SCHEMA_VERSION)
|
||||
elif int(row[0]) < CURRENT_SCHEMA_VERSION:
|
||||
await conn.execute(
|
||||
text(
|
||||
"UPDATE schema_version SET version = :v, applied_at = :t "
|
||||
"WHERE id = 1"
|
||||
),
|
||||
{"v": CURRENT_SCHEMA_VERSION, "t": now},
|
||||
)
|
||||
logger.info(
|
||||
"Bumped schema_version from %s to %d",
|
||||
row[0], CURRENT_SCHEMA_VERSION,
|
||||
)
|
||||
|
||||
@@ -376,6 +376,13 @@ class NotificationTrackerState(SQLModel, table=True):
|
||||
shared: bool = Field(default=False)
|
||||
asset_ids: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
pending_asset_ids: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
# Lightweight fingerprint ({updated_at, asset_count, shared, name, ...})
|
||||
# captured from the provider's cheap meta probe. Letting this differ from
|
||||
# the current provider response is what tells the watcher a full fetch is
|
||||
# actually required — letting it match lets the watcher skip the big read.
|
||||
meta_fingerprint: dict[str, Any] = Field(
|
||||
default_factory=dict, sa_column=Column(JSON)
|
||||
)
|
||||
last_updated: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
|
||||
@@ -394,8 +394,37 @@ async def _seed_default_command_configs() -> None:
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _ensure_system_user() -> None:
|
||||
"""Ensure a User row with id=0 exists.
|
||||
|
||||
Historically the app used ``user_id=0`` as a sentinel for "system-owned"
|
||||
defaults (tracking configs, templates, etc.). Now that we enable
|
||||
``PRAGMA foreign_keys=ON`` at connect time, those inserts would fail
|
||||
with ``FOREIGN KEY constraint failed`` unless a placeholder user row
|
||||
with the matching id exists.
|
||||
"""
|
||||
engine = get_engine()
|
||||
async with engine.begin() as conn:
|
||||
# INSERT OR IGNORE so re-running seeds is cheap and idempotent.
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT OR IGNORE INTO user "
|
||||
"(id, username, hashed_password, role, token_version, created_at) "
|
||||
"VALUES (0, :u, :p, :r, 1, :t)"
|
||||
),
|
||||
{
|
||||
"u": "__system__",
|
||||
# Invalid bcrypt hash — nobody can ever log in as this user.
|
||||
"p": "!disabled!",
|
||||
"r": "system",
|
||||
"t": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def seed_all() -> None:
|
||||
"""Run all seed functions in order."""
|
||||
await _ensure_system_user()
|
||||
await _seed_default_templates()
|
||||
await _seed_default_command_templates()
|
||||
await _seed_default_tracking_configs()
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
"""Pre-migration database snapshots.
|
||||
|
||||
Runs at lifespan startup BEFORE migrations execute. Produces a consistent
|
||||
point-in-time copy of the SQLite database using ``VACUUM INTO`` (atomic,
|
||||
cannot tear against concurrent activity, works with WAL).
|
||||
|
||||
The snapshot is the operator's fallback if a future migration corrupts the
|
||||
schema — restore is a single ``mv`` / ``docker cp``. We keep the N most
|
||||
recent files (default 5) and never fail startup if the snapshot itself
|
||||
fails: a snapshot is best-effort safety net, not a gate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_SNAPSHOT_GLOB = "pre-migrate-*.db"
|
||||
_SNAPSHOT_NAME_RE = re.compile(r"^[A-Za-z0-9._+\-:]+$")
|
||||
|
||||
|
||||
def _sqlite_path_from_url(url: str) -> Path | None:
|
||||
"""Extract the filesystem path from a ``sqlite+aiosqlite:///...`` URL."""
|
||||
if not url.startswith("sqlite"):
|
||||
return None
|
||||
# e.g. "sqlite+aiosqlite:///C:/data/notify_bridge.db"
|
||||
prefix, _, rest = url.partition(":///")
|
||||
if not rest:
|
||||
return None
|
||||
return Path(rest)
|
||||
|
||||
|
||||
async def snapshot_database(
|
||||
engine: AsyncEngine,
|
||||
target_dir: Path,
|
||||
*,
|
||||
label: str = "pre-migrate",
|
||||
) -> Path | None:
|
||||
"""Write a consistent copy of the SQLite DB to ``target_dir``.
|
||||
|
||||
Uses ``VACUUM INTO`` which SQLite executes atomically against a read
|
||||
snapshot — safe under WAL, cannot produce a torn copy. Returns the
|
||||
snapshot path on success, ``None`` when skipped or on non-fatal
|
||||
failure. Never raises: callers treat a missing snapshot as acceptable
|
||||
(the main DB remains the source of truth).
|
||||
"""
|
||||
if not _SNAPSHOT_NAME_RE.match(label):
|
||||
_LOGGER.warning("Snapshot label %r contains unsafe characters; skipping", label)
|
||||
return None
|
||||
|
||||
url = str(engine.url)
|
||||
src = _sqlite_path_from_url(url)
|
||||
if src is None:
|
||||
_LOGGER.debug("Non-SQLite engine; skipping snapshot")
|
||||
return None
|
||||
if not src.exists():
|
||||
_LOGGER.debug("DB file %s does not exist yet (fresh install); skipping snapshot", src)
|
||||
return None
|
||||
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
|
||||
dest = target_dir / f"{label}-{ts}.db"
|
||||
|
||||
# VACUUM INTO accepts a string literal, not a bind parameter. The dest
|
||||
# path is built from our own label + timestamp (never user input), so
|
||||
# escaping is straightforward — still, reject any dest containing a
|
||||
# single quote as a belt-and-braces check.
|
||||
dest_str = str(dest)
|
||||
if "'" in dest_str:
|
||||
_LOGGER.warning("Refusing to snapshot to path containing a single quote: %s", dest_str)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with engine.connect() as conn:
|
||||
# VACUUM cannot run inside an explicit transaction; use the
|
||||
# plain connection without begin().
|
||||
await conn.execute(text(f"VACUUM INTO '{dest_str}'"))
|
||||
_LOGGER.info("Database snapshot written: %s (%.1f KiB)", dest, dest.stat().st_size / 1024)
|
||||
return dest
|
||||
except Exception:
|
||||
_LOGGER.warning(
|
||||
"Pre-migration snapshot failed — continuing with startup. "
|
||||
"Check disk space in %s.",
|
||||
target_dir,
|
||||
exc_info=True,
|
||||
)
|
||||
# Partial file can linger if VACUUM INTO aborted mid-write; clean up.
|
||||
try:
|
||||
if dest.exists():
|
||||
dest.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def prune_old_snapshots(target_dir: Path, keep: int) -> list[Path]:
|
||||
"""Keep the ``keep`` most recent pre-migrate snapshots, delete the rest.
|
||||
|
||||
Returns the list of paths that were deleted. Safe to call with
|
||||
``keep=0`` (deletes everything) or when the directory does not exist.
|
||||
"""
|
||||
if keep < 0:
|
||||
raise ValueError("keep must be >= 0")
|
||||
if not target_dir.is_dir():
|
||||
return []
|
||||
|
||||
try:
|
||||
snapshots = sorted(
|
||||
target_dir.glob(_SNAPSHOT_GLOB),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
except OSError:
|
||||
return []
|
||||
|
||||
deleted: list[Path] = []
|
||||
for old in snapshots[keep:]:
|
||||
try:
|
||||
old.unlink()
|
||||
deleted.append(old)
|
||||
except OSError:
|
||||
_LOGGER.debug("Could not delete old snapshot %s", old, exc_info=True)
|
||||
if deleted:
|
||||
_LOGGER.info(
|
||||
"Pruned %d old pre-migrate snapshot(s); kept %d most recent",
|
||||
len(deleted), min(keep, len(snapshots)),
|
||||
)
|
||||
return deleted
|
||||
|
||||
|
||||
async def snapshot_and_prune(
|
||||
engine: AsyncEngine,
|
||||
target_dir: Path,
|
||||
*,
|
||||
keep: int,
|
||||
) -> Path | None:
|
||||
"""Take a snapshot and prune old ones. Used by the lifespan startup path.
|
||||
|
||||
``keep=0`` disables snapshotting entirely.
|
||||
"""
|
||||
if keep <= 0:
|
||||
return None
|
||||
snapshot_path = await snapshot_database(engine, target_dir)
|
||||
# Always prune even if this run's snapshot failed — old files still
|
||||
# cost disk and may have been written by prior successful boots.
|
||||
await asyncio.to_thread(prune_old_snapshots, target_dir, keep)
|
||||
return snapshot_path
|
||||
@@ -0,0 +1,324 @@
|
||||
"""Production-grade logging configuration.
|
||||
|
||||
Installs one ``dictConfig`` layout with:
|
||||
|
||||
* A ``LogRecordFactory`` that pulls request-scoped identifiers from
|
||||
``notify_bridge_core.log_context`` onto every record, so logs can be
|
||||
filtered/correlated by ``request_id``, ``command``, ``chat_id``,
|
||||
``bot_id``, ``dispatch_id`` without each call site passing them.
|
||||
* A ``SecretMaskingFilter`` that redacts Telegram bot tokens and common
|
||||
``Authorization`` / ``x-api-key`` headers so an accidental ``repr`` or
|
||||
dumped request doesn't leak credentials into the log aggregator.
|
||||
* A text formatter (default) or a JSON formatter (one line per record)
|
||||
selectable via ``NOTIFY_BRIDGE_LOG_FORMAT`` / app setting.
|
||||
|
||||
Levels are configurable three ways (later wins):
|
||||
|
||||
1. ``NOTIFY_BRIDGE_LOG_LEVEL`` env var (root) plus
|
||||
``NOTIFY_BRIDGE_LOG_LEVELS`` (``mod=LEVEL,mod2=LEVEL``).
|
||||
2. DB ``AppSetting`` rows ``log_level`` / ``log_levels`` / ``log_format``,
|
||||
applied after migrations during startup.
|
||||
3. Live edits via the settings API — ``apply_log_levels()`` updates
|
||||
existing loggers in place without a server restart.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import logging.config
|
||||
import re
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from notify_bridge_core.log_context import (
|
||||
bot_id_var,
|
||||
chat_id_var,
|
||||
command_var,
|
||||
dispatch_id_var,
|
||||
request_id_var,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Secret masking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Telegram bot tokens: /bot<digits>:<alnum with dashes/underscores>
|
||||
_TELEGRAM_TOKEN_RE = re.compile(r"/bot\d+:[A-Za-z0-9_-]{20,}")
|
||||
|
||||
# Header-style secrets: Authorization: Bearer xxx, x-api-key=xxx, etc.
|
||||
# Only matches reasonably long tokens so short legitimate values don't trip.
|
||||
_HEADER_SECRET_RE = re.compile(
|
||||
r"(?i)(authorization|x-api-key|api[_-]?key|password|secret|access[_-]?token|refresh[_-]?token)"
|
||||
r"([\"']?\s*[:=]\s*[\"']?)"
|
||||
r"([A-Za-z0-9._+/=\-]{12,})"
|
||||
)
|
||||
|
||||
|
||||
def _mask(text: str) -> str:
|
||||
redacted = _TELEGRAM_TOKEN_RE.sub("/bot***", text)
|
||||
redacted = _HEADER_SECRET_RE.sub(r"\1\2***", redacted)
|
||||
return redacted
|
||||
|
||||
|
||||
class SecretMaskingFilter(logging.Filter):
|
||||
"""Redact likely secrets from every log message before it's emitted.
|
||||
|
||||
Covers three surfaces where a leaked token can end up in the log:
|
||||
the formatted message, a cached exception traceback (``exc_text``),
|
||||
and a cached stack frame dump (``stack_info``). The formatter still
|
||||
expands ``exc_info`` for us when ``exc_text`` is None, so we also
|
||||
pre-render + mask on first emission.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
try:
|
||||
msg = record.getMessage()
|
||||
except Exception:
|
||||
return True
|
||||
redacted = _mask(msg)
|
||||
if redacted != msg:
|
||||
# Replace the formatted message and drop args so the handler
|
||||
# doesn't re-format with the original values.
|
||||
record.msg = redacted
|
||||
record.args = ()
|
||||
|
||||
if record.exc_info and not record.exc_text:
|
||||
# Pre-render so we can mask before the formatter caches it.
|
||||
fmt = logging.Formatter()
|
||||
record.exc_text = fmt.formatException(record.exc_info)
|
||||
if record.exc_text:
|
||||
record.exc_text = _mask(record.exc_text)
|
||||
if record.stack_info:
|
||||
record.stack_info = _mask(record.stack_info)
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Record factory — injects context identifiers onto every record
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CONTEXT_FIELDS = ("request_id", "command", "chat_id", "bot_id", "dispatch_id")
|
||||
_PLACEHOLDER = "-"
|
||||
|
||||
_original_factory = logging.getLogRecordFactory()
|
||||
|
||||
|
||||
def _context_record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
|
||||
record = _original_factory(*args, **kwargs)
|
||||
record.request_id = request_id_var.get() or _PLACEHOLDER
|
||||
record.command = command_var.get() or _PLACEHOLDER
|
||||
record.chat_id = chat_id_var.get() or _PLACEHOLDER
|
||||
bid = bot_id_var.get()
|
||||
record.bot_id = str(bid) if bid is not None else _PLACEHOLDER
|
||||
record.dispatch_id = dispatch_id_var.get() or _PLACEHOLDER
|
||||
return record
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON formatter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""Emit one JSON object per log record."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"ts": self.formatTime(record, "%Y-%m-%dT%H:%M:%S") + f".{int(record.msecs):03d}",
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"module": record.module,
|
||||
"line": record.lineno,
|
||||
"msg": record.getMessage(),
|
||||
}
|
||||
for field in _CONTEXT_FIELDS:
|
||||
val = getattr(record, field, None)
|
||||
if val and val != _PLACEHOLDER:
|
||||
payload[field] = val
|
||||
# Prefer the pre-masked exc_text cached by SecretMaskingFilter over
|
||||
# re-formatting from exc_info, which would bypass the mask.
|
||||
if record.exc_text:
|
||||
payload["exc"] = record.exc_text
|
||||
elif record.exc_info:
|
||||
payload["exc"] = self.formatException(record.exc_info)
|
||||
if record.stack_info:
|
||||
payload["stack"] = record.stack_info
|
||||
return json.dumps(payload, ensure_ascii=False, default=str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text formatter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Keeps all context fields on one line so grep-by-field works. Empty values
|
||||
# are rendered as "-" by the record factory to avoid KeyError if a record
|
||||
# arrives without the filter.
|
||||
_TEXT_FORMAT = (
|
||||
"%(asctime)s %(levelname)-7s %(name)s:%(lineno)d "
|
||||
"[req=%(request_id)s cmd=%(command)s bot=%(bot_id)s chat=%(chat_id)s disp=%(dispatch_id)s] "
|
||||
"%(message)s"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Level override parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_VALID_LEVELS = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", "NOTSET"}
|
||||
|
||||
|
||||
def parse_level_overrides(raw: str) -> dict[str, str]:
|
||||
"""Parse ``module=LEVEL,module2=LEVEL`` into a mapping of validated levels.
|
||||
|
||||
Invalid entries (bad format, unknown level) are silently dropped —
|
||||
a malformed env var or DB setting must not crash boot.
|
||||
"""
|
||||
result: dict[str, str] = {}
|
||||
for chunk in (raw or "").split(","):
|
||||
chunk = chunk.strip()
|
||||
if not chunk or "=" not in chunk:
|
||||
continue
|
||||
mod, _, lvl = chunk.partition("=")
|
||||
mod = mod.strip()
|
||||
lvl = lvl.strip().upper()
|
||||
if not mod or lvl not in _VALID_LEVELS:
|
||||
continue
|
||||
result[mod] = lvl
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_level(level: str | None, default: str = "INFO") -> str:
|
||||
if not level:
|
||||
return default
|
||||
up = level.strip().upper()
|
||||
return up if up in _VALID_LEVELS else default
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Setup + live apply
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Libraries we quiet by default — noisy at DEBUG and almost always irrelevant
|
||||
# to a service issue. Override via LOG_LEVELS=sqlalchemy.engine=DEBUG when
|
||||
# actually debugging.
|
||||
_NOISY_LIBRARY_DEFAULTS: dict[str, str] = {
|
||||
"sqlalchemy": "WARNING",
|
||||
"sqlalchemy.engine": "WARNING",
|
||||
"sqlalchemy.pool": "WARNING",
|
||||
"aiohttp": "WARNING",
|
||||
"aiohttp.access": "WARNING",
|
||||
"aiohttp.client": "WARNING",
|
||||
"aiohttp.server": "WARNING",
|
||||
"apscheduler": "WARNING",
|
||||
"apscheduler.scheduler": "WARNING",
|
||||
"apscheduler.executors.default": "WARNING",
|
||||
"urllib3": "WARNING",
|
||||
"asyncio": "WARNING",
|
||||
"httpx": "WARNING",
|
||||
"httpcore": "WARNING",
|
||||
"PIL": "WARNING",
|
||||
"uvicorn.access": "WARNING",
|
||||
}
|
||||
|
||||
|
||||
def setup_logging(
|
||||
*,
|
||||
level: str = "INFO",
|
||||
fmt: str = "text",
|
||||
per_module_levels: str = "",
|
||||
) -> None:
|
||||
"""Install the logging configuration. Safe to call more than once.
|
||||
|
||||
Args:
|
||||
level: Root log level (applied to ``notify_bridge_*`` loggers).
|
||||
fmt: ``"text"`` (default) or ``"json"``.
|
||||
per_module_levels: ``mod=LEVEL,mod2=LEVEL`` overrides. Wins over the
|
||||
root level for the listed loggers.
|
||||
"""
|
||||
root_level = _normalize_level(level, "INFO")
|
||||
overrides = parse_level_overrides(per_module_levels)
|
||||
|
||||
# Install the context-aware record factory (idempotent — setting the same
|
||||
# factory twice is fine because ``_original_factory`` is captured at
|
||||
# import time).
|
||||
logging.setLogRecordFactory(_context_record_factory)
|
||||
|
||||
if fmt == "json":
|
||||
formatters = {"default": {"()": f"{__name__}.JsonFormatter"}}
|
||||
else:
|
||||
formatters = {
|
||||
"default": {
|
||||
"format": _TEXT_FORMAT,
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
}
|
||||
}
|
||||
|
||||
# Start with noisy-library defaults, then layer user overrides on top so
|
||||
# the user can raise them to DEBUG when actually debugging.
|
||||
loggers: dict[str, dict[str, Any]] = {}
|
||||
for mod, lvl in _NOISY_LIBRARY_DEFAULTS.items():
|
||||
loggers[mod] = {"level": lvl, "propagate": True}
|
||||
# App loggers follow the root level unless overridden.
|
||||
loggers["notify_bridge_server"] = {"level": root_level, "propagate": True}
|
||||
loggers["notify_bridge_core"] = {"level": root_level, "propagate": True}
|
||||
# User overrides win.
|
||||
for mod, lvl in overrides.items():
|
||||
loggers[mod] = {"level": lvl, "propagate": True}
|
||||
|
||||
config: dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"filters": {
|
||||
"mask_secrets": {"()": f"{__name__}.SecretMaskingFilter"},
|
||||
},
|
||||
"formatters": formatters,
|
||||
"handlers": {
|
||||
"stderr": {
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": sys.stderr,
|
||||
"formatter": "default",
|
||||
"filters": ["mask_secrets"],
|
||||
},
|
||||
},
|
||||
"root": {
|
||||
"level": root_level,
|
||||
"handlers": ["stderr"],
|
||||
},
|
||||
"loggers": loggers,
|
||||
}
|
||||
logging.config.dictConfig(config)
|
||||
|
||||
|
||||
def apply_log_levels(
|
||||
*,
|
||||
level: str | None,
|
||||
per_module_levels: str | None,
|
||||
) -> None:
|
||||
"""Update existing logger levels in-place without rebuilding handlers.
|
||||
|
||||
Called when an admin changes the log settings at runtime. Setting
|
||||
``level`` to None leaves the root untouched; setting it to a valid
|
||||
level applies to ``notify_bridge_server`` / ``notify_bridge_core``.
|
||||
|
||||
``per_module_levels`` is treated as an exclusive set — loggers that
|
||||
previously had an override but aren't in the new string are reset
|
||||
*toward* the root level so a removed override actually takes effect.
|
||||
"""
|
||||
if level:
|
||||
lvl = _normalize_level(level, "INFO")
|
||||
logging.getLogger("notify_bridge_server").setLevel(lvl)
|
||||
logging.getLogger("notify_bridge_core").setLevel(lvl)
|
||||
# NOTSET on root is almost never what you want — keep root where it is
|
||||
# unless the caller explicitly set something.
|
||||
logging.getLogger().setLevel(lvl)
|
||||
|
||||
if per_module_levels is not None:
|
||||
overrides = parse_level_overrides(per_module_levels)
|
||||
# Apply new overrides
|
||||
for mod, lvl in overrides.items():
|
||||
logging.getLogger(mod).setLevel(lvl)
|
||||
# Reset noisy libs that aren't in the new overrides back to defaults
|
||||
for mod, default_lvl in _NOISY_LIBRARY_DEFAULTS.items():
|
||||
if mod not in overrides:
|
||||
logging.getLogger(mod).setLevel(default_lvl)
|
||||
@@ -9,13 +9,18 @@ from slowapi import _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
|
||||
# Ensure app-level loggers are visible
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
from .config import settings as _log_cfg
|
||||
_log_level = logging.DEBUG if _log_cfg.debug else logging.INFO
|
||||
logging.getLogger("notify_bridge_server").setLevel(_log_level)
|
||||
logging.getLogger("notify_bridge_core").setLevel(_log_level)
|
||||
from .logging_setup import setup_logging
|
||||
|
||||
# Boot logging from env-based config. DB-backed AppSetting rows (``log_level`` /
|
||||
# ``log_levels`` / ``log_format``) override this after migrations — see the
|
||||
# lifespan block below.
|
||||
setup_logging(
|
||||
level="DEBUG" if _log_cfg.debug else _log_cfg.log_level,
|
||||
fmt=_log_cfg.log_format,
|
||||
per_module_levels=_log_cfg.log_levels,
|
||||
)
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
from .database.engine import init_db
|
||||
from .database.models import * # noqa: F401,F403 — ensure all models registered
|
||||
@@ -47,13 +52,41 @@ from .api.webhook_logs import router as webhook_logs_router
|
||||
from .api.backup import router as backup_router
|
||||
|
||||
|
||||
# Readiness flag — flipped to True once the scheduler has started and the
|
||||
# app is fully initialized. Exposed via /api/ready for orchestrators.
|
||||
_READY: bool = False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global _READY
|
||||
await init_db()
|
||||
# Run data migrations (idempotent)
|
||||
from .database.engine import get_engine
|
||||
from .database.migrations import migrate_schema, migrate_tracker_targets, migrate_entity_refactor, migrate_template_slots, migrate_target_receivers, migrate_template_locale, migrate_receivers_from_config, migrate_command_slot_locale, migrate_notification_slot_locale, migrate_user_token_version
|
||||
from .database.migrations import (
|
||||
migrate_schema,
|
||||
migrate_tracker_targets,
|
||||
migrate_entity_refactor,
|
||||
migrate_template_slots,
|
||||
migrate_target_receivers,
|
||||
migrate_template_locale,
|
||||
migrate_receivers_from_config,
|
||||
migrate_command_slot_locale,
|
||||
migrate_notification_slot_locale,
|
||||
migrate_user_token_version,
|
||||
migrate_performance_indexes,
|
||||
migrate_schema_version,
|
||||
)
|
||||
from .database.snapshot import snapshot_and_prune
|
||||
engine = get_engine()
|
||||
# Take a consistent DB snapshot BEFORE migrations run, so operators can
|
||||
# roll back a bad upgrade by restoring one file. Best-effort — failures
|
||||
# are logged, not raised.
|
||||
await snapshot_and_prune(
|
||||
engine,
|
||||
_log_cfg.data_dir / "backups",
|
||||
keep=_log_cfg.pre_migrate_snapshot_keep,
|
||||
)
|
||||
await migrate_schema(engine)
|
||||
await migrate_tracker_targets(engine)
|
||||
await migrate_entity_refactor(engine)
|
||||
@@ -64,8 +97,28 @@ async def lifespan(app: FastAPI):
|
||||
await migrate_command_slot_locale(engine)
|
||||
await migrate_notification_slot_locale(engine)
|
||||
await migrate_user_token_version(engine)
|
||||
await migrate_performance_indexes(engine)
|
||||
await migrate_schema_version(engine)
|
||||
from .database.seeds import seed_all
|
||||
await seed_all()
|
||||
# Apply DB-backed logging settings (override env-based boot config).
|
||||
# log_format still needs a restart — changing it means swapping the
|
||||
# handler formatter entirely.
|
||||
try:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession as _AS_log
|
||||
from .api.app_settings import get_setting as _get_log_setting
|
||||
from .logging_setup import apply_log_levels
|
||||
async with _AS_log(engine) as _log_session:
|
||||
db_level = await _get_log_setting(_log_session, "log_level")
|
||||
db_levels = await _get_log_setting(_log_session, "log_levels")
|
||||
apply_log_levels(level=db_level or None, per_module_levels=db_levels)
|
||||
_LOGGER.info(
|
||||
"Logging initialized: level=%s overrides=%r format=%s",
|
||||
db_level or _log_cfg.log_level, db_levels or _log_cfg.log_levels,
|
||||
_log_cfg.log_format,
|
||||
)
|
||||
except Exception: # pragma: no cover — never let logging setup abort boot
|
||||
_LOGGER.exception("Failed to apply DB-backed log settings; keeping env-based levels")
|
||||
# Apply any pending restore staged via /api/backup/prepare-restore
|
||||
from .services.pending_restore import apply_pending_restore_if_any
|
||||
await apply_pending_restore_if_any()
|
||||
@@ -77,16 +130,28 @@ async def lifespan(app: FastAPI):
|
||||
set_webhook_secret(_secret or None)
|
||||
from .services.scheduler import start_scheduler, get_scheduler
|
||||
await start_scheduler()
|
||||
_READY = True
|
||||
yield
|
||||
# Graceful shutdown
|
||||
from .services.http_session import close_http_session
|
||||
await close_http_session()
|
||||
# Graceful shutdown — stop the scheduler FIRST so in-flight jobs finish
|
||||
# before we close their HTTP session. Then close the shared session and
|
||||
# dispose the DB engine.
|
||||
_READY = False
|
||||
scheduler = get_scheduler()
|
||||
if scheduler.running:
|
||||
scheduler.shutdown()
|
||||
scheduler.shutdown(wait=True)
|
||||
from .services.http_session import close_http_session
|
||||
await close_http_session()
|
||||
from .database.engine import dispose_engine
|
||||
await dispose_engine()
|
||||
|
||||
|
||||
app = FastAPI(title="Notify Bridge", version="0.1.0", lifespan=lifespan)
|
||||
try:
|
||||
from importlib.metadata import version as _pkg_version
|
||||
_APP_VERSION = _pkg_version("notify-bridge-server")
|
||||
except Exception: # pragma: no cover — editable install edge cases
|
||||
_APP_VERSION = "0.0.0+unknown"
|
||||
|
||||
app = FastAPI(title="Notify Bridge", version=_APP_VERSION, lifespan=lifespan)
|
||||
|
||||
# --- Security headers ---
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
@@ -94,6 +159,24 @@ from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
|
||||
|
||||
_CSP = (
|
||||
"default-src 'self'; "
|
||||
"img-src 'self' data: blob: https:; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
# SvelteKit's static adapter emits an inline bootstrap <script> with the
|
||||
# hydration payload, so 'self' alone blocks the SPA from starting.
|
||||
# 'unsafe-inline' re-enables it; the app's primary XSS protection still
|
||||
# comes from Svelte's template auto-escaping and frontend/sanitize.ts
|
||||
# for the few {@html} paths that render user-controlled content.
|
||||
"script-src 'self' 'unsafe-inline'; "
|
||||
"connect-src 'self'; "
|
||||
"font-src 'self' data:; "
|
||||
"base-uri 'self'; "
|
||||
"form-action 'self'; "
|
||||
"frame-ancestors 'none'"
|
||||
)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: StarletteRequest, call_next):
|
||||
response: StarletteResponse = await call_next(request)
|
||||
@@ -101,6 +184,14 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers.setdefault("Content-Security-Policy", _CSP)
|
||||
# HSTS only makes sense over HTTPS; set when the edge terminates TLS
|
||||
# and forwards X-Forwarded-Proto=https.
|
||||
if request.headers.get("x-forwarded-proto") == "https":
|
||||
response.headers.setdefault(
|
||||
"Strict-Transport-Security",
|
||||
"max-age=31536000; includeSubDomains",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@@ -153,7 +244,22 @@ app.include_router(backup_router)
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
"""Liveness: process is up and responding. Always returns 200 once the
|
||||
ASGI app has started. Keep this endpoint anonymous and trivially cheap."""
|
||||
return {"status": "ok", "version": _APP_VERSION}
|
||||
|
||||
|
||||
@app.get("/api/ready")
|
||||
async def ready():
|
||||
"""Readiness: migrations and scheduler have started, app can serve traffic.
|
||||
|
||||
Returns 503 until the lifespan startup sequence has completed. Use this
|
||||
for orchestrator readiness probes (Docker, Kubernetes).
|
||||
"""
|
||||
if not _READY:
|
||||
from starlette.responses import JSONResponse
|
||||
return JSONResponse({"status": "starting"}, status_code=503)
|
||||
return {"status": "ready", "version": _APP_VERSION}
|
||||
|
||||
|
||||
# --- Serve frontend static files (production) ---
|
||||
@@ -186,4 +292,12 @@ if _cfg.static_dir and Path(_cfg.static_dir).is_dir():
|
||||
|
||||
def run():
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=_cfg.host, port=_cfg.port)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=_cfg.host,
|
||||
port=_cfg.port,
|
||||
proxy_headers=True,
|
||||
forwarded_allow_ips=_cfg.forwarded_allow_ips or "127.0.0.1",
|
||||
timeout_graceful_shutdown=_cfg.graceful_shutdown_seconds,
|
||||
access_log=not _cfg.debug,
|
||||
)
|
||||
|
||||
@@ -387,10 +387,9 @@ async def export_backup_to_file(
|
||||
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
|
||||
filename = f"backup-{ts}.json"
|
||||
filepath = backup_dir / filename
|
||||
filepath.write_text(
|
||||
json.dumps(backup.model_dump(), indent=2, ensure_ascii=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
import asyncio as _asyncio
|
||||
payload = json.dumps(backup.model_dump(), indent=2, ensure_ascii=False)
|
||||
await _asyncio.to_thread(filepath.write_text, payload, encoding="utf-8")
|
||||
_LOGGER.info("Scheduled backup saved: %s", filepath)
|
||||
return filepath
|
||||
|
||||
@@ -399,7 +398,13 @@ def cleanup_old_backups(backup_dir: Path, keep: int = 5) -> list[str]:
|
||||
"""Delete oldest backup files exceeding `keep` count. Returns deleted filenames."""
|
||||
if not backup_dir.is_dir():
|
||||
return []
|
||||
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
|
||||
# Sort by mtime (newest first) so behavior doesn't depend on the filename
|
||||
# timestamp format, which could change later without updating this code.
|
||||
files = sorted(
|
||||
backup_dir.glob("backup-*.json"),
|
||||
key=lambda f: f.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
deleted = []
|
||||
for old in files[keep:]:
|
||||
old.unlink()
|
||||
@@ -413,7 +418,13 @@ def list_backup_files(backup_dir: Path) -> list[dict[str, Any]]:
|
||||
"""List backup files in the directory with metadata."""
|
||||
if not backup_dir.is_dir():
|
||||
return []
|
||||
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
|
||||
# Sort by mtime (newest first) so behavior doesn't depend on the filename
|
||||
# timestamp format, which could change later without updating this code.
|
||||
files = sorted(
|
||||
backup_dir.glob("backup-*.json"),
|
||||
key=lambda f: f.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
result = []
|
||||
for f in files:
|
||||
stat = f.stat()
|
||||
|
||||
@@ -11,23 +11,36 @@ Call ``close_http_session()`` once during application shutdown.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
|
||||
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30, connect=10)
|
||||
|
||||
_session: aiohttp.ClientSession | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_http_session() -> aiohttp.ClientSession:
|
||||
"""Get or create the shared HTTP session."""
|
||||
"""Get or create the shared HTTP session.
|
||||
|
||||
Concurrent first-callers are serialized through ``_lock`` so we never
|
||||
leak a second ClientSession / connector pair. Once established, hot
|
||||
callers skip the lock via the fast-path check.
|
||||
"""
|
||||
global _session
|
||||
if _session is None or _session.closed:
|
||||
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
|
||||
if _session is not None and not _session.closed:
|
||||
return _session
|
||||
async with _lock:
|
||||
if _session is None or _session.closed:
|
||||
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
|
||||
return _session
|
||||
|
||||
|
||||
async def close_http_session() -> None:
|
||||
"""Close the shared HTTP session (call on app shutdown)."""
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
await _session.close()
|
||||
async with _lock:
|
||||
if _session is not None and not _session.closed:
|
||||
await _session.close()
|
||||
_session = None
|
||||
|
||||
@@ -188,6 +188,7 @@ _SAMPLE_CONTEXT = {
|
||||
"current_time": "09:00",
|
||||
"current_datetime": "22.03.2026, 09:00 UTC",
|
||||
"weekday": "Monday",
|
||||
"timezone": "UTC",
|
||||
"custom_vars": {"team": "Engineering", "message": "Time for standup!"},
|
||||
"team": "Engineering",
|
||||
"message": "Time for standup!",
|
||||
|
||||
@@ -0,0 +1,242 @@
|
||||
"""Cron-fired scheduled / periodic / memory dispatch for Immich trackers.
|
||||
|
||||
The Immich provider exposes three notification slots that fire on a wall-clock
|
||||
schedule rather than in response to album changes:
|
||||
|
||||
* ``scheduled_assets_message`` — random asset selection at fixed times of day
|
||||
* ``periodic_summary_message`` — album stats summary at fixed times of day
|
||||
* ``memory_mode_message`` — "On This Day" memories at fixed times of day
|
||||
|
||||
The fire times live on the tracker's default ``TrackingConfig`` as comma-
|
||||
separated ``HH:MM`` strings (``scheduled_times`` / ``periodic_times`` /
|
||||
``memory_times``) interpreted in the app-level IANA timezone
|
||||
(``AppSetting.timezone``). The scheduler module wires the cron jobs; this
|
||||
module owns the dispatch flow once a job fires.
|
||||
|
||||
Note on per-link tracking-config overrides: schedule *times* come from the
|
||||
tracker's default config — a per-link override may disable the slot for that
|
||||
link (via ``{kind}_enabled``) but cannot shift its fire time. Consistent with
|
||||
the test-dispatch path in ``manual_dispatch``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.models.events import EventType
|
||||
from notify_bridge_core.notifications.dispatcher import (
|
||||
NotificationDispatcher,
|
||||
TargetConfig,
|
||||
)
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import (
|
||||
EventLog,
|
||||
NotificationTracker,
|
||||
ServiceProvider,
|
||||
TemplateSlot,
|
||||
TrackingConfig,
|
||||
)
|
||||
from .dispatch_helpers import (
|
||||
event_allowed_by_config,
|
||||
get_app_timezone,
|
||||
load_link_data,
|
||||
)
|
||||
from .manual_dispatch import _build_immich_event, _build_immich_periodic_event
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ScheduledKind = Literal["scheduled", "periodic", "memory"]
|
||||
|
||||
# Maps the dispatch kind to the DB slot name that holds its template.
|
||||
# The dispatcher keys templates by ``event.event_type.value`` (always
|
||||
# ``scheduled_message`` here), so we read the right ``TemplateSlot`` row and
|
||||
# inject it under that single event-type key — same pattern as the test path.
|
||||
_SLOT_MAP: dict[ScheduledKind, str] = {
|
||||
"scheduled": "scheduled_assets_message",
|
||||
"periodic": "periodic_summary_message",
|
||||
"memory": "memory_mode_message",
|
||||
}
|
||||
|
||||
|
||||
async def dispatch_scheduled_for_tracker(
|
||||
tracker_id: int, kind: ScheduledKind
|
||||
) -> None:
|
||||
"""Build the slot's event for ``tracker_id`` and fan out to its links.
|
||||
|
||||
Skips silently when the tracker is disabled, the provider is not Immich,
|
||||
the slot is disabled on the tracker's default tracking config, or no link
|
||||
has a ``TemplateConfig`` with the corresponding slot row.
|
||||
"""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
tracker = await session.get(NotificationTracker, tracker_id)
|
||||
if not tracker or not tracker.enabled:
|
||||
return
|
||||
provider = await session.get(ServiceProvider, tracker.provider_id)
|
||||
if not provider or provider.type != "immich":
|
||||
return
|
||||
|
||||
default_tc: TrackingConfig | None = None
|
||||
if tracker.default_tracking_config_id:
|
||||
default_tc = await session.get(
|
||||
TrackingConfig, tracker.default_tracking_config_id
|
||||
)
|
||||
# If the default config disables this kind, nothing to do — schedule
|
||||
# rebuild only adds jobs when the flag is set, but a stale job from
|
||||
# a previous DB state could still fire one tick before invalidation.
|
||||
if default_tc is None or not getattr(default_tc, f"{kind}_enabled", False):
|
||||
_LOGGER.debug(
|
||||
"Scheduled %s skipped for tracker %d: kind disabled on default config",
|
||||
kind, tracker_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Snapshot every field we need outside the session — after the
|
||||
# ``async with`` exits the instances are detached and lazy-load
|
||||
# would fail. Cheaper than re-fetching, safer than touching
|
||||
# attributes through a closed session.
|
||||
provider_id = provider.id
|
||||
provider_config = dict(provider.config)
|
||||
provider_name = provider.name or provider.type
|
||||
tracker_user_id = tracker.user_id
|
||||
tracker_name = tracker.name or ""
|
||||
collection_ids = list(tracker.collection_ids or [])
|
||||
|
||||
app_tz = await get_app_timezone(session)
|
||||
link_data = await load_link_data(session, tracker_id)
|
||||
|
||||
if not link_data:
|
||||
_LOGGER.info(
|
||||
"Scheduled %s for tracker %d: no enabled links, skipping",
|
||||
kind, tracker_id,
|
||||
)
|
||||
return
|
||||
|
||||
if kind == "periodic":
|
||||
event = await _build_immich_periodic_event(
|
||||
provider_config=provider_config,
|
||||
provider_name=provider_name,
|
||||
tracker_name=tracker_name,
|
||||
collection_ids=collection_ids,
|
||||
)
|
||||
else:
|
||||
event = await _build_immich_event(
|
||||
provider_config=provider_config,
|
||||
provider_name=provider_name,
|
||||
tracker_name=tracker_name,
|
||||
collection_ids=collection_ids,
|
||||
test_type=kind,
|
||||
tracking_config=default_tc,
|
||||
)
|
||||
if event is None:
|
||||
_LOGGER.warning(
|
||||
"Scheduled %s for tracker %d: provider returned no event",
|
||||
kind, tracker_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Skip empty payloads for asset-bearing kinds — sending the bare
|
||||
# "On this day:" / "Scheduled delivery —" header with no items below
|
||||
# spams chats with title-only messages every day. ``periodic`` is
|
||||
# different: it's a stats summary that's still meaningful with zero
|
||||
# assets, so we let it through.
|
||||
if kind in ("scheduled", "memory") and not event.added_assets:
|
||||
_LOGGER.info(
|
||||
"Scheduled %s for tracker %d: 0 assets matched, skipping dispatch",
|
||||
kind, tracker_id,
|
||||
)
|
||||
return
|
||||
|
||||
slot_name = _SLOT_MAP[kind]
|
||||
target_configs: list[TargetConfig] = []
|
||||
async with AsyncSession(engine) as session:
|
||||
for ld in link_data:
|
||||
tc = ld["tracking_config"] or default_tc
|
||||
tmpl = ld["template_config"]
|
||||
if tc is not None:
|
||||
# Per-link override may disable this kind even when the
|
||||
# default has it on — honour that here.
|
||||
if not getattr(tc, f"{kind}_enabled", True):
|
||||
continue
|
||||
if not event_allowed_by_config(event, tc, app_tz):
|
||||
continue
|
||||
if tmpl is None:
|
||||
continue
|
||||
|
||||
slot_rows = (await session.exec(
|
||||
select(TemplateSlot).where(
|
||||
TemplateSlot.config_id == tmpl.id,
|
||||
TemplateSlot.slot_name == slot_name,
|
||||
)
|
||||
)).all()
|
||||
if not slot_rows:
|
||||
continue
|
||||
locale_map = {s.locale: s.template for s in slot_rows}
|
||||
template_slots = {EventType.SCHEDULED_MESSAGE.value: locale_map}
|
||||
|
||||
target_configs.append(TargetConfig(
|
||||
type=ld["target_type"],
|
||||
config=ld["target_config"],
|
||||
template_slots=template_slots,
|
||||
date_format=tmpl.date_format,
|
||||
date_only_format=(
|
||||
tmpl.date_only_format or "%d.%m.%Y"
|
||||
),
|
||||
provider_api_key=provider_config.get("api_key"),
|
||||
provider_internal_url=provider_config.get("url", ""),
|
||||
provider_external_url=provider_config.get("external_domain", ""),
|
||||
receivers=ld["receivers"],
|
||||
))
|
||||
|
||||
if not target_configs:
|
||||
_LOGGER.info(
|
||||
"Scheduled %s for tracker %d: no targets after filtering",
|
||||
kind, tracker_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Lazy import to break the watcher↔scheduler↔scheduled_dispatch cycle.
|
||||
from .watcher import _get_telegram_caches
|
||||
from .http_session import get_http_session
|
||||
|
||||
url_cache, asset_cache = await _get_telegram_caches()
|
||||
http_session = await get_http_session()
|
||||
dispatcher = NotificationDispatcher(
|
||||
url_cache=url_cache, asset_cache=asset_cache, session=http_session,
|
||||
)
|
||||
_LOGGER.info(
|
||||
"Dispatching scheduled %s for tracker %d to %d link(s)",
|
||||
kind, tracker_id, len(target_configs),
|
||||
)
|
||||
results = await dispatcher.dispatch(event, target_configs)
|
||||
|
||||
# Mirror the watcher's audit trail: surface scheduled fires in EventLog so
|
||||
# the dashboard shows *why* a notification arrived (otherwise these would
|
||||
# be invisible to the activity feed).
|
||||
successes = sum(1 for r in results if isinstance(r, dict) and r.get("success"))
|
||||
async with AsyncSession(engine) as session:
|
||||
session.add(EventLog(
|
||||
user_id=tracker_user_id,
|
||||
tracker_id=tracker_id,
|
||||
tracker_name=tracker_name,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
event_type=event.event_type.value,
|
||||
collection_id=event.collection_id,
|
||||
collection_name=event.collection_name,
|
||||
assets_count=event.added_count or 0,
|
||||
details={
|
||||
"kind": kind,
|
||||
"slot": slot_name,
|
||||
"trigger": "cron",
|
||||
"timezone": app_tz,
|
||||
"targets_dispatched": len(target_configs),
|
||||
"targets_succeeded": successes,
|
||||
},
|
||||
))
|
||||
await session.commit()
|
||||
@@ -3,18 +3,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_zoneinfo(tz_name: str | None) -> ZoneInfo:
|
||||
"""Resolve an IANA tz string to a ZoneInfo, falling back to UTC on any error.
|
||||
|
||||
Kept local to avoid importing from api/dispatch layers inside the scheduler
|
||||
module (which is loaded at startup, before the API routers).
|
||||
"""
|
||||
if not tz_name:
|
||||
return ZoneInfo("UTC")
|
||||
try:
|
||||
return ZoneInfo(tz_name)
|
||||
except (ZoneInfoNotFoundError, ValueError):
|
||||
_LOGGER.warning("Unknown timezone %r; falling back to UTC", tz_name)
|
||||
return ZoneInfo("UTC")
|
||||
|
||||
|
||||
async def _load_app_timezone() -> ZoneInfo:
|
||||
"""Load the admin-configured app timezone from AppSetting (falls back to UTC)."""
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..api.app_settings import get_setting
|
||||
from ..database.engine import get_engine
|
||||
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
tz_name = await get_setting(session, "timezone")
|
||||
return _resolve_zoneinfo(tz_name)
|
||||
|
||||
_scheduler: AsyncIOScheduler | None = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adaptive polling (Tier 6 of the big-album optimization plan).
|
||||
#
|
||||
# We don't touch the user-configured ``scan_interval`` — that's still the
|
||||
# authoritative cadence. Instead, we *skip* a growing fraction of scheduled
|
||||
# ticks when a tracker is idle, and reset to 1:1 as soon as it detects
|
||||
# anything. The scheduler keeps running on the user's chosen period, so
|
||||
# response time to the *first* change after an idle stretch is never worse
|
||||
# than one tick — but the steady-state HTTP cost for a fleet of idle
|
||||
# trackers drops by ~75%.
|
||||
#
|
||||
# Thresholds are intentionally conservative: a tracker polling every 30 s
|
||||
# needs 5 min of silence before we halve its effective rate, and 15 min
|
||||
# before we quarter it. Any caller can disable adaptive behavior by passing
|
||||
# ``adaptive=False`` in the tracker filters dict (checked in ``_poll_tracker``).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_ADAPTIVE_HALVE_THRESHOLD = 10 # consecutive empty ticks → 1-in-2
|
||||
_ADAPTIVE_QUARTER_THRESHOLD = 30 # consecutive empty ticks → 1-in-4
|
||||
_ADAPTIVE_MAX_SKIP = 4 # hard cap on skip factor
|
||||
|
||||
# Per-tracker adaptive state, keyed by tracker_id. Rebuilt on process
|
||||
# restart — a short warmup period is fine and avoids persisting what is
|
||||
# effectively a performance heuristic.
|
||||
_adaptive_state: dict[int, dict[str, int]] = {}
|
||||
|
||||
|
||||
def _compute_jitter(interval_seconds: int) -> int:
|
||||
"""Return a jitter bound (in seconds) suitable for an IntervalTrigger.
|
||||
|
||||
Without jitter, a fleet of N trackers all on ``scan_interval=60`` wake up
|
||||
at the same wall-clock second every minute — that creates a thundering-
|
||||
herd on the upstream Immich/Gitea/etc. server. APScheduler's ``jitter``
|
||||
randomizes each tick's firing time by ±jitter seconds.
|
||||
|
||||
We use a quarter of the interval up to a 30 s cap. For short intervals
|
||||
(≤8 s) jitter would round to 0 — that's fine, at those cadences a
|
||||
bursty pattern is what the user implicitly opted into.
|
||||
"""
|
||||
if interval_seconds <= 0:
|
||||
return 0
|
||||
return min(interval_seconds // 4, 30)
|
||||
|
||||
|
||||
def get_scheduler() -> AsyncIOScheduler:
|
||||
global _scheduler
|
||||
if _scheduler is None:
|
||||
_scheduler = AsyncIOScheduler()
|
||||
# Sensible production defaults applied to every job unless overridden:
|
||||
# * coalesce — collapse a queue of missed runs into one firing after
|
||||
# a restart / pause, instead of bursting to catch up.
|
||||
# * misfire_grace_time — accept firings up to 5 min late without
|
||||
# dropping them silently.
|
||||
# * max_instances=1 — never run two copies of the same tracker tick
|
||||
# concurrently; the scheduler already enforces this on add_job,
|
||||
# but we also set it as the default for safety.
|
||||
_scheduler = AsyncIOScheduler(
|
||||
job_defaults={
|
||||
"coalesce": True,
|
||||
"misfire_grace_time": 300,
|
||||
"max_instances": 1,
|
||||
},
|
||||
)
|
||||
return _scheduler
|
||||
|
||||
|
||||
@@ -26,6 +111,7 @@ async def start_scheduler() -> None:
|
||||
|
||||
await _load_tracker_jobs()
|
||||
await _load_action_jobs()
|
||||
await _load_immich_dispatch_jobs()
|
||||
|
||||
# Start Telegram bot polling for bots with active command listeners
|
||||
from .telegram_poller import start_command_listener_polling
|
||||
@@ -208,21 +294,38 @@ async def _refresh_telegram_chat_titles() -> None:
|
||||
|
||||
|
||||
async def _cleanup_old_events() -> None:
|
||||
"""Delete EventLog entries older than 90 days."""
|
||||
"""Delete EventLog / WebhookPayloadLog / ActionExecution rows older than the
|
||||
configured retention window. A retention of 0 disables the job.
|
||||
"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlmodel import delete
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..config import settings
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import EventLog
|
||||
from ..database.models import ActionExecution, EventLog, WebhookPayloadLog
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
days = settings.event_log_retention_days
|
||||
if days <= 0:
|
||||
_LOGGER.debug("Event log retention disabled (days=0); skipping cleanup")
|
||||
return
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
await session.exec(delete(EventLog).where(EventLog.created_at < cutoff))
|
||||
await session.exec(
|
||||
delete(WebhookPayloadLog).where(WebhookPayloadLog.created_at < cutoff)
|
||||
)
|
||||
await session.exec(
|
||||
delete(ActionExecution).where(ActionExecution.started_at < cutoff)
|
||||
)
|
||||
await session.commit()
|
||||
_LOGGER.info("Cleaned up event log entries older than %s", cutoff.date())
|
||||
_LOGGER.info(
|
||||
"Cleaned event_log / webhook_payload_log / action_execution older than %s",
|
||||
cutoff.date(),
|
||||
)
|
||||
|
||||
|
||||
async def _load_tracker_jobs() -> None:
|
||||
@@ -250,6 +353,8 @@ async def _load_tracker_jobs() -> None:
|
||||
)
|
||||
provider_types = {p.id: p.type for p in provider_result.all()}
|
||||
|
||||
tz = await _load_app_timezone()
|
||||
|
||||
for tracker in trackers:
|
||||
job_id = f"tracker_{tracker.id}"
|
||||
if scheduler.get_job(job_id):
|
||||
@@ -263,7 +368,7 @@ async def _load_tracker_jobs() -> None:
|
||||
cron_expr = filters.get("cron_expression", "")
|
||||
if cron_expr:
|
||||
try:
|
||||
_add_cron_job(scheduler, job_id, tracker.id, cron_expr, tracker.name)
|
||||
_add_cron_job(scheduler, job_id, tracker.id, cron_expr, tracker.name, tz)
|
||||
continue
|
||||
except Exception as e:
|
||||
_LOGGER.error(
|
||||
@@ -271,16 +376,21 @@ async def _load_tracker_jobs() -> None:
|
||||
tracker.id, tracker.name, e,
|
||||
)
|
||||
|
||||
jitter = _compute_jitter(tracker.scan_interval)
|
||||
scheduler.add_job(
|
||||
_poll_tracker,
|
||||
"interval",
|
||||
seconds=tracker.scan_interval,
|
||||
jitter=jitter or None,
|
||||
id=job_id,
|
||||
args=[tracker.id],
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Scheduled tracker %d (%s) every %ds", tracker.id, tracker.name, tracker.scan_interval)
|
||||
_LOGGER.info(
|
||||
"Scheduled tracker %d (%s) every %ds (jitter ±%ds)",
|
||||
tracker.id, tracker.name, tracker.scan_interval, jitter,
|
||||
)
|
||||
|
||||
|
||||
def _add_cron_job(
|
||||
@@ -289,10 +399,18 @@ def _add_cron_job(
|
||||
tracker_id: int,
|
||||
cron_expression: str,
|
||||
tracker_name: str,
|
||||
tz: ZoneInfo,
|
||||
) -> None:
|
||||
"""Add a cron-triggered job for a scheduler-type tracker."""
|
||||
"""Add a cron-triggered job for a scheduler-type tracker.
|
||||
|
||||
``tz`` is the user-configured app timezone; without it APScheduler
|
||||
interprets the crontab in the host's local timezone, which surfaces as
|
||||
events firing at the "wrong" wall-clock time for operators in a non-UTC
|
||||
zone (see the companion fix in ``update_settings`` which reschedules on
|
||||
timezone changes).
|
||||
"""
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
trigger = CronTrigger.from_crontab(cron_expression)
|
||||
trigger = CronTrigger.from_crontab(cron_expression, timezone=tz)
|
||||
scheduler.add_job(
|
||||
_poll_tracker,
|
||||
trigger,
|
||||
@@ -301,7 +419,10 @@ def _add_cron_job(
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Scheduled tracker %d (%s) with cron: %s", tracker_id, tracker_name, cron_expression)
|
||||
_LOGGER.info(
|
||||
"Scheduled tracker %d (%s) with cron: %s [tz=%s]",
|
||||
tracker_id, tracker_name, cron_expression, tz.key,
|
||||
)
|
||||
|
||||
|
||||
async def schedule_tracker(
|
||||
@@ -313,44 +434,129 @@ async def schedule_tracker(
|
||||
scheduler = get_scheduler()
|
||||
job_id = f"tracker_{tracker_id}"
|
||||
|
||||
# A reschedule typically follows a config edit or enable/disable flip —
|
||||
# drop adaptive back-off so the first tick after the change runs promptly.
|
||||
reset_adaptive_state(tracker_id)
|
||||
|
||||
# Remove existing job first to allow trigger type changes
|
||||
if scheduler.get_job(job_id):
|
||||
scheduler.remove_job(job_id)
|
||||
|
||||
if cron_expression:
|
||||
try:
|
||||
_add_cron_job(scheduler, job_id, tracker_id, cron_expression, f"tracker-{tracker_id}")
|
||||
tz = await _load_app_timezone()
|
||||
_add_cron_job(scheduler, job_id, tracker_id, cron_expression, f"tracker-{tracker_id}", tz)
|
||||
return
|
||||
except Exception as e:
|
||||
_LOGGER.error("Invalid cron for tracker %d: %s — using interval", tracker_id, e)
|
||||
|
||||
jitter = _compute_jitter(interval)
|
||||
scheduler.add_job(
|
||||
_poll_tracker,
|
||||
"interval",
|
||||
seconds=interval,
|
||||
jitter=jitter or None,
|
||||
id=job_id,
|
||||
args=[tracker_id],
|
||||
replace_existing=True,
|
||||
)
|
||||
_LOGGER.info("Scheduled tracker %d every %ds", tracker_id, interval)
|
||||
_LOGGER.info(
|
||||
"Scheduled tracker %d every %ds (jitter ±%ds)", tracker_id, interval, jitter,
|
||||
)
|
||||
|
||||
|
||||
async def unschedule_tracker(tracker_id: int) -> None:
|
||||
"""Remove a scheduler job for a tracker."""
|
||||
scheduler = get_scheduler()
|
||||
job_id = f"tracker_{tracker_id}"
|
||||
reset_adaptive_state(tracker_id)
|
||||
if scheduler.get_job(job_id):
|
||||
scheduler.remove_job(job_id)
|
||||
_LOGGER.info("Unscheduled tracker %d", tracker_id)
|
||||
|
||||
|
||||
def _adaptive_should_skip(tracker_id: int) -> bool:
|
||||
"""Return True when the adaptive heuristic says to skip this tick.
|
||||
|
||||
Run-length skip: if we're in 1-in-K mode, skip (K-1) ticks between each
|
||||
real poll. Stateless about the *current* tick counter except for the
|
||||
``tick_counter`` we bump here.
|
||||
"""
|
||||
state = _adaptive_state.get(tracker_id)
|
||||
if not state:
|
||||
return False
|
||||
skip_every = state.get("skip_every", 1)
|
||||
if skip_every <= 1:
|
||||
return False
|
||||
state["tick_counter"] = state.get("tick_counter", 0) + 1
|
||||
# Fire on ticks where counter % skip_every == 0; skip the rest.
|
||||
return (state["tick_counter"] % skip_every) != 0
|
||||
|
||||
|
||||
def _adaptive_update(tracker_id: int, events_detected: int) -> None:
|
||||
"""Update the adaptive counter after a real tick ran."""
|
||||
state = _adaptive_state.setdefault(
|
||||
tracker_id, {"empty_count": 0, "skip_every": 1, "tick_counter": 0}
|
||||
)
|
||||
if events_detected > 0:
|
||||
if state["skip_every"] > 1:
|
||||
_LOGGER.info(
|
||||
"Adaptive polling: tracker %d saw activity, restoring base rate",
|
||||
tracker_id,
|
||||
)
|
||||
state["empty_count"] = 0
|
||||
state["skip_every"] = 1
|
||||
state["tick_counter"] = 0
|
||||
return
|
||||
|
||||
state["empty_count"] = state.get("empty_count", 0) + 1
|
||||
if (
|
||||
state["empty_count"] >= _ADAPTIVE_QUARTER_THRESHOLD
|
||||
and state["skip_every"] < _ADAPTIVE_MAX_SKIP
|
||||
):
|
||||
state["skip_every"] = _ADAPTIVE_MAX_SKIP
|
||||
_LOGGER.info(
|
||||
"Adaptive polling: tracker %d idle for %d ticks, skipping 3 of 4",
|
||||
tracker_id, state["empty_count"],
|
||||
)
|
||||
elif (
|
||||
state["empty_count"] >= _ADAPTIVE_HALVE_THRESHOLD
|
||||
and state["skip_every"] < 2
|
||||
):
|
||||
state["skip_every"] = 2
|
||||
_LOGGER.info(
|
||||
"Adaptive polling: tracker %d idle for %d ticks, skipping every other",
|
||||
tracker_id, state["empty_count"],
|
||||
)
|
||||
|
||||
|
||||
def reset_adaptive_state(tracker_id: int) -> None:
|
||||
"""Drop cached adaptive counters for a tracker.
|
||||
|
||||
Used by API callers that make changes requiring the tracker to run
|
||||
promptly on the next scheduled tick (enable/disable, config edits,
|
||||
manual "check now" actions).
|
||||
"""
|
||||
_adaptive_state.pop(tracker_id, None)
|
||||
|
||||
|
||||
async def _poll_tracker(tracker_id: int) -> None:
|
||||
"""Poll a tracker for changes."""
|
||||
from .watcher import check_tracker
|
||||
|
||||
if _adaptive_should_skip(tracker_id):
|
||||
return
|
||||
|
||||
try:
|
||||
await check_tracker(tracker_id)
|
||||
result = await check_tracker(tracker_id)
|
||||
except Exception as e:
|
||||
_LOGGER.error("Error polling tracker %d: %s", tracker_id, e)
|
||||
return
|
||||
|
||||
# Treat the "error" / "skipped" statuses as inconclusive — don't let
|
||||
# a transient upstream failure trick the heuristic into backing off.
|
||||
if isinstance(result, dict) and result.get("status") == "ok":
|
||||
_adaptive_update(tracker_id, int(result.get("events_detected", 0) or 0))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -374,6 +580,8 @@ async def _load_action_jobs() -> None:
|
||||
)
|
||||
actions = result.all()
|
||||
|
||||
tz = await _load_app_timezone()
|
||||
|
||||
for action in actions:
|
||||
job_id = f"action_{action.id}"
|
||||
if scheduler.get_job(job_id):
|
||||
@@ -382,7 +590,7 @@ async def _load_action_jobs() -> None:
|
||||
if action.schedule_type == "cron" and action.schedule_cron:
|
||||
try:
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
trigger = CronTrigger.from_crontab(action.schedule_cron)
|
||||
trigger = CronTrigger.from_crontab(action.schedule_cron, timezone=tz)
|
||||
scheduler.add_job(
|
||||
_run_action,
|
||||
trigger,
|
||||
@@ -391,8 +599,8 @@ async def _load_action_jobs() -> None:
|
||||
replace_existing=True,
|
||||
)
|
||||
_LOGGER.info(
|
||||
"Scheduled action %d (%s) with cron: %s",
|
||||
action.id, action.name, action.schedule_cron,
|
||||
"Scheduled action %d (%s) with cron: %s [tz=%s]",
|
||||
action.id, action.name, action.schedule_cron, tz.key,
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
@@ -431,7 +639,8 @@ async def schedule_action(
|
||||
if schedule_type == "cron" and cron_expression:
|
||||
try:
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
trigger = CronTrigger.from_crontab(cron_expression)
|
||||
tz = await _load_app_timezone()
|
||||
trigger = CronTrigger.from_crontab(cron_expression, timezone=tz)
|
||||
scheduler.add_job(
|
||||
_run_action,
|
||||
trigger,
|
||||
@@ -439,7 +648,10 @@ async def schedule_action(
|
||||
args=[action_id],
|
||||
replace_existing=True,
|
||||
)
|
||||
_LOGGER.info("Scheduled action %d with cron: %s", action_id, cron_expression)
|
||||
_LOGGER.info(
|
||||
"Scheduled action %d with cron: %s [tz=%s]",
|
||||
action_id, cron_expression, tz.key,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
_LOGGER.error("Invalid cron for action %d: %s — using interval", action_id, e)
|
||||
@@ -464,6 +676,96 @@ async def unschedule_action(action_id: int) -> None:
|
||||
_LOGGER.info("Unscheduled action %d", action_id)
|
||||
|
||||
|
||||
async def reschedule_cron_jobs_for_timezone_change() -> None:
|
||||
"""Re-add every cron-triggered tracker/action job under the new app timezone.
|
||||
|
||||
Called by the admin settings endpoint after the ``timezone`` AppSetting is
|
||||
updated. APScheduler's ``CronTrigger`` freezes its timezone at construction
|
||||
time, so a timezone change has no effect on jobs already in the scheduler
|
||||
— we have to rebuild those jobs. Interval-triggered jobs are tz-agnostic
|
||||
and are left alone.
|
||||
"""
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import Action, NotificationTracker, ServiceProvider as ServiceProviderModel
|
||||
|
||||
engine = get_engine()
|
||||
scheduler = get_scheduler()
|
||||
tz = await _load_app_timezone()
|
||||
rescheduled = 0
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
# Trackers with cron scheduling (scheduler provider + schedule_type=cron).
|
||||
trackers = (await session.exec(
|
||||
select(NotificationTracker).where(NotificationTracker.enabled == True) # noqa: E712
|
||||
)).all()
|
||||
provider_ids = list({t.provider_id for t in trackers})
|
||||
provider_types: dict[int, str] = {}
|
||||
if provider_ids:
|
||||
rows = await session.exec(
|
||||
select(ServiceProviderModel).where(ServiceProviderModel.id.in_(provider_ids))
|
||||
)
|
||||
provider_types = {p.id: p.type for p in rows.all()}
|
||||
|
||||
for tracker in trackers:
|
||||
if provider_types.get(tracker.provider_id) != "scheduler":
|
||||
continue
|
||||
filters = tracker.filters or {}
|
||||
if filters.get("schedule_type") != "cron":
|
||||
continue
|
||||
cron_expr = filters.get("cron_expression", "")
|
||||
if not cron_expr:
|
||||
continue
|
||||
job_id = f"tracker_{tracker.id}"
|
||||
if scheduler.get_job(job_id):
|
||||
scheduler.remove_job(job_id)
|
||||
try:
|
||||
_add_cron_job(scheduler, job_id, tracker.id, cron_expr, tracker.name, tz)
|
||||
rescheduled += 1
|
||||
except Exception as e: # noqa: BLE001
|
||||
_LOGGER.error(
|
||||
"Failed to re-apply cron for tracker %d on tz change: %s",
|
||||
tracker.id, e,
|
||||
)
|
||||
|
||||
# Actions with cron schedules.
|
||||
actions = (await session.exec(
|
||||
select(Action).where(Action.enabled == True) # noqa: E712
|
||||
)).all()
|
||||
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
for action in actions:
|
||||
if action.schedule_type != "cron" or not action.schedule_cron:
|
||||
continue
|
||||
job_id = f"action_{action.id}"
|
||||
if scheduler.get_job(job_id):
|
||||
scheduler.remove_job(job_id)
|
||||
try:
|
||||
scheduler.add_job(
|
||||
_run_action,
|
||||
CronTrigger.from_crontab(action.schedule_cron, timezone=tz),
|
||||
id=job_id,
|
||||
args=[action.id],
|
||||
replace_existing=True,
|
||||
)
|
||||
rescheduled += 1
|
||||
except Exception as e: # noqa: BLE001
|
||||
_LOGGER.error(
|
||||
"Failed to re-apply cron for action %d on tz change: %s",
|
||||
action.id, e,
|
||||
)
|
||||
|
||||
_LOGGER.info(
|
||||
"Rescheduled %d cron job(s) for new app timezone %s", rescheduled, tz.key,
|
||||
)
|
||||
|
||||
# Immich scheduled/periodic/memory jobs are also CronTrigger-based and
|
||||
# carry the same frozen-tz problem — rebuild them under the new tz.
|
||||
await reschedule_immich_dispatch_jobs()
|
||||
|
||||
|
||||
async def _run_action(action_id: int) -> None:
|
||||
"""Run an action (called by APScheduler)."""
|
||||
from .action_runner import run_action
|
||||
@@ -473,6 +775,155 @@ async def _run_action(action_id: int) -> None:
|
||||
_LOGGER.error("Error running action %d: %s", action_id, e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Immich scheduled / periodic / memory dispatch (cron-fired)
|
||||
#
|
||||
# These three slots fire on wall-clock schedules taken from the tracker's
|
||||
# default ``TrackingConfig`` (``scheduled_times``, ``periodic_times``,
|
||||
# ``memory_times`` — comma-separated ``HH:MM`` strings) interpreted in the
|
||||
# app-level IANA timezone. The dispatch flow lives in
|
||||
# ``services.scheduled_dispatch``; this section just owns scheduling.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMMICH_DISPATCH_KINDS = ("scheduled", "periodic", "memory")
|
||||
_IMMICH_DISPATCH_PREFIX = "immich_dispatch_"
|
||||
|
||||
|
||||
def _parse_hhmm_list(raw: str) -> list[tuple[int, int]]:
|
||||
"""Parse ``"09:00,18:30"`` → ``[(9, 0), (18, 30)]``, skipping bad entries.
|
||||
|
||||
A typo in one slot must not prevent the others from scheduling — we log
|
||||
and move on rather than raising.
|
||||
"""
|
||||
out: list[tuple[int, int]] = []
|
||||
for part in (raw or "").split(","):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
try:
|
||||
h_str, m_str = part.split(":", 1)
|
||||
hour, minute = int(h_str), int(m_str)
|
||||
except ValueError:
|
||||
_LOGGER.warning("Skipping invalid time literal %r", part)
|
||||
continue
|
||||
if not (0 <= hour <= 23 and 0 <= minute <= 59):
|
||||
_LOGGER.warning("Skipping out-of-range time %r", part)
|
||||
continue
|
||||
out.append((hour, minute))
|
||||
return out
|
||||
|
||||
|
||||
async def _run_immich_dispatch(tracker_id: int, kind: str) -> None:
|
||||
"""APScheduler entry point — wraps the dispatch helper to swallow errors."""
|
||||
from .scheduled_dispatch import dispatch_scheduled_for_tracker
|
||||
try:
|
||||
await dispatch_scheduled_for_tracker(tracker_id, kind) # type: ignore[arg-type]
|
||||
except Exception as err: # noqa: BLE001
|
||||
_LOGGER.error(
|
||||
"Immich %s dispatch for tracker %d failed: %s", kind, tracker_id, err,
|
||||
)
|
||||
|
||||
|
||||
async def _load_immich_dispatch_jobs() -> None:
|
||||
"""Schedule cron jobs for every (tracker, kind, time) where the kind is on.
|
||||
|
||||
Reads each enabled Immich tracker's *default* tracking config — per-link
|
||||
overrides only gate dispatch (handled in ``scheduled_dispatch``), they do
|
||||
not influence the fire schedule.
|
||||
"""
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import (
|
||||
NotificationTracker,
|
||||
ServiceProvider as ServiceProviderModel,
|
||||
TrackingConfig,
|
||||
)
|
||||
|
||||
engine = get_engine()
|
||||
scheduler = get_scheduler()
|
||||
tz = await _load_app_timezone()
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
trackers = (await session.exec(
|
||||
select(NotificationTracker).where(NotificationTracker.enabled == True) # noqa: E712
|
||||
)).all()
|
||||
if not trackers:
|
||||
return
|
||||
|
||||
provider_ids = list({t.provider_id for t in trackers})
|
||||
provider_types: dict[int, str] = {}
|
||||
if provider_ids:
|
||||
rows = await session.exec(
|
||||
select(ServiceProviderModel).where(
|
||||
ServiceProviderModel.id.in_(provider_ids)
|
||||
)
|
||||
)
|
||||
provider_types = {p.id: p.type for p in rows.all()}
|
||||
|
||||
tc_ids = list({
|
||||
t.default_tracking_config_id for t in trackers
|
||||
if t.default_tracking_config_id
|
||||
})
|
||||
tc_map: dict[int, TrackingConfig] = {}
|
||||
if tc_ids:
|
||||
rows = await session.exec(
|
||||
select(TrackingConfig).where(TrackingConfig.id.in_(tc_ids))
|
||||
)
|
||||
tc_map = {tc.id: tc for tc in rows.all()}
|
||||
|
||||
scheduled = 0
|
||||
for tracker in trackers:
|
||||
if provider_types.get(tracker.provider_id) != "immich":
|
||||
continue
|
||||
tc = tc_map.get(tracker.default_tracking_config_id) if tracker.default_tracking_config_id else None
|
||||
if tc is None:
|
||||
continue
|
||||
|
||||
for kind in _IMMICH_DISPATCH_KINDS:
|
||||
if not getattr(tc, f"{kind}_enabled", False):
|
||||
continue
|
||||
times_raw = getattr(tc, f"{kind}_times", "") or ""
|
||||
for hour, minute in _parse_hhmm_list(times_raw):
|
||||
job_id = f"{_IMMICH_DISPATCH_PREFIX}{kind}_{tracker.id}_{hour:02d}{minute:02d}"
|
||||
scheduler.add_job(
|
||||
_run_immich_dispatch,
|
||||
CronTrigger(hour=hour, minute=minute, timezone=tz),
|
||||
id=job_id,
|
||||
args=[tracker.id, kind],
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
scheduled += 1
|
||||
_LOGGER.info(
|
||||
"Scheduled Immich %s for tracker %d at %02d:%02d [tz=%s]",
|
||||
kind, tracker.id, hour, minute, tz.key,
|
||||
)
|
||||
|
||||
if scheduled:
|
||||
_LOGGER.info(
|
||||
"Loaded %d Immich scheduled/periodic/memory job(s) [tz=%s]",
|
||||
scheduled, tz.key,
|
||||
)
|
||||
|
||||
|
||||
async def reschedule_immich_dispatch_jobs() -> None:
|
||||
"""Drop and rebuild all Immich scheduled/periodic/memory jobs.
|
||||
|
||||
Cheap to call on every relevant mutation — a typical install has only a
|
||||
handful of trackers. Called from the tracker, link, and tracking-config
|
||||
CRUD endpoints, and from ``reschedule_cron_jobs_for_timezone_change``.
|
||||
"""
|
||||
scheduler = get_scheduler()
|
||||
for job in list(scheduler.get_jobs()):
|
||||
if job.id.startswith(_IMMICH_DISPATCH_PREFIX):
|
||||
scheduler.remove_job(job.id)
|
||||
await _load_immich_dispatch_jobs()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scheduled backup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -11,11 +11,13 @@ CommandTrackerListeners with enabled CommandTrackers.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.log_context import bind_log_context
|
||||
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
||||
|
||||
from ..database.engine import get_engine
|
||||
@@ -125,7 +127,14 @@ async def stop_bot_if_unused(bot_id: int) -> None:
|
||||
|
||||
|
||||
def schedule_bot_polling(bot_id: int) -> None:
|
||||
"""Add a polling job for a bot (idempotent)."""
|
||||
"""Add a polling job for a bot (idempotent).
|
||||
|
||||
We schedule at a 30 s interval, but each tick calls ``getUpdates`` with
|
||||
``timeout=25`` — Telegram holds the connection open until either an
|
||||
update arrives or the timeout elapses, so in practice the bot streams
|
||||
updates with sub-second latency while consuming ~2 API calls / minute
|
||||
per bot (down from 20 under the old 3 s short-poll).
|
||||
"""
|
||||
scheduler = get_scheduler()
|
||||
job_id = f"telegram_poll_{bot_id}"
|
||||
if scheduler.get_job(job_id):
|
||||
@@ -133,13 +142,13 @@ def schedule_bot_polling(bot_id: int) -> None:
|
||||
scheduler.add_job(
|
||||
_poll_bot,
|
||||
"interval",
|
||||
seconds=3,
|
||||
seconds=30,
|
||||
id=job_id,
|
||||
args=[bot_id],
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Started polling for bot %d", bot_id)
|
||||
_LOGGER.info("Started polling for bot %d (long-poll, 25s timeout)", bot_id)
|
||||
|
||||
|
||||
def unschedule_bot_polling(bot_id: int) -> None:
|
||||
@@ -231,8 +240,10 @@ async def _poll_bot(bot_id: int) -> None:
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, bot_token)
|
||||
# Long-poll: hold connection open until an update arrives or 25 s
|
||||
# elapse. Drastically cuts API calls vs. 3 s short-poll.
|
||||
result = await client.get_updates(
|
||||
offset=offset + 1 if offset else None, limit=50,
|
||||
offset=offset + 1 if offset else None, limit=50, timeout=25,
|
||||
)
|
||||
if not result.get("success"):
|
||||
err_text = str(result.get("error") or "")
|
||||
@@ -257,7 +268,13 @@ async def _poll_bot(bot_id: int) -> None:
|
||||
_last_update_id[bot_id] = updates[-1]["update_id"]
|
||||
|
||||
# Process each update
|
||||
from ..commands.handler import handle_command, send_media_group, send_reply
|
||||
from ..commands.handler import (
|
||||
classify_command_chat_action,
|
||||
handle_command,
|
||||
send_media_group,
|
||||
send_reply,
|
||||
)
|
||||
from .telegram_send import telegram_chat_action
|
||||
|
||||
for update in updates:
|
||||
message = update.get("message")
|
||||
@@ -283,26 +300,64 @@ async def _poll_bot(bot_id: int) -> None:
|
||||
|
||||
# Dispatch commands (only if chat has commands enabled)
|
||||
if text and text.startswith("/"):
|
||||
try:
|
||||
async with AsyncSession(engine) as cmd_session:
|
||||
chat_row = (await cmd_session.exec(
|
||||
select(TelegramChat).where(
|
||||
TelegramChat.bot_id == bot_obj.id,
|
||||
TelegramChat.chat_id == chat_id,
|
||||
from ..commands.parser import parse_command
|
||||
cmd_name, _, _ = parse_command(text)
|
||||
update_id = update.get("update_id")
|
||||
message_id = message.get("message_id")
|
||||
request_id = f"tg:{update_id}" if update_id is not None else f"tg:msg{message_id}"
|
||||
with bind_log_context(
|
||||
request_id=request_id,
|
||||
command=cmd_name or "-",
|
||||
chat_id=chat_id,
|
||||
bot_id=bot_obj.id,
|
||||
):
|
||||
started = time.monotonic()
|
||||
try:
|
||||
async with AsyncSession(engine) as cmd_session:
|
||||
chat_row = (await cmd_session.exec(
|
||||
select(TelegramChat).where(
|
||||
TelegramChat.bot_id == bot_obj.id,
|
||||
TelegramChat.chat_id == chat_id,
|
||||
)
|
||||
)).first()
|
||||
if not chat_row or not chat_row.commands_enabled:
|
||||
_LOGGER.info(
|
||||
"Command ignored — commands disabled (poll) for bot=%s chat=%s",
|
||||
bot_obj.id, chat_id,
|
||||
)
|
||||
continue
|
||||
effective_lang = chat_row.language_override or msg_language
|
||||
_LOGGER.info("Command received (poll): /%s args=%r lang=%s", cmd_name, text[:200], effective_lang)
|
||||
async with telegram_chat_action(
|
||||
bot_token, chat_id, classify_command_chat_action(text),
|
||||
):
|
||||
responses = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
|
||||
if not responses:
|
||||
_LOGGER.info(
|
||||
"Command produced no response (cmd=%r, poll) after %.0f ms",
|
||||
cmd_name, (time.monotonic() - started) * 1000,
|
||||
)
|
||||
continue
|
||||
text_count = sum(1 for r in responses if r.text)
|
||||
media_count = sum(len(r.media or []) for r in responses)
|
||||
_LOGGER.info(
|
||||
"Command dispatching %d response(s): text=%d media_items=%d",
|
||||
len(responses), text_count, media_count,
|
||||
)
|
||||
)).first()
|
||||
if not chat_row or not chat_row.commands_enabled:
|
||||
continue
|
||||
effective_lang = chat_row.language_override or msg_language
|
||||
message_id = message.get("message_id")
|
||||
responses = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
|
||||
if responses:
|
||||
for resp in responses:
|
||||
if resp.text:
|
||||
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||
if resp.media:
|
||||
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||
except Exception:
|
||||
_LOGGER.error("Error handling command from bot %d", bot_id, exc_info=True)
|
||||
for resp in responses:
|
||||
if resp.text:
|
||||
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||
if resp.media:
|
||||
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||
_LOGGER.info(
|
||||
"Command /%s completed in %.0f ms (responses=%d media=%d)",
|
||||
cmd_name, (time.monotonic() - started) * 1000,
|
||||
len(responses), media_count,
|
||||
)
|
||||
except Exception:
|
||||
_LOGGER.exception(
|
||||
"Error handling command /%s from bot %d after %.0f ms",
|
||||
cmd_name, bot_id, (time.monotonic() - started) * 1000,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Single entry point for all Telegram send operations.
|
||||
|
||||
Both the notification dispatcher (event-driven) and the bot command
|
||||
handlers (user-driven) funnel their Telegram API calls through this
|
||||
module. Keeping construction in one place means:
|
||||
|
||||
* The shared aiohttp session is always reused (one TCP pool for the
|
||||
whole process).
|
||||
* The Telegram file_id caches (URL cache + asset cache) are always
|
||||
wired in, so repeated sends — whether from a scheduled tracker or
|
||||
a ``/latest`` command — reuse cached file_ids instead of re-uploading
|
||||
the same bytes.
|
||||
* Future cross-cutting concerns (rate limiting, telemetry, retries)
|
||||
have exactly one place to live.
|
||||
|
||||
The actual Telegram API routine is still ``TelegramClient`` in core —
|
||||
this module just guarantees every caller gets a properly-wired client.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import Any, AsyncIterator, Callable
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.notifications.telegram.client import (
|
||||
NotificationResult,
|
||||
TelegramClient,
|
||||
)
|
||||
|
||||
from .http_session import get_http_session
|
||||
from .watcher import _get_telegram_caches
|
||||
|
||||
|
||||
async def get_telegram_client(
|
||||
bot_token: str,
|
||||
*,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
thumbhash_resolver: Callable[[str], str | None] | None = None,
|
||||
) -> TelegramClient:
|
||||
"""Return a ``TelegramClient`` wired to shared session + shared caches.
|
||||
|
||||
Every Telegram send in the process should acquire its client from
|
||||
here — constructing ``TelegramClient`` directly skips the caches and
|
||||
silently halves cache hit rate.
|
||||
|
||||
Args:
|
||||
bot_token: The bot's API token.
|
||||
session: Optional explicit aiohttp session. Defaults to the
|
||||
process-wide shared session.
|
||||
thumbhash_resolver: Optional asset-id → thumbhash lookup. The
|
||||
notification dispatcher passes one so asset-cache entries
|
||||
invalidate on visual change; the command path doesn't need it
|
||||
(commands always ask for a fresh result).
|
||||
"""
|
||||
if session is None:
|
||||
session = await get_http_session()
|
||||
url_cache, asset_cache = await _get_telegram_caches()
|
||||
return TelegramClient(
|
||||
session, bot_token,
|
||||
url_cache=url_cache,
|
||||
asset_cache=asset_cache,
|
||||
thumbhash_resolver=thumbhash_resolver,
|
||||
)
|
||||
|
||||
|
||||
async def send_telegram_message(
|
||||
bot_token: str,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
*,
|
||||
reply_to_message_id: int | None = None,
|
||||
disable_web_page_preview: bool = True,
|
||||
parse_mode: str = "HTML",
|
||||
) -> NotificationResult:
|
||||
"""Send a plain-text Telegram message with caches wired in."""
|
||||
client = await get_telegram_client(bot_token)
|
||||
return await client.send_message(
|
||||
chat_id, text,
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
disable_web_page_preview=disable_web_page_preview,
|
||||
parse_mode=parse_mode,
|
||||
)
|
||||
|
||||
|
||||
async def send_telegram_media(
|
||||
bot_token: str,
|
||||
chat_id: str,
|
||||
assets: list[dict[str, Any]],
|
||||
*,
|
||||
caption: str | None = None,
|
||||
reply_to_message_id: int | None = None,
|
||||
max_group_size: int = 10,
|
||||
chunk_delay: int = 0,
|
||||
max_asset_data_size: int | None = None,
|
||||
send_large_photos_as_documents: bool = False,
|
||||
chat_action: str | None = "typing",
|
||||
thumbhash_resolver: Callable[[str], str | None] | None = None,
|
||||
) -> NotificationResult:
|
||||
"""Send a Telegram media group (or single asset) with caches wired in.
|
||||
|
||||
``assets`` must be in ``TelegramClient`` format — see
|
||||
``notify_bridge_core.notifications.telegram.media.build_telegram_asset_entry``
|
||||
for the canonical builder.
|
||||
"""
|
||||
client = await get_telegram_client(
|
||||
bot_token, thumbhash_resolver=thumbhash_resolver,
|
||||
)
|
||||
return await client.send_notification(
|
||||
chat_id,
|
||||
assets=assets,
|
||||
caption=caption,
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
max_group_size=max_group_size,
|
||||
chunk_delay=chunk_delay,
|
||||
max_asset_data_size=max_asset_data_size,
|
||||
send_large_photos_as_documents=send_large_photos_as_documents,
|
||||
chat_action=chat_action,
|
||||
)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def telegram_chat_action(
|
||||
bot_token: str,
|
||||
chat_id: str,
|
||||
action: str | None,
|
||||
) -> AsyncIterator[None]:
|
||||
"""Hold a Telegram chat action (e.g. ``upload_photo``) for the block's duration.
|
||||
|
||||
Used by the command path to show ``typing`` / ``uploading photo`` while
|
||||
the bot fetches assets from the service (Immich, etc.) AND uploads them
|
||||
to Telegram — i.e. for the whole user-visible wait, not just the upload.
|
||||
|
||||
A ``None`` action makes this a no-op so callers don't have to branch.
|
||||
"""
|
||||
if not action:
|
||||
yield
|
||||
return
|
||||
|
||||
client = await get_telegram_client(bot_token)
|
||||
task = client.start_chat_action_keepalive(chat_id, action)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
@@ -187,8 +187,17 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
"asset_ids": s.asset_ids,
|
||||
"pending_asset_ids": s.pending_asset_ids,
|
||||
"shared": bool(s.shared),
|
||||
"meta_fingerprint": s.meta_fingerprint or {},
|
||||
}
|
||||
|
||||
# Snapshot the original fingerprint per collection so we can skip the
|
||||
# (expensive) asset_ids rewrite when nothing changed. For a 200k-asset
|
||||
# album this avoids a ~7 MB JSON write to the state row every tick.
|
||||
original_fingerprints: dict[str, dict[str, Any]] = {
|
||||
cid: dict(cstate.get("meta_fingerprint") or {})
|
||||
for cid, cstate in state_dict.items()
|
||||
}
|
||||
|
||||
# Load tracker-target links
|
||||
link_data = await load_link_data(session, tracker_id)
|
||||
|
||||
@@ -237,6 +246,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
name=provider_name,
|
||||
tracker_name=tracker_name,
|
||||
custom_variables=custom_vars,
|
||||
timezone_name=app_tz,
|
||||
)
|
||||
events, new_state = await sched.poll(collection_ids, state_dict)
|
||||
elif provider_type == "nut":
|
||||
@@ -279,11 +289,20 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
existing = s
|
||||
break
|
||||
|
||||
current_fingerprint = dict(cstate.get("meta_fingerprint") or {})
|
||||
prior_fingerprint = original_fingerprints.get(cid, {})
|
||||
# Skip the DB update when the provider reported no meaningful
|
||||
# change. ``existing`` is None on first-ever fetch for a
|
||||
# collection — that path always writes so the row gets created.
|
||||
if existing is not None and current_fingerprint == prior_fingerprint:
|
||||
continue
|
||||
|
||||
if existing:
|
||||
existing.asset_ids = cstate.get("asset_ids", [])
|
||||
existing.pending_asset_ids = cstate.get("pending_asset_ids", [])
|
||||
existing.collection_name = cstate.get("name", "")
|
||||
existing.shared = cstate.get("shared", False)
|
||||
existing.meta_fingerprint = current_fingerprint
|
||||
session.add(existing)
|
||||
else:
|
||||
new_ts = NotificationTrackerState(
|
||||
@@ -293,11 +312,32 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
shared=cstate.get("shared", False),
|
||||
asset_ids=cstate.get("asset_ids", []),
|
||||
pending_asset_ids=cstate.get("pending_asset_ids", []),
|
||||
meta_fingerprint=current_fingerprint,
|
||||
)
|
||||
session.add(new_ts)
|
||||
|
||||
for event in events:
|
||||
assets_count = event.added_count or event.removed_count or 0
|
||||
details: dict[str, Any] = {
|
||||
"added_count": event.added_count,
|
||||
"removed_count": event.removed_count,
|
||||
"provider_type": event.provider_type.value,
|
||||
}
|
||||
# Scheduler/periodic events carry the schedule context in ``extra``
|
||||
# (cron expression, interval, timezone, fire count). Surface that
|
||||
# in the event log so the dashboard and audit queries can show
|
||||
# *why* the event fired, not just that it did.
|
||||
if event.event_type.value == "scheduled_message":
|
||||
sched_type = tracker_filters.get("schedule_type", "interval")
|
||||
details["schedule_type"] = sched_type
|
||||
if sched_type == "cron":
|
||||
details["cron_expression"] = tracker_filters.get("cron_expression", "")
|
||||
else:
|
||||
details["interval_seconds"] = tracker.scan_interval
|
||||
details["timezone"] = app_tz
|
||||
fire_count = event.extra.get("fire_count") if event.extra else None
|
||||
if fire_count is not None:
|
||||
details["fire_count"] = fire_count
|
||||
log = EventLog(
|
||||
user_id=tracker.user_id,
|
||||
tracker_id=tracker_id,
|
||||
@@ -308,11 +348,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
collection_id=event.collection_id,
|
||||
collection_name=event.collection_name,
|
||||
assets_count=assets_count,
|
||||
details={
|
||||
"added_count": event.added_count,
|
||||
"removed_count": event.removed_count,
|
||||
"provider_type": event.provider_type.value,
|
||||
},
|
||||
details=details,
|
||||
)
|
||||
session.add(log)
|
||||
|
||||
@@ -333,7 +369,13 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
|
||||
if events and link_data:
|
||||
url_cache, asset_cache = await _get_telegram_caches()
|
||||
dispatcher = NotificationDispatcher(url_cache=url_cache, asset_cache=asset_cache)
|
||||
from .http_session import get_http_session
|
||||
shared_session = await get_http_session()
|
||||
dispatcher = NotificationDispatcher(
|
||||
url_cache=url_cache,
|
||||
asset_cache=asset_cache,
|
||||
session=shared_session,
|
||||
)
|
||||
for event in events:
|
||||
_LOGGER.info(
|
||||
"Dispatching event %s for %s (added=%d removed=%d)",
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Shared pytest fixtures.
|
||||
|
||||
We set the required env vars before any ``notify_bridge_server`` module is
|
||||
imported so ``Settings()`` passes its startup validation and opens the DB
|
||||
in a writable temp directory instead of the production ``/data`` default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Provision a writable temp data dir BEFORE the server package is imported —
|
||||
# Settings() materializes at import time, so env-var overrides have to land
|
||||
# here (conftest.py) to be effective.
|
||||
_TMP = Path(tempfile.mkdtemp(prefix="notify-bridge-tests-"))
|
||||
os.environ["NOTIFY_BRIDGE_DATA_DIR"] = str(_TMP)
|
||||
|
||||
os.environ.setdefault(
|
||||
"NOTIFY_BRIDGE_SECRET_KEY",
|
||||
"pytest-secret-key-" + "x" * 40,
|
||||
)
|
||||
os.environ.setdefault("NOTIFY_BRIDGE_DEBUG", "false")
|
||||
os.environ.setdefault("NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS", "http://localhost:8420")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tmp_data_dir() -> Path:
|
||||
"""Expose the already-provisioned temp dir to tests that want the path."""
|
||||
return _TMP
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Unit tests for the Settings validator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from notify_bridge_server.config import Settings, _FORBIDDEN_SECRETS
|
||||
|
||||
|
||||
_GOOD = "a" * 40
|
||||
|
||||
|
||||
def _make(**kwargs) -> Settings:
|
||||
defaults = dict(secret_key=_GOOD, cors_allowed_origins="http://localhost:8420")
|
||||
defaults.update(kwargs)
|
||||
return Settings(**defaults)
|
||||
|
||||
|
||||
class TestSecretKey:
|
||||
def test_accepts_long_random_key(self) -> None:
|
||||
_make()
|
||||
|
||||
def test_rejects_default(self) -> None:
|
||||
with pytest.raises(ValueError, match="SECURITY"):
|
||||
_make(secret_key="change-me-in-production")
|
||||
|
||||
def test_rejects_known_dev_keys(self) -> None:
|
||||
for bad in _FORBIDDEN_SECRETS:
|
||||
with pytest.raises(ValueError):
|
||||
_make(secret_key=bad)
|
||||
|
||||
def test_rejects_short_key(self) -> None:
|
||||
with pytest.raises(ValueError, match="32 characters"):
|
||||
_make(secret_key="short")
|
||||
|
||||
|
||||
class TestCors:
|
||||
def test_rejects_wildcard(self) -> None:
|
||||
with pytest.raises(ValueError, match="wildcard"):
|
||||
_make(cors_allowed_origins="*")
|
||||
|
||||
def test_rejects_missing_scheme(self) -> None:
|
||||
with pytest.raises(ValueError, match="scheme"):
|
||||
_make(cors_allowed_origins="example.com")
|
||||
|
||||
def test_accepts_multiple(self) -> None:
|
||||
cfg = _make(cors_allowed_origins="http://localhost:8420,https://example.com")
|
||||
assert "http://localhost:8420" in cfg.cors_allowed_origins
|
||||
|
||||
|
||||
class TestNumericValidation:
|
||||
def test_rejects_zero_access_token_expiry(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_make(access_token_expire_minutes=0)
|
||||
|
||||
def test_rejects_invalid_port(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_make(port=0)
|
||||
with pytest.raises(ValueError):
|
||||
_make(port=70000)
|
||||
|
||||
def test_rejects_negative_retention(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_make(event_log_retention_days=-1)
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Discord client 429-retry bounding."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from aioresponses import aioresponses
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.notifications.discord.client import DiscordClient
|
||||
|
||||
|
||||
WEBHOOK = "https://discord.com/api/webhooks/123/abc"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bounded_retries_on_persistent_429() -> None:
|
||||
"""If every response is 429, the client gives up after _MAX_RETRIES."""
|
||||
with aioresponses() as mocked:
|
||||
mocked.post(WEBHOOK, status=429, headers={"Retry-After": "0.001"}, repeat=True)
|
||||
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
client = DiscordClient(sess)
|
||||
result = await client.send(WEBHOOK, "hello")
|
||||
|
||||
assert result["success"] is False
|
||||
# Either the custom "Rate limited" message or the bare HTTP 429 from the
|
||||
# final attempt — both indicate bounded retries without infinite recursion.
|
||||
assert "429" in result["error"] or "Rate limited" in result["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caps_retry_after() -> None:
|
||||
"""A malicious Retry-After: 99999 must not pin the task for hours."""
|
||||
with aioresponses() as mocked:
|
||||
# First call: absurd Retry-After. Second call: success.
|
||||
mocked.post(WEBHOOK, status=429, headers={"Retry-After": "99999"})
|
||||
mocked.post(WEBHOOK, status=204)
|
||||
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
client = DiscordClient(sess)
|
||||
# Override the cap to something trivial so the test completes fast.
|
||||
client._MAX_RETRY_AFTER = 0.001 # type: ignore[attr-defined]
|
||||
result = await client.send(WEBHOOK, "hello")
|
||||
|
||||
assert result["success"] is True
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Smoke test: app imports, /api/health returns 200, version string present."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_health_endpoint(tmp_data_dir) -> None: # noqa: ARG001 — fixture applies env
|
||||
from notify_bridge_server.main import app
|
||||
|
||||
# TestClient runs the lifespan on enter/exit, so migrations run once
|
||||
# against the temp data dir — a genuine integration smoke check.
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/health")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "ok"
|
||||
assert body["version"] != "0.0.0+unknown"
|
||||
assert "." in body["version"] # looks like a real version
|
||||
|
||||
|
||||
def test_ready_endpoint(tmp_data_dir) -> None: # noqa: ARG001
|
||||
from notify_bridge_server.main import app
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/ready")
|
||||
# By the time TestClient yields, lifespan startup has completed.
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ready"
|
||||
|
||||
|
||||
def test_health_is_anonymous(tmp_data_dir) -> None: # noqa: ARG001
|
||||
"""/api/health must not require auth — the Docker healthcheck depends on it."""
|
||||
from notify_bridge_server.main import app
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/health")
|
||||
assert resp.status_code == 200
|
||||
@@ -0,0 +1,74 @@
|
||||
"""JWT encode/decode round-trips."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
|
||||
from notify_bridge_server.auth.jwt import (
|
||||
ALGORITHM,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
)
|
||||
from notify_bridge_server.config import settings
|
||||
|
||||
|
||||
def test_access_token_round_trip() -> None:
|
||||
token = create_access_token(user_id=1, role="admin", token_version=3)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == "1"
|
||||
assert payload["type"] == "access"
|
||||
assert payload["role"] == "admin"
|
||||
assert payload["ver"] == 3
|
||||
assert payload["iss"] == settings.jwt_issuer
|
||||
assert payload["aud"] == settings.jwt_audience
|
||||
|
||||
|
||||
def test_refresh_token_round_trip() -> None:
|
||||
token = create_refresh_token(user_id=7, token_version=2)
|
||||
payload = decode_token(token)
|
||||
assert payload["type"] == "refresh"
|
||||
assert payload["sub"] == "7"
|
||||
|
||||
|
||||
def test_decode_rejects_wrong_audience() -> None:
|
||||
"""A token signed with our key but for a different audience is rejected."""
|
||||
now = datetime.now(timezone.utc)
|
||||
forged = pyjwt.encode(
|
||||
{
|
||||
"iss": settings.jwt_issuer,
|
||||
"aud": "other-service",
|
||||
"sub": "1",
|
||||
"type": "access",
|
||||
"ver": 1,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=5),
|
||||
},
|
||||
settings.secret_key,
|
||||
algorithm=ALGORITHM,
|
||||
)
|
||||
with pytest.raises(pyjwt.InvalidAudienceError):
|
||||
decode_token(forged)
|
||||
|
||||
|
||||
def test_decode_rejects_none_alg() -> None:
|
||||
"""An ``alg: none`` token must never be accepted."""
|
||||
now = datetime.now(timezone.utc)
|
||||
forged = pyjwt.encode(
|
||||
{
|
||||
"iss": settings.jwt_issuer,
|
||||
"aud": settings.jwt_audience,
|
||||
"sub": "1",
|
||||
"type": "access",
|
||||
"ver": 1,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=5),
|
||||
},
|
||||
"",
|
||||
algorithm="none",
|
||||
)
|
||||
with pytest.raises(pyjwt.InvalidAlgorithmError):
|
||||
decode_token(forged)
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Pre-migration snapshot: atomic copy + retention pruning."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from notify_bridge_server.database.snapshot import (
|
||||
prune_old_snapshots,
|
||||
snapshot_and_prune,
|
||||
snapshot_database,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sqlite_engine(tmp_path: Path):
|
||||
"""Tiny SQLite DB with one table + one row, closed cleanly after the test."""
|
||||
db_path = tmp_path / "app.db"
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)"))
|
||||
await conn.execute(text("INSERT INTO t (v) VALUES ('seed')"))
|
||||
yield engine, db_path, tmp_path / "backups"
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
class TestSnapshot:
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_consistent_copy(self, sqlite_engine) -> None:
|
||||
engine, _db, backups = sqlite_engine
|
||||
dest = await snapshot_database(engine, backups)
|
||||
assert dest is not None
|
||||
assert dest.exists()
|
||||
# Can open the snapshot and see the seed row — proves it's a real DB copy.
|
||||
copy = create_async_engine(f"sqlite+aiosqlite:///{dest}")
|
||||
async with copy.connect() as c:
|
||||
result = await c.execute(text("SELECT v FROM t"))
|
||||
rows = result.all()
|
||||
await copy.dispose()
|
||||
assert rows == [("seed",)]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_db_missing(self, tmp_path: Path) -> None:
|
||||
# Engine pointing at a path that doesn't exist yet.
|
||||
engine = create_async_engine(
|
||||
f"sqlite+aiosqlite:///{tmp_path / 'does-not-exist.db'}"
|
||||
)
|
||||
try:
|
||||
dest = await snapshot_database(engine, tmp_path / "backups")
|
||||
finally:
|
||||
await engine.dispose()
|
||||
assert dest is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_unsafe_label(self, sqlite_engine) -> None:
|
||||
engine, _db, backups = sqlite_engine
|
||||
dest = await snapshot_database(engine, backups, label="bad'; DROP TABLE t;--")
|
||||
assert dest is None
|
||||
|
||||
|
||||
class TestPrune:
|
||||
def _make_snapshot(self, backups: Path, age_seconds: int) -> Path:
|
||||
backups.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now(timezone.utc) - timedelta(seconds=age_seconds)
|
||||
name = f"pre-migrate-{ts.strftime('%Y-%m-%dT%H-%M-%S')}.db"
|
||||
p = backups / name
|
||||
p.write_bytes(b"x")
|
||||
mtime = ts.timestamp()
|
||||
import os
|
||||
os.utime(p, (mtime, mtime))
|
||||
return p
|
||||
|
||||
def test_keeps_n_newest(self, tmp_path: Path) -> None:
|
||||
backups = tmp_path / "backups"
|
||||
for age in (100, 80, 60, 40, 20, 0):
|
||||
self._make_snapshot(backups, age)
|
||||
|
||||
deleted = prune_old_snapshots(backups, keep=3)
|
||||
remaining = sorted(backups.glob("pre-migrate-*.db"))
|
||||
assert len(deleted) == 3
|
||||
assert len(remaining) == 3
|
||||
|
||||
def test_keep_zero_deletes_all(self, tmp_path: Path) -> None:
|
||||
backups = tmp_path / "backups"
|
||||
for age in (30, 20, 10):
|
||||
self._make_snapshot(backups, age)
|
||||
prune_old_snapshots(backups, keep=0)
|
||||
assert list(backups.glob("pre-migrate-*.db")) == []
|
||||
|
||||
def test_missing_dir_is_noop(self, tmp_path: Path) -> None:
|
||||
assert prune_old_snapshots(tmp_path / "never-created", keep=5) == []
|
||||
|
||||
|
||||
class TestSnapshotAndPrune:
|
||||
@pytest.mark.asyncio
|
||||
async def test_keep_zero_disables(self, sqlite_engine) -> None:
|
||||
engine, _db, backups = sqlite_engine
|
||||
result = await snapshot_and_prune(engine, backups, keep=0)
|
||||
assert result is None
|
||||
assert not backups.exists() or list(backups.glob("*.db")) == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end(self, sqlite_engine) -> None:
|
||||
engine, _db, backups = sqlite_engine
|
||||
# Run twice — second run should keep both snapshots (keep=5).
|
||||
a = await snapshot_and_prune(engine, backups, keep=5)
|
||||
# Guarantee distinct filenames (timestamp has second resolution).
|
||||
await asyncio.sleep(1.05)
|
||||
b = await snapshot_and_prune(engine, backups, keep=5)
|
||||
assert a and b and a != b
|
||||
assert a.exists() and b.exists()
|
||||
@@ -0,0 +1,59 @@
|
||||
"""SSRF guard regression tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from notify_bridge_core.notifications.ssrf import (
|
||||
UnsafeURLError,
|
||||
avalidate_outbound_url,
|
||||
validate_outbound_url,
|
||||
)
|
||||
|
||||
|
||||
class TestScheme:
|
||||
def test_rejects_file_scheme(self) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
validate_outbound_url("file:///etc/passwd")
|
||||
|
||||
def test_rejects_gopher(self) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
validate_outbound_url("gopher://example.com/")
|
||||
|
||||
def test_accepts_https(self) -> None:
|
||||
# A well-known public host — validated via real DNS so this test is
|
||||
# skipped when offline.
|
||||
try:
|
||||
validate_outbound_url("https://example.com/")
|
||||
except UnsafeURLError as err:
|
||||
if "DNS" in str(err):
|
||||
pytest.skip("No DNS in test environment")
|
||||
raise
|
||||
|
||||
|
||||
class TestBlockedRanges:
|
||||
@pytest.mark.parametrize(
|
||||
"url",
|
||||
[
|
||||
"http://127.0.0.1/",
|
||||
"http://10.0.0.1/",
|
||||
"http://192.168.1.1/",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
"http://[::1]/",
|
||||
],
|
||||
)
|
||||
def test_rejects_literal_private(self, url: str) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
validate_outbound_url(url)
|
||||
|
||||
|
||||
class TestAsyncValidator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_rejects_loopback(self) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
await avalidate_outbound_url("http://127.0.0.1/")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_rejects_bad_scheme(self) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
await avalidate_outbound_url("file:///etc/passwd")
|
||||
@@ -24,7 +24,7 @@ fi
|
||||
|
||||
# Start backend
|
||||
export NOTIFY_BRIDGE_DATA_DIR=./test-data
|
||||
export NOTIFY_BRIDGE_SECRET_KEY=test-secret-key-minimum-32-chars
|
||||
export NOTIFY_BRIDGE_SECRET_KEY=dev-only-pwIOUsKmfn4CYWQ9hCRs5lmI3GgrVlXSu2nqFzGW
|
||||
# Dev targets (homelab Immich / Gitea / etc.) live on RFC1918 ranges; the SSRF
|
||||
# guard rejects private addresses by default, which would make trackers fail.
|
||||
export NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
|
||||
|
||||
Reference in New Issue
Block a user