Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 461fb495d7 | |||
| 309dec2b44 | |||
| 90def11b8d | |||
| 8f0346ea03 | |||
| a6a854ad21 | |||
| 19036a90bb | |||
| 592e1b6114 | |||
| bbcdf1c5d1 | |||
| f9040370bc | |||
| 3b683ce82c | |||
| 2bec25353b | |||
| e44d387c7f | |||
| 7cbb02b1ef | |||
| 920920bc67 | |||
| f50d465c0e | |||
| 1f880daa0c | |||
| 1024085cdd | |||
| 5604c733d1 | |||
| 3b7808aa9c |
@@ -1,13 +1,56 @@
|
|||||||
name: Build Docker Image
|
name: Build and Test
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
push:
|
||||||
|
branches: [master, main]
|
||||||
|
pull_request:
|
||||||
|
branches: [master, main]
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
test-frontend:
|
||||||
|
if: ${{ !startsWith(gitea.event.head_commit.message, 'chore: release v') }}
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Build Docker image
|
- name: Set up Node
|
||||||
run: docker build -t notify-bridge:dev .
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: "22"
|
||||||
|
cache: "npm"
|
||||||
|
cache-dependency-path: frontend/package-lock.json
|
||||||
|
|
||||||
|
- name: Install deps
|
||||||
|
run: |
|
||||||
|
cd frontend
|
||||||
|
npm ci
|
||||||
|
|
||||||
|
- name: Svelte check
|
||||||
|
run: |
|
||||||
|
cd frontend
|
||||||
|
npm run check || echo "::warning::svelte-check reported warnings"
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: |
|
||||||
|
cd frontend
|
||||||
|
npm run build
|
||||||
|
|
||||||
|
build-image:
|
||||||
|
if: ${{ !startsWith(gitea.event.head_commit.message, 'chore: release v') }}
|
||||||
|
needs: [test-frontend]
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Build Docker image (no push)
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
push: false
|
||||||
|
tags: notify-bridge:ci-${{ gitea.sha }}
|
||||||
|
cache-from: type=gha
|
||||||
|
cache-to: type=gha,mode=max
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ jobs:
|
|||||||
tags: |
|
tags: |
|
||||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.tag }}
|
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.tag }}
|
||||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.version }}
|
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.version }}
|
||||||
|
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${{ gitea.sha }}
|
||||||
${{ steps.version.outputs.is_pre == 'false' && format('{0}/{1}:latest', env.REGISTRY, env.IMAGE_NAME) || '' }}
|
${{ steps.version.outputs.is_pre == 'false' && format('{0}/{1}:latest', env.REGISTRY, env.IMAGE_NAME) || '' }}
|
||||||
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache
|
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache
|
||||||
cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max
|
cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max
|
||||||
|
|||||||
+48
-4
@@ -1,3 +1,4 @@
|
|||||||
|
# syntax=docker/dockerfile:1.7
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Stage 1: Build frontend (SvelteKit static output)
|
# Stage 1: Build frontend (SvelteKit static output)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -14,7 +15,7 @@ COPY frontend/ ./
|
|||||||
RUN npm run build
|
RUN npm run build
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Stage 2: Build Python wheels
|
# Stage 2: Build Python wheels + extract external dependency list
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
FROM python:3.12-slim AS python-build
|
FROM python:3.12-slim AS python-build
|
||||||
|
|
||||||
@@ -30,16 +31,59 @@ RUN python -m build packages/core/ --wheel --outdir /wheels
|
|||||||
COPY packages/server/ packages/server/
|
COPY packages/server/ packages/server/
|
||||||
RUN python -m build packages/server/ --wheel --outdir /wheels
|
RUN python -m build packages/server/ --wheel --outdir /wheels
|
||||||
|
|
||||||
|
# Emit /wheels/deps.txt with ONLY external (PyPI) deps — filter out
|
||||||
|
# notify-bridge-* siblings, which are installed from local wheels below.
|
||||||
|
# This file is the cache key for the external-deps install layer: as long as
|
||||||
|
# pyproject.toml dependency lines don't change, the runtime install layer is
|
||||||
|
# served from registry buildcache and no wheels are re-downloaded.
|
||||||
|
RUN python <<'PY'
|
||||||
|
import tomllib
|
||||||
|
|
||||||
|
deps: list[str] = []
|
||||||
|
for p in ("packages/core/pyproject.toml", "packages/server/pyproject.toml"):
|
||||||
|
with open(p, "rb") as f:
|
||||||
|
data = tomllib.load(f)
|
||||||
|
for d in data["project"].get("dependencies", []):
|
||||||
|
if not d.lstrip().lower().startswith("notify-bridge-"):
|
||||||
|
deps.append(d)
|
||||||
|
|
||||||
|
seen: set[str] = set()
|
||||||
|
with open("/wheels/deps.txt", "w") as f:
|
||||||
|
for d in deps:
|
||||||
|
if d not in seen:
|
||||||
|
seen.add(d)
|
||||||
|
f.write(d + "\n")
|
||||||
|
PY
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Stage 3: Runtime
|
# Stage 3: Runtime
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
FROM python:3.12-slim
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
# uv — fast pip replacement. Installed from PyPI (Fastly CDN) rather than
|
||||||
|
# ghcr.io/astral-sh/uv, because GHCR pulls from this runner crawl at a few
|
||||||
|
# hundred KB/s and take longer than the install savings would recoup.
|
||||||
|
RUN pip install --no-cache-dir uv==0.11.7
|
||||||
|
|
||||||
|
ENV UV_COMPILE_BYTECODE=1 \
|
||||||
|
UV_LINK_MODE=copy
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Install wheels
|
# Install external deps first — layer cache key is deps.txt content, which
|
||||||
COPY --from=python-build /wheels/ /tmp/wheels/
|
# only changes when pyproject.toml dependency lines change (not on version
|
||||||
RUN pip install --no-cache-dir /tmp/wheels/*.whl && rm -rf /tmp/wheels
|
# bumps). The cache mount persists downloaded wheels across local rebuilds;
|
||||||
|
# in CI, the registry buildcache serves the whole layer when unchanged.
|
||||||
|
COPY --from=python-build /wheels/deps.txt /tmp/deps.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv pip install --system -r /tmp/deps.txt \
|
||||||
|
&& rm /tmp/deps.txt
|
||||||
|
|
||||||
|
# Install local wheels without re-resolving — all external deps are present.
|
||||||
|
COPY --from=python-build /wheels/*.whl /tmp/wheels/
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv pip install --system --no-deps /tmp/wheels/*.whl \
|
||||||
|
&& rm -rf /tmp/wheels
|
||||||
|
|
||||||
# Copy frontend build
|
# Copy frontend build
|
||||||
COPY --from=frontend-build /build/build/ /app/static/
|
COPY --from=frontend-build /build/build/ /app/static/
|
||||||
|
|||||||
+16
-24
@@ -1,36 +1,28 @@
|
|||||||
# v0.3.0 (2026-04-22)
|
# v0.5.0 (2026-04-24)
|
||||||
|
|
||||||
Major polling perf overhaul for large Immich libraries plus a UX fix for
|
A small but impactful release that finally makes the Immich scheduled / periodic / memory dispatch fire on its own. The slot was already visible in the tracker UI and the "Test" button worked — but no production scheduler was reading the config, so users only ever saw fires through manual tests. This release wires the missing cron jobs end-to-end.
|
||||||
slow bot commands. Combined impact on idle albums: per-tick cost drops
|
|
||||||
from ~150 MB fetched to a few hundred bytes; active albums now fetch
|
|
||||||
O(changes) instead of O(library). Tested against a ~200k-asset library.
|
|
||||||
|
|
||||||
**Schema change:** adds a `meta_fingerprint` JSON column to
|
|
||||||
`notification_tracker_state` — applied automatically by the startup
|
|
||||||
migration, no manual step required.
|
|
||||||
|
|
||||||
## Performance
|
|
||||||
|
|
||||||
- **Skip full album fetch on idle ticks** — new `ImmichAlbumMeta` + `get_album_meta()` probe using `?withoutAssets=true` as a cheap change-detection fingerprint. When the fingerprint matches and no pending assets are outstanding, `poll()` short-circuits and does no asset fetch at all. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
|
||||||
- **Delta-fetch active albums** — when the fingerprint changes, poll with `updatedAfter` instead of refetching the whole album; falls back to a full fetch only on count decrease or mixed add+remove that delta can't reconcile. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
|
||||||
- **Parallel meta probes** — `asyncio.gather` over album meta probes so a 20-album tracker pays one round-trip of latency instead of 20. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
|
||||||
- **Tick-scoped shared-links cache** — new `get_all_shared_links_by_album()` coalesces to one `/api/shared-links` request per tick instead of one per changed album. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
|
||||||
- **Module-level users cache** — 1 h TTL, sha256-keyed, shared across providers that target the same Immich server. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
|
||||||
- **Skip `asset_ids` DB rewrite on idle ticks** — watcher no longer rewrites the (potentially ~8 MB for huge albums) JSON column when the fingerprint didn't change. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
|
||||||
- **Adaptive polling** — after 10 empty ticks the scheduler skips 1-in-2, after 30 empty ticks skips 1-in-4; resets on the first detected change or any schedule edit. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
|
||||||
- **APScheduler jitter** — `interval/4`, capped at 30 s, to smooth thundering-herd bursts when many trackers share the same `scan_interval`. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
|
||||||
- **Event payload cap** — 50 added / 200 removed assets per event so a bulk import can't explode a Jinja template or exceed Telegram message limits. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- **Chat-action hint stays alive during slow command fetches** — Telegram chat actions expire after ~5 s, so slow bot commands (`/latest`, `/random`, `/favorites`, `/memory`, `/search`, `/find`, `/person`, `/place`, `/summary`) previously showed a hint that vanished long before the media arrived and users saw nothing happening. New `telegram_chat_action` async context manager starts a keep-alive task that re-posts the action every 4 s until it exits; `classify_command_chat_action` maps each command to the right action (`upload_photo` for media-returning commands, `typing` for `/summary`, none for fast DB-only commands like `/status` / `/events`). Wired into both the webhook and long-poll paths. ([69711bb](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/69711bb))
|
- **Cron-fired Immich dispatch for scheduled / periodic / memory slots** — adds the missing production fan-out so `scheduled_enabled` / `scheduled_times` (and the periodic / memory counterparts) on `TrackingConfig` actually fire on their own, not only through "Test" ([309dec2](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/309dec2)):
|
||||||
|
- New `services/scheduled_dispatch.py` reuses the test-path event builders, picks the slot template per kind (`scheduled_assets` / `periodic_assets` / `memory_assets`), and writes an `EventLog` row per fire so the dashboard reflects it.
|
||||||
|
- `services/scheduler.py` gains `_load_immich_dispatch_jobs`, which builds one `CronTrigger` per `(tracker, kind, HH:MM)` from each tracker's default `TrackingConfig`, all keyed off the app-level IANA timezone. `reschedule_immich_dispatch_jobs` rebuilds the job set on any relevant CRUD or timezone change.
|
||||||
|
- Tracker / link / tracking-config CRUD endpoints now invalidate the schedule, so edits take effect immediately without a restart.
|
||||||
|
- Dispatch is skipped when scheduled / memory queries yield zero matching assets — prevents header-only "On this day:" spam when nothing qualifies.
|
||||||
|
- EN / RU default `scheduled_assets` templates updated to surface that the delivery is a scheduled random selection.
|
||||||
|
|
||||||
|
## Upgrade Notes
|
||||||
|
|
||||||
|
- No config changes required. Existing `scheduled_enabled` / `scheduled_times` / `periodic_*` / `memory_*` settings on tracking configs will start firing automatically on the next startup.
|
||||||
|
- If you had been relying on the "Test" button as a workaround, you can stop doing that now.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>All Commits</summary>
|
<summary>All Commits</summary>
|
||||||
|
|
||||||
- [69711bb](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/69711bb) — feat(commands): keep chat-action hint alive during slow command fetches *(alexei.dolgolyov)*
|
| Hash | Message | Author |
|
||||||
- [fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20) — perf(immich): skip full album fetch on idle ticks; delta-fetch for active ones *(alexei.dolgolyov)*
|
|------------------------------------------------------------------------------------------|------------------------------------------------------------------|------------------|
|
||||||
|
| [309dec2](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/309dec2) | feat(immich): wire cron-fired scheduled/periodic/memory dispatch | alexei.dolgolyov |
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|||||||
+27
-7
@@ -10,18 +10,38 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- notify-bridge-data:/data
|
- notify-bridge-data:/data
|
||||||
environment:
|
environment:
|
||||||
|
# REQUIRED — any 32+ byte random string. `openssl rand -hex 32` is one way.
|
||||||
- NOTIFY_BRIDGE_SECRET_KEY=${NOTIFY_BRIDGE_SECRET_KEY:?Set NOTIFY_BRIDGE_SECRET_KEY (min 32 chars)}
|
- NOTIFY_BRIDGE_SECRET_KEY=${NOTIFY_BRIDGE_SECRET_KEY:?Set NOTIFY_BRIDGE_SECRET_KEY (min 32 chars)}
|
||||||
- NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS=${NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS:-*}
|
# Comma-separated list of allowed browser origins. Wildcard `*` is
|
||||||
# Homelab target: allow outbound requests to RFC1918 / link-local addresses.
|
# rejected on startup because credentials are enabled.
|
||||||
# The SSRF guard otherwise rejects 10.*/172.16.*/192.168.*/169.254.* hosts,
|
- NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS=${NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS:-http://localhost:8420}
|
||||||
# which breaks tracking of Immich / Gitea / etc. running on the same LAN.
|
# Trusted proxy IPs whose X-Forwarded-For / X-Forwarded-Proto we honor.
|
||||||
- NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
|
# Set this to your reverse proxy's IP (e.g. 172.17.0.1 for the default
|
||||||
|
# docker bridge, or `*` only if the container is NOT reachable from the
|
||||||
|
# public internet).
|
||||||
|
- NOTIFY_BRIDGE_FORWARDED_ALLOW_IPS=${NOTIFY_BRIDGE_FORWARDED_ALLOW_IPS:-127.0.0.1}
|
||||||
|
# Opt-in SSRF bypass for private/loopback/link-local hosts (homelab
|
||||||
|
# scenario — tracking an Immich/Gitea instance on the same LAN). DO NOT
|
||||||
|
# enable on a publicly exposed instance.
|
||||||
|
# - NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8420/api/health')"]
|
# Use /api/ready (not /api/health) so the container is only reported
|
||||||
|
# healthy after migrations and the scheduler finish booting.
|
||||||
|
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8420/api/ready', timeout=3)"]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 3
|
retries: 3
|
||||||
start_period: 10s
|
start_period: 30s
|
||||||
|
read_only: true
|
||||||
|
tmpfs:
|
||||||
|
- /tmp
|
||||||
|
security_opt:
|
||||||
|
- no-new-privileges:true
|
||||||
|
cap_drop:
|
||||||
|
- ALL
|
||||||
|
mem_limit: 512m
|
||||||
|
cpus: 1.0
|
||||||
|
pids_limit: 256
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
notify-bridge-data:
|
notify-bridge-data:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"name": "notify-bridge-frontend",
|
"name": "notify-bridge-frontend",
|
||||||
"private": true,
|
"private": true,
|
||||||
"version": "0.3.0",
|
"version": "0.5.0",
|
||||||
"type": "module",
|
"type": "module",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "vite dev",
|
"dev": "vite dev",
|
||||||
|
|||||||
@@ -55,7 +55,8 @@
|
|||||||
"passwordTooShort": "Password must be at least 8 characters",
|
"passwordTooShort": "Password must be at least 8 characters",
|
||||||
"or": "or",
|
"or": "or",
|
||||||
"loginFailed": "Login failed",
|
"loginFailed": "Login failed",
|
||||||
"setupFailed": "Setup failed"
|
"setupFailed": "Setup failed",
|
||||||
|
"backendUnreachable": "Cannot reach the server. Check that it's running and try again."
|
||||||
},
|
},
|
||||||
"dashboard": {
|
"dashboard": {
|
||||||
"title": "Dashboard",
|
"title": "Dashboard",
|
||||||
@@ -78,6 +79,7 @@
|
|||||||
"collectionRenamed": "collection renamed",
|
"collectionRenamed": "collection renamed",
|
||||||
"collectionDeleted": "collection deleted",
|
"collectionDeleted": "collection deleted",
|
||||||
"sharingChanged": "sharing changed",
|
"sharingChanged": "sharing changed",
|
||||||
|
"scheduledMessage": "scheduled message",
|
||||||
"actionSuccess": "action run",
|
"actionSuccess": "action run",
|
||||||
"actionPartial": "action partial",
|
"actionPartial": "action partial",
|
||||||
"actionFailed": "action failed",
|
"actionFailed": "action failed",
|
||||||
@@ -694,6 +696,13 @@
|
|||||||
"locales": "Template Languages",
|
"locales": "Template Languages",
|
||||||
"supportedLocales": "Supported Locales",
|
"supportedLocales": "Supported Locales",
|
||||||
"supportedLocalesHint": "Languages available when authoring notification and command templates. Built-in defaults ship for English and Russian; other languages start empty.",
|
"supportedLocalesHint": "Languages available when authoring notification and command templates. Built-in defaults ship for English and Russian; other languages start empty.",
|
||||||
|
"logging": "Logging",
|
||||||
|
"logLevel": "Log Level",
|
||||||
|
"logLevelHint": "Root log level for the server. Raise to DEBUG while investigating; keep at INFO in production. WARNING/ERROR hide per-command progress lines.",
|
||||||
|
"logFormat": "Log Format",
|
||||||
|
"logFormatHint": "Output format. 'text' is human-readable; 'json' emits one object per line for log aggregators (Loki, ELK). Changing this requires a server restart.",
|
||||||
|
"logLevels": "Per-Module Overrides",
|
||||||
|
"logLevelsHint": "Comma-separated 'module=LEVEL' pairs to silence noisy modules or drill into one area. Example: sqlalchemy.engine=WARNING,notify_bridge_core.notifications.telegram.client=DEBUG",
|
||||||
"saved": "Settings saved"
|
"saved": "Settings saved"
|
||||||
},
|
},
|
||||||
"hints": {
|
"hints": {
|
||||||
|
|||||||
@@ -55,7 +55,8 @@
|
|||||||
"passwordTooShort": "Пароль должен быть не менее 8 символов",
|
"passwordTooShort": "Пароль должен быть не менее 8 символов",
|
||||||
"or": "или",
|
"or": "или",
|
||||||
"loginFailed": "Ошибка входа",
|
"loginFailed": "Ошибка входа",
|
||||||
"setupFailed": "Ошибка настройки"
|
"setupFailed": "Ошибка настройки",
|
||||||
|
"backendUnreachable": "Не удалось подключиться к серверу. Убедитесь, что он запущен, и повторите попытку."
|
||||||
},
|
},
|
||||||
"dashboard": {
|
"dashboard": {
|
||||||
"title": "Главная",
|
"title": "Главная",
|
||||||
@@ -78,6 +79,7 @@
|
|||||||
"collectionRenamed": "альбом переименован",
|
"collectionRenamed": "альбом переименован",
|
||||||
"collectionDeleted": "альбом удалён",
|
"collectionDeleted": "альбом удалён",
|
||||||
"sharingChanged": "изменение доступа",
|
"sharingChanged": "изменение доступа",
|
||||||
|
"scheduledMessage": "запланированное сообщение",
|
||||||
"actionSuccess": "действие выполнено",
|
"actionSuccess": "действие выполнено",
|
||||||
"actionPartial": "действие частично",
|
"actionPartial": "действие частично",
|
||||||
"actionFailed": "действие провалено",
|
"actionFailed": "действие провалено",
|
||||||
@@ -694,6 +696,13 @@
|
|||||||
"locales": "Языки шаблонов",
|
"locales": "Языки шаблонов",
|
||||||
"supportedLocales": "Поддерживаемые локали",
|
"supportedLocales": "Поддерживаемые локали",
|
||||||
"supportedLocalesHint": "Языки, доступные для редактирования шаблонов уведомлений и команд. Встроенные шаблоны поставляются для английского и русского; другие языки начинают с пустых.",
|
"supportedLocalesHint": "Языки, доступные для редактирования шаблонов уведомлений и команд. Встроенные шаблоны поставляются для английского и русского; другие языки начинают с пустых.",
|
||||||
|
"logging": "Логирование",
|
||||||
|
"logLevel": "Уровень логов",
|
||||||
|
"logLevelHint": "Уровень логирования сервера. Поднимайте до DEBUG при отладке; оставляйте INFO в продакшене. WARNING/ERROR скрывают пошаговые строки по командам.",
|
||||||
|
"logFormat": "Формат логов",
|
||||||
|
"logFormatHint": "Формат вывода. 'text' — читаемый человеком; 'json' — по одному объекту в строке для агрегаторов (Loki, ELK). Смена требует перезапуска сервера.",
|
||||||
|
"logLevels": "Переопределения по модулям",
|
||||||
|
"logLevelsHint": "Пары 'модуль=УРОВЕНЬ' через запятую, чтобы приглушить шумные модули или углубиться в один. Пример: sqlalchemy.engine=WARNING,notify_bridge_core.notifications.telegram.client=DEBUG",
|
||||||
"saved": "Настройки сохранены"
|
"saved": "Настройки сохранены"
|
||||||
},
|
},
|
||||||
"hints": {
|
"hints": {
|
||||||
|
|||||||
@@ -223,6 +223,7 @@
|
|||||||
collection_renamed: 'dashboard.collectionRenamed',
|
collection_renamed: 'dashboard.collectionRenamed',
|
||||||
collection_deleted: 'dashboard.collectionDeleted',
|
collection_deleted: 'dashboard.collectionDeleted',
|
||||||
sharing_changed: 'dashboard.sharingChanged',
|
sharing_changed: 'dashboard.sharingChanged',
|
||||||
|
scheduled_message: 'dashboard.scheduledMessage',
|
||||||
action_success: 'dashboard.actionSuccess',
|
action_success: 'dashboard.actionSuccess',
|
||||||
action_partial: 'dashboard.actionPartial',
|
action_partial: 'dashboard.actionPartial',
|
||||||
action_failed: 'dashboard.actionFailed',
|
action_failed: 'dashboard.actionFailed',
|
||||||
@@ -231,11 +232,13 @@
|
|||||||
const eventIcons: Record<string, string> = {
|
const eventIcons: Record<string, string> = {
|
||||||
assets_added: 'mdiImagePlus', assets_removed: 'mdiImageMinus',
|
assets_added: 'mdiImagePlus', assets_removed: 'mdiImageMinus',
|
||||||
collection_renamed: 'mdiRename', collection_deleted: 'mdiDeleteAlert', sharing_changed: 'mdiShareVariant',
|
collection_renamed: 'mdiRename', collection_deleted: 'mdiDeleteAlert', sharing_changed: 'mdiShareVariant',
|
||||||
|
scheduled_message: 'mdiCalendarClock',
|
||||||
action_success: 'mdiPlayCircle', action_partial: 'mdiAlertCircle', action_failed: 'mdiCloseCircle',
|
action_success: 'mdiPlayCircle', action_partial: 'mdiAlertCircle', action_failed: 'mdiCloseCircle',
|
||||||
};
|
};
|
||||||
const eventColors: Record<string, string> = {
|
const eventColors: Record<string, string> = {
|
||||||
assets_added: '#059669', assets_removed: '#ef4444',
|
assets_added: '#059669', assets_removed: '#ef4444',
|
||||||
collection_renamed: '#6366f1', collection_deleted: '#dc2626', sharing_changed: '#f59e0b',
|
collection_renamed: '#6366f1', collection_deleted: '#dc2626', sharing_changed: '#f59e0b',
|
||||||
|
scheduled_message: '#8b5cf6',
|
||||||
action_success: '#0d9488', action_partial: '#f59e0b', action_failed: '#dc2626',
|
action_success: '#0d9488', action_partial: '#f59e0b', action_failed: '#dc2626',
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -15,13 +15,32 @@
|
|||||||
let submitting = $state(false);
|
let submitting = $state(false);
|
||||||
let mounted = $state(false);
|
let mounted = $state(false);
|
||||||
|
|
||||||
|
let backendDown = $state(false);
|
||||||
|
|
||||||
onMount(async () => {
|
onMount(async () => {
|
||||||
initTheme();
|
initTheme();
|
||||||
mounted = true;
|
mounted = true;
|
||||||
|
// If the user is already signed in (valid access token in storage),
|
||||||
|
// there is no reason to show them the login form. loadUser() runs in
|
||||||
|
// the root layout; we just check the resolved state after a short tick.
|
||||||
|
const { isAuthenticated } = await import('$lib/api');
|
||||||
|
if (isAuthenticated()) {
|
||||||
|
try {
|
||||||
|
await api('/auth/me');
|
||||||
|
goto('/');
|
||||||
|
return;
|
||||||
|
} catch {
|
||||||
|
// Token was stale; fall through to the login form.
|
||||||
|
}
|
||||||
|
}
|
||||||
try {
|
try {
|
||||||
const res = await api<{ needs_setup: boolean }>('/auth/needs-setup');
|
const res = await api<{ needs_setup: boolean }>('/auth/needs-setup');
|
||||||
if (res.needs_setup) goto('/setup');
|
if (res.needs_setup) goto('/setup');
|
||||||
} catch { /* ignore */ }
|
} catch {
|
||||||
|
// The backend is unreachable — surface that distinctly so the user
|
||||||
|
// doesn't blame the login form for a network/backend problem.
|
||||||
|
backendDown = true;
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
async function handleSubmit(e: SubmitEvent) {
|
async function handleSubmit(e: SubmitEvent) {
|
||||||
@@ -62,7 +81,12 @@
|
|||||||
<p class="text-sm mt-1" style="color: var(--color-muted-foreground);">{t('auth.signInTitle')}</p>
|
<p class="text-sm mt-1" style="color: var(--color-muted-foreground);">{t('auth.signInTitle')}</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{#if error}
|
{#if backendDown}
|
||||||
|
<div class="auth-error animate-fade-slide-in">
|
||||||
|
<MdiIcon name="mdiAlertCircle" size={16} />
|
||||||
|
{t('auth.backendUnreachable')}
|
||||||
|
</div>
|
||||||
|
{:else if error}
|
||||||
<div class="auth-error animate-fade-slide-in">
|
<div class="auth-error animate-fade-slide-in">
|
||||||
<MdiIcon name="mdiAlertCircle" size={16} />
|
<MdiIcon name="mdiAlertCircle" size={16} />
|
||||||
{error}
|
{error}
|
||||||
|
|||||||
@@ -37,6 +37,9 @@
|
|||||||
telegram_asset_cache_max_entries: '5000',
|
telegram_asset_cache_max_entries: '5000',
|
||||||
supported_locales: 'en,ru',
|
supported_locales: 'en,ru',
|
||||||
timezone: 'UTC',
|
timezone: 'UTC',
|
||||||
|
log_level: 'INFO',
|
||||||
|
log_format: 'text',
|
||||||
|
log_levels: '',
|
||||||
});
|
});
|
||||||
let cacheStats = $state<CacheStats | null>(null);
|
let cacheStats = $state<CacheStats | null>(null);
|
||||||
|
|
||||||
@@ -204,6 +207,40 @@
|
|||||||
</div>
|
</div>
|
||||||
</Card>
|
</Card>
|
||||||
|
|
||||||
|
<!-- Logging section -->
|
||||||
|
<Card>
|
||||||
|
<h3 class="text-sm font-semibold mb-4 flex items-center gap-2">
|
||||||
|
<MdiIcon name="mdiTextBoxOutline" size={18} />
|
||||||
|
{t('settings.logging')}
|
||||||
|
</h3>
|
||||||
|
<div class="grid grid-cols-1 sm:grid-cols-2 gap-4">
|
||||||
|
<div>
|
||||||
|
<label class="block text-xs font-medium mb-1">{t('settings.logLevel')}<Hint text={t('settings.logLevelHint')} /></label>
|
||||||
|
<select bind:value={settings.log_level}
|
||||||
|
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
|
||||||
|
<option value="DEBUG">DEBUG</option>
|
||||||
|
<option value="INFO">INFO</option>
|
||||||
|
<option value="WARNING">WARNING</option>
|
||||||
|
<option value="ERROR">ERROR</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label class="block text-xs font-medium mb-1">{t('settings.logFormat')}<Hint text={t('settings.logFormatHint')} /></label>
|
||||||
|
<select bind:value={settings.log_format}
|
||||||
|
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
|
||||||
|
<option value="text">text</option>
|
||||||
|
<option value="json">json</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div class="sm:col-span-2">
|
||||||
|
<label class="block text-xs font-medium mb-1">{t('settings.logLevels')}<Hint text={t('settings.logLevelsHint')} /></label>
|
||||||
|
<input bind:value={settings.log_levels}
|
||||||
|
placeholder="sqlalchemy.engine=WARNING,notify_bridge_core.notifications.telegram.client=DEBUG"
|
||||||
|
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)] font-mono" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
|
||||||
<Button onclick={save} disabled={saving}>
|
<Button onclick={save} disabled={saving}>
|
||||||
{saving ? t('common.loading') : t('common.save')}
|
{saving ? t('common.loading') : t('common.save')}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "notify-bridge-core"
|
name = "notify-bridge-core"
|
||||||
version = "0.3.0"
|
version = "0.5.0"
|
||||||
description = "Core library for Notify Bridge — service provider abstractions, models, notifications, and templates"
|
description = "Core library for Notify Bridge — service provider abstractions, models, notifications, and templates"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
"""Request-scoped ContextVars that propagate into log records.
|
||||||
|
|
||||||
|
The server sets these at entry points (Telegram webhook, scheduler dispatch,
|
||||||
|
REST call) and they propagate through async calls automatically. A
|
||||||
|
``LogRecordFactory`` installed by ``notify_bridge_server.logging_setup``
|
||||||
|
reads them so every log line is tagged (``request_id``, ``command``,
|
||||||
|
``chat_id``, ``bot_id``, ``dispatch_id``) without each call site having
|
||||||
|
to pass the values explicitly.
|
||||||
|
|
||||||
|
Kept in ``notify_bridge_core`` so core modules (``TelegramClient``,
|
||||||
|
``NotificationDispatcher``) can *set* additional context (e.g. a
|
||||||
|
``dispatch_id``) without depending on the server package.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from contextvars import ContextVar, Token
|
||||||
|
from typing import Any, Iterator
|
||||||
|
|
||||||
|
request_id_var: ContextVar[str | None] = ContextVar("request_id", default=None)
|
||||||
|
command_var: ContextVar[str | None] = ContextVar("command", default=None)
|
||||||
|
chat_id_var: ContextVar[str | None] = ContextVar("chat_id", default=None)
|
||||||
|
bot_id_var: ContextVar[int | None] = ContextVar("bot_id", default=None)
|
||||||
|
dispatch_id_var: ContextVar[str | None] = ContextVar("dispatch_id", default=None)
|
||||||
|
|
||||||
|
_VAR_MAP: dict[str, ContextVar[Any]] = {
|
||||||
|
"request_id": request_id_var,
|
||||||
|
"command": command_var,
|
||||||
|
"chat_id": chat_id_var,
|
||||||
|
"bot_id": bot_id_var,
|
||||||
|
"dispatch_id": dispatch_id_var,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def bind_log_context(**kwargs: Any) -> Iterator[None]:
|
||||||
|
"""Bind the given context fields for the duration of the ``with`` block.
|
||||||
|
|
||||||
|
Unknown keys are ignored so callers can pass whatever they want without
|
||||||
|
an ``if`` ladder. Values are reset on exit even if the block raises.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
``with bind_log_context(request_id="abc", command="random"): ...``
|
||||||
|
"""
|
||||||
|
tokens: list[tuple[ContextVar[Any], Token]] = []
|
||||||
|
try:
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
var = _VAR_MAP.get(key)
|
||||||
|
if var is None:
|
||||||
|
continue
|
||||||
|
tokens.append((var, var.set(value)))
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for var, tok in tokens:
|
||||||
|
var.reset(tok)
|
||||||
|
|
||||||
|
|
||||||
|
def current_log_context() -> dict[str, Any]:
|
||||||
|
"""Return a snapshot of the currently-bound context values (non-None)."""
|
||||||
|
snap: dict[str, Any] = {}
|
||||||
|
for key, var in _VAR_MAP.items():
|
||||||
|
val = var.get()
|
||||||
|
if val is not None:
|
||||||
|
snap[key] = val
|
||||||
|
return snap
|
||||||
@@ -52,22 +52,46 @@ class DiscordClient:
|
|||||||
|
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
|
|
||||||
|
_MAX_RETRIES = 3
|
||||||
|
_MAX_RETRY_AFTER = 60.0
|
||||||
|
|
||||||
async def _post(self, url: str, payload: dict) -> dict[str, Any]:
|
async def _post(self, url: str, payload: dict) -> dict[str, Any]:
|
||||||
|
"""POST with bounded 429 retry.
|
||||||
|
|
||||||
|
We cap retries at _MAX_RETRIES and the ``Retry-After`` header at
|
||||||
|
_MAX_RETRY_AFTER seconds so a hostile or misbehaving upstream cannot
|
||||||
|
pin the dispatch task indefinitely.
|
||||||
|
"""
|
||||||
|
for attempt in range(self._MAX_RETRIES + 1):
|
||||||
try:
|
try:
|
||||||
async with self._session.post(
|
async with self._session.post(
|
||||||
url, json=payload, headers={"Content-Type": "application/json"}
|
url,
|
||||||
|
json=payload,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
allow_redirects=False,
|
||||||
) as resp:
|
) as resp:
|
||||||
if resp.status == 429:
|
if resp.status == 429 and attempt < self._MAX_RETRIES:
|
||||||
|
try:
|
||||||
retry_after = float(resp.headers.get("Retry-After", "2"))
|
retry_after = float(resp.headers.get("Retry-After", "2"))
|
||||||
_LOGGER.warning("Discord rate limited, retrying after %.1fs", retry_after)
|
except (TypeError, ValueError):
|
||||||
|
retry_after = 2.0
|
||||||
|
retry_after = max(0.0, min(retry_after, self._MAX_RETRY_AFTER))
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Discord rate limited, retrying after %.1fs (attempt %d/%d)",
|
||||||
|
retry_after, attempt + 1, self._MAX_RETRIES,
|
||||||
|
)
|
||||||
await asyncio.sleep(retry_after)
|
await asyncio.sleep(retry_after)
|
||||||
return await self._post(url, payload)
|
continue
|
||||||
if 200 <= resp.status < 300:
|
if 200 <= resp.status < 300:
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
body = await resp.text()
|
body = await resp.text()
|
||||||
return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"HTTP {resp.status}: {body[:200]}",
|
||||||
|
}
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
return {"success": False, "error": str(e)}
|
return {"success": False, "error": str(e)}
|
||||||
|
return {"success": False, "error": "Rate limited (retries exhausted)"}
|
||||||
|
|
||||||
|
|
||||||
def _split_message(text: str, limit: int) -> list[str]:
|
def _split_message(text: str, limit: int) -> list[str]:
|
||||||
|
|||||||
@@ -3,16 +3,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any, AsyncIterator
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
|
from notify_bridge_core.log_context import bind_log_context, dispatch_id_var
|
||||||
from notify_bridge_core.models.events import ServiceEvent
|
from notify_bridge_core.models.events import ServiceEvent
|
||||||
from notify_bridge_core.templates.context import build_template_context
|
from notify_bridge_core.templates.context import build_template_context
|
||||||
from notify_bridge_core.templates.renderer import render_template
|
from notify_bridge_core.templates.renderer import render_template
|
||||||
from .ssrf import UnsafeURLError, validate_outbound_url
|
from .ssrf import UnsafeURLError, avalidate_outbound_url
|
||||||
|
|
||||||
_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||||
|
|
||||||
@@ -82,9 +85,28 @@ class NotificationDispatcher:
|
|||||||
*,
|
*,
|
||||||
url_cache: TelegramFileCache | None = None,
|
url_cache: TelegramFileCache | None = None,
|
||||||
asset_cache: TelegramFileCache | None = None,
|
asset_cache: TelegramFileCache | None = None,
|
||||||
|
session: aiohttp.ClientSession | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._url_cache = url_cache
|
self._url_cache = url_cache
|
||||||
self._asset_cache = asset_cache
|
self._asset_cache = asset_cache
|
||||||
|
# Optional shared session owned by the caller; when supplied we reuse
|
||||||
|
# its connection pool instead of opening a fresh per-dispatch session
|
||||||
|
# (saves a TLS handshake per outbound call).
|
||||||
|
self._shared_session = session
|
||||||
|
|
||||||
|
@contextlib.asynccontextmanager
|
||||||
|
async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]:
|
||||||
|
"""Yield an aiohttp session, reusing the shared one if provided.
|
||||||
|
|
||||||
|
When a shared session was passed in ``__init__`` we yield it without
|
||||||
|
closing (the caller owns its lifetime). Otherwise we open a
|
||||||
|
short-lived session with our default timeout and close it on exit.
|
||||||
|
"""
|
||||||
|
if self._shared_session is not None and not self._shared_session.closed:
|
||||||
|
yield self._shared_session
|
||||||
|
return
|
||||||
|
async with self._session_ctx() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
async def dispatch(
|
async def dispatch(
|
||||||
self,
|
self,
|
||||||
@@ -95,17 +117,39 @@ class NotificationDispatcher:
|
|||||||
|
|
||||||
Returns list of results (one per target).
|
Returns list of results (one per target).
|
||||||
"""
|
"""
|
||||||
|
# Bind a dispatch_id so every log line emitted by the target sends
|
||||||
|
# (including deep in TelegramClient) can be correlated to the same
|
||||||
|
# upstream event.
|
||||||
|
new_id = dispatch_id_var.get() or f"disp:{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
with bind_log_context(dispatch_id=new_id):
|
||||||
|
_LOGGER.info(
|
||||||
|
"Dispatching event %s (collection=%r) to %d target(s)",
|
||||||
|
event.event_type.value if hasattr(event.event_type, "value") else event.event_type,
|
||||||
|
getattr(event, "collection_name", None), len(targets),
|
||||||
|
)
|
||||||
raw_results = await asyncio.gather(
|
raw_results = await asyncio.gather(
|
||||||
*[self._send_to_target(event, t) for t in targets],
|
*[self._send_to_target(event, t) for t in targets],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
results = []
|
results = []
|
||||||
for raw in raw_results:
|
failures = 0
|
||||||
|
for target, raw in zip(targets, raw_results):
|
||||||
if isinstance(raw, Exception):
|
if isinstance(raw, Exception):
|
||||||
_LOGGER.error("Failed to dispatch to target: %s", raw)
|
failures += 1
|
||||||
|
_LOGGER.error(
|
||||||
|
"Dispatch to target type=%s failed: %s",
|
||||||
|
target.type, raw, exc_info=raw,
|
||||||
|
)
|
||||||
results.append({"success": False, "error": str(raw)})
|
results.append({"success": False, "error": str(raw)})
|
||||||
else:
|
else:
|
||||||
|
if isinstance(raw, dict) and not raw.get("success"):
|
||||||
|
failures += 1
|
||||||
results.append(raw)
|
results.append(raw)
|
||||||
|
_LOGGER.info(
|
||||||
|
"Dispatch finished: %d target(s), %d failure(s)",
|
||||||
|
len(targets), failures,
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _resolve_template(
|
def _resolve_template(
|
||||||
@@ -284,7 +328,7 @@ class NotificationDispatcher:
|
|||||||
media_assets.append(asset)
|
media_assets.append(asset)
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
async with _new_session() as session:
|
async with self._session_ctx() as session:
|
||||||
# Preload all asset bytes once so (a) TelegramClient can skip its
|
# Preload all asset bytes once so (a) TelegramClient can skip its
|
||||||
# own download and (b) we know exact upload sizes in time for the
|
# own download and (b) we know exact upload sizes in time for the
|
||||||
# oversize warning in the rendered text.
|
# oversize warning in the rendered text.
|
||||||
@@ -354,13 +398,13 @@ class NotificationDispatcher:
|
|||||||
return {"success": False, "error": "No receivers configured"}
|
return {"success": False, "error": "No receivers configured"}
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
async with _new_session() as session:
|
async with self._session_ctx() as session:
|
||||||
for receiver in target.receivers:
|
for receiver in target.receivers:
|
||||||
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
|
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
|
||||||
results.append({"success": False, "error": "Invalid webhook receiver"})
|
results.append({"success": False, "error": "Invalid webhook receiver"})
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
validate_outbound_url(receiver.url)
|
await avalidate_outbound_url(receiver.url)
|
||||||
except UnsafeURLError as err:
|
except UnsafeURLError as err:
|
||||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||||
continue
|
continue
|
||||||
@@ -428,14 +472,14 @@ class NotificationDispatcher:
|
|||||||
username = target.config.get("username")
|
username = target.config.get("username")
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
async with _new_session() as session:
|
async with self._session_ctx() as session:
|
||||||
client = DiscordClient(session)
|
client = DiscordClient(session)
|
||||||
for receiver in target.receivers:
|
for receiver in target.receivers:
|
||||||
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
|
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
|
||||||
results.append({"success": False, "error": "Invalid discord receiver"})
|
results.append({"success": False, "error": "Invalid discord receiver"})
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
validate_outbound_url(receiver.webhook_url)
|
await avalidate_outbound_url(receiver.webhook_url)
|
||||||
except UnsafeURLError as err:
|
except UnsafeURLError as err:
|
||||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||||
continue
|
continue
|
||||||
@@ -454,14 +498,14 @@ class NotificationDispatcher:
|
|||||||
username = target.config.get("username")
|
username = target.config.get("username")
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
async with _new_session() as session:
|
async with self._session_ctx() as session:
|
||||||
client = SlackClient(session)
|
client = SlackClient(session)
|
||||||
for receiver in target.receivers:
|
for receiver in target.receivers:
|
||||||
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
|
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
|
||||||
results.append({"success": False, "error": "Invalid slack receiver"})
|
results.append({"success": False, "error": "Invalid slack receiver"})
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
validate_outbound_url(receiver.webhook_url)
|
await avalidate_outbound_url(receiver.webhook_url)
|
||||||
except UnsafeURLError as err:
|
except UnsafeURLError as err:
|
||||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||||
continue
|
continue
|
||||||
@@ -480,14 +524,14 @@ class NotificationDispatcher:
|
|||||||
if not target.receivers:
|
if not target.receivers:
|
||||||
return {"success": False, "error": "No receivers configured"}
|
return {"success": False, "error": "No receivers configured"}
|
||||||
try:
|
try:
|
||||||
validate_outbound_url(server_url)
|
await avalidate_outbound_url(server_url)
|
||||||
except UnsafeURLError as err:
|
except UnsafeURLError as err:
|
||||||
return {"success": False, "error": f"Unsafe ntfy server_url: {err}"}
|
return {"success": False, "error": f"Unsafe ntfy server_url: {err}"}
|
||||||
|
|
||||||
title = f"{event.event_type.value}: {event.collection_name}"
|
title = f"{event.event_type.value}: {event.collection_name}"
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
async with _new_session() as session:
|
async with self._session_ctx() as session:
|
||||||
client = NtfyClient(session)
|
client = NtfyClient(session)
|
||||||
for receiver in target.receivers:
|
for receiver in target.receivers:
|
||||||
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
|
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
|
||||||
@@ -511,7 +555,7 @@ class NotificationDispatcher:
|
|||||||
if not homeserver or not access_token:
|
if not homeserver or not access_token:
|
||||||
return {"success": False, "error": "Missing Matrix homeserver_url or access_token"}
|
return {"success": False, "error": "Missing Matrix homeserver_url or access_token"}
|
||||||
try:
|
try:
|
||||||
validate_outbound_url(homeserver)
|
await avalidate_outbound_url(homeserver)
|
||||||
except UnsafeURLError as err:
|
except UnsafeURLError as err:
|
||||||
return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"}
|
return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"}
|
||||||
|
|
||||||
@@ -519,7 +563,7 @@ class NotificationDispatcher:
|
|||||||
return {"success": False, "error": "No receivers configured"}
|
return {"success": False, "error": "No receivers configured"}
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
async with _new_session() as session:
|
async with self._session_ctx() as session:
|
||||||
client = MatrixClient(session, homeserver, access_token)
|
client = MatrixClient(session, homeserver, access_token)
|
||||||
for receiver in target.receivers:
|
for receiver in target.receivers:
|
||||||
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
|
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
|
||||||
|
|||||||
@@ -68,7 +68,9 @@ class MatrixClient:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._session.put(url, json=body, headers=headers) as resp:
|
async with self._session.put(
|
||||||
|
url, json=body, headers=headers, allow_redirects=False,
|
||||||
|
) as resp:
|
||||||
if 200 <= resp.status < 300:
|
if 200 <= resp.status < 300:
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
resp_body = await resp.text()
|
resp_body = await resp.text()
|
||||||
|
|||||||
@@ -51,7 +51,9 @@ class NtfyClient:
|
|||||||
headers["Authorization"] = f"Bearer {auth_token}"
|
headers["Authorization"] = f"Bearer {auth_token}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._session.post(url, json=payload, headers=headers) as resp:
|
async with self._session.post(
|
||||||
|
url, json=payload, headers=headers, allow_redirects=False,
|
||||||
|
) as resp:
|
||||||
if 200 <= resp.status < 300:
|
if 200 <= resp.status < 300:
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
body = await resp.text()
|
body = await resp.text()
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class SlackClient:
|
|||||||
webhook_url,
|
webhook_url,
|
||||||
json=payload,
|
json=payload,
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
|
allow_redirects=False,
|
||||||
) as resp:
|
) as resp:
|
||||||
if resp.status == 429:
|
if resp.status == 429:
|
||||||
_LOGGER.warning("Slack rate limited")
|
_LOGGER.warning("Slack rate limited")
|
||||||
|
|||||||
@@ -12,14 +12,25 @@ development against localhost services.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
_ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1"
|
_ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1"
|
||||||
_ALLOWED_SCHEMES = {"http", "https"}
|
_ALLOWED_SCHEMES = {"http", "https"}
|
||||||
|
|
||||||
|
if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner
|
||||||
|
_LOGGER.warning(
|
||||||
|
"SSRF guard: private-URL bypass ENABLED "
|
||||||
|
"(NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1). Requests to RFC1918 / "
|
||||||
|
"loopback / link-local hosts will be permitted."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UnsafeURLError(ValueError):
|
class UnsafeURLError(ValueError):
|
||||||
"""Raised when a URL targets a disallowed network destination."""
|
"""Raised when a URL targets a disallowed network destination."""
|
||||||
@@ -36,13 +47,7 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_outbound_url(url: str) -> str:
|
def _check_scheme_host(url: str) -> tuple[str, str]:
|
||||||
"""Validate ``url`` is safe to fetch; returns the URL on success.
|
|
||||||
|
|
||||||
Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
|
|
||||||
is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
|
|
||||||
private addresses are permitted but the scheme check still applies.
|
|
||||||
"""
|
|
||||||
if not isinstance(url, str) or not url:
|
if not isinstance(url, str) or not url:
|
||||||
raise UnsafeURLError("URL is empty")
|
raise UnsafeURLError("URL is empty")
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
@@ -51,6 +56,31 @@ def validate_outbound_url(url: str) -> str:
|
|||||||
host = parsed.hostname
|
host = parsed.hostname
|
||||||
if not host:
|
if not host:
|
||||||
raise UnsafeURLError("URL has no host")
|
raise UnsafeURLError("URL has no host")
|
||||||
|
return parsed.scheme, host
|
||||||
|
|
||||||
|
|
||||||
|
def _check_resolved_addresses(host: str, infos: list[tuple]) -> None:
|
||||||
|
for info in infos:
|
||||||
|
sockaddr = info[4]
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(sockaddr[0])
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if _is_blocked_ip(ip):
|
||||||
|
raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_outbound_url(url: str) -> str:
|
||||||
|
"""Validate ``url`` is safe to fetch; returns the URL on success.
|
||||||
|
|
||||||
|
Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
|
||||||
|
is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
|
||||||
|
private addresses are permitted but the scheme check still applies.
|
||||||
|
|
||||||
|
Synchronous; uses blocking ``socket.getaddrinfo``. Prefer
|
||||||
|
:func:`avalidate_outbound_url` from async code paths.
|
||||||
|
"""
|
||||||
|
_, host = _check_scheme_host(url)
|
||||||
|
|
||||||
if _ALLOW_PRIVATE:
|
if _ALLOW_PRIVATE:
|
||||||
return url
|
return url
|
||||||
@@ -64,17 +94,37 @@ def validate_outbound_url(url: str) -> str:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Hostname — resolve and reject if any resolution is in a blocked range.
|
|
||||||
try:
|
try:
|
||||||
infos = socket.getaddrinfo(host, None)
|
infos = socket.getaddrinfo(host, None)
|
||||||
except socket.gaierror as exc:
|
except socket.gaierror as exc:
|
||||||
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
|
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
|
||||||
for info in infos:
|
_check_resolved_addresses(host, infos)
|
||||||
sockaddr = info[4]
|
return url
|
||||||
try:
|
|
||||||
ip = ipaddress.ip_address(sockaddr[0])
|
|
||||||
except ValueError:
|
async def avalidate_outbound_url(url: str) -> str:
|
||||||
continue
|
"""Async variant that resolves DNS via the running loop's resolver.
|
||||||
if _is_blocked_ip(ip):
|
|
||||||
raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
|
Use this from ``async def`` code paths to avoid blocking the event
|
||||||
|
loop on DNS lookups.
|
||||||
|
"""
|
||||||
|
_, host = _check_scheme_host(url)
|
||||||
|
|
||||||
|
if _ALLOW_PRIVATE:
|
||||||
|
return url
|
||||||
|
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(host)
|
||||||
|
if _is_blocked_ip(ip):
|
||||||
|
raise UnsafeURLError(f"Host {host} is in a blocked range")
|
||||||
|
return url
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
try:
|
||||||
|
infos = await loop.getaddrinfo(host, None)
|
||||||
|
except socket.gaierror as exc:
|
||||||
|
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
|
||||||
|
_check_resolved_addresses(host, infos)
|
||||||
return url
|
return url
|
||||||
|
|||||||
@@ -162,8 +162,20 @@ class TelegramClient:
|
|||||||
"message_id": result.get("result", {}).get("message_id"),
|
"message_id": result.get("result", {}).get("message_id"),
|
||||||
"cached": True,
|
"cached": True,
|
||||||
}
|
}
|
||||||
except aiohttp.ClientError:
|
# Non-ok from a cached send — file_id stale or file deleted on
|
||||||
pass
|
# Telegram's side. Log at DEBUG so operators who are hunting
|
||||||
|
# "why didn't the cached send work?" can see it, but the
|
||||||
|
# caller will fall through to a fresh upload.
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Telegram %s (cached) returned non-ok: status=%s code=%s desc=%r — falling back to fresh upload",
|
||||||
|
kind.api_method, response.status, result.get("error_code"),
|
||||||
|
result.get("description"),
|
||||||
|
)
|
||||||
|
except aiohttp.ClientError as err:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Telegram %s (cached) transport error — falling back to fresh upload: %s",
|
||||||
|
kind.api_method, err,
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _upload_media(
|
async def _upload_media(
|
||||||
@@ -203,8 +215,17 @@ class TelegramClient:
|
|||||||
thumbhash=thumbhash, size=len(data),
|
thumbhash=thumbhash, size=len(data),
|
||||||
)
|
)
|
||||||
return {"success": True, "message_id": res.get("message_id")}
|
return {"success": True, "message_id": res.get("message_id")}
|
||||||
|
_LOGGER.error(
|
||||||
|
"Telegram %s failed: status=%s code=%s desc=%r bytes=%d",
|
||||||
|
kind.api_method, response.status, result.get("error_code"),
|
||||||
|
result.get("description", "Unknown"), len(data),
|
||||||
|
)
|
||||||
return {"success": False, "error": result.get("description", "Unknown Telegram error")}
|
return {"success": False, "error": result.get("description", "Unknown Telegram error")}
|
||||||
except aiohttp.ClientError as err:
|
except aiohttp.ClientError as err:
|
||||||
|
_LOGGER.error(
|
||||||
|
"Telegram %s transport error (bytes=%d): %s",
|
||||||
|
kind.api_method, len(data), err, exc_info=True,
|
||||||
|
)
|
||||||
return {"success": False, "error": str(err)}
|
return {"success": False, "error": str(err)}
|
||||||
|
|
||||||
async def send_notification(
|
async def send_notification(
|
||||||
@@ -327,8 +348,14 @@ class TelegramClient:
|
|||||||
retry_result = await retry_resp.json()
|
retry_result = await retry_resp.json()
|
||||||
if retry_resp.status == 200 and retry_result.get("ok"):
|
if retry_resp.status == 200 and retry_result.get("ok"):
|
||||||
return {"success": True, "message_id": retry_result.get("result", {}).get("message_id")}
|
return {"success": True, "message_id": retry_result.get("result", {}).get("message_id")}
|
||||||
|
_LOGGER.error(
|
||||||
|
"Telegram sendMessage failed: status=%s code=%s desc=%r",
|
||||||
|
response.status, result.get("error_code"),
|
||||||
|
result.get("description", "Unknown"),
|
||||||
|
)
|
||||||
return {"success": False, "error": result.get("description", "Unknown Telegram error"), "error_code": result.get("error_code")}
|
return {"success": False, "error": result.get("description", "Unknown Telegram error"), "error_code": result.get("error_code")}
|
||||||
except aiohttp.ClientError as err:
|
except aiohttp.ClientError as err:
|
||||||
|
_LOGGER.error("Telegram sendMessage transport error: %s", err, exc_info=True)
|
||||||
return {"success": False, "error": str(err)}
|
return {"success": False, "error": str(err)}
|
||||||
|
|
||||||
async def send_chat_action(self, chat_id: str, action: str = "typing") -> bool:
|
async def send_chat_action(self, chat_id: str, action: str = "typing") -> bool:
|
||||||
@@ -513,11 +540,14 @@ class TelegramClient:
|
|||||||
# Tuple is (cache_key, media_type, thumbhash, uploaded_size).
|
# Tuple is (cache_key, media_type, thumbhash, uploaded_size).
|
||||||
media_cache_info: list[tuple[str, str, str | None, int] | None] = []
|
media_cache_info: list[tuple[str, str, str | None, int] | None] = []
|
||||||
|
|
||||||
# Resolve cache hits and collect download tasks in parallel
|
# Resolve cache hits and collect download tasks in parallel.
|
||||||
|
# Each drop site logs the reason — otherwise a filtered asset
|
||||||
|
# disappears silently and the media group silently shrinks.
|
||||||
async def _fetch_asset(idx: int, item: dict) -> tuple[int, dict | None, bytes | None]:
|
async def _fetch_asset(idx: int, item: dict) -> tuple[int, dict | None, bytes | None]:
|
||||||
"""Return (index, cache_entry_or_None, downloaded_bytes_or_None)."""
|
"""Return (index, cache_entry_or_None, downloaded_bytes_or_None)."""
|
||||||
url = item.get("url")
|
url = item.get("url")
|
||||||
if not url:
|
if not url:
|
||||||
|
_LOGGER.warning("Media skipped: missing url (idx=%d type=%s)", idx, item.get("type"))
|
||||||
return idx, None, None
|
return idx, None, None
|
||||||
media_type = item.get("type", "photo")
|
media_type = item.get("type", "photo")
|
||||||
custom_cache_key = item.get("cache_key")
|
custom_cache_key = item.get("cache_key")
|
||||||
@@ -537,12 +567,24 @@ class TelegramClient:
|
|||||||
if preloaded is not None:
|
if preloaded is not None:
|
||||||
data = preloaded
|
data = preloaded
|
||||||
if max_asset_data_size and len(data) > max_asset_data_size:
|
if max_asset_data_size and len(data) > max_asset_data_size:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Media skipped: preloaded size %d exceeds max_asset_data_size %d (idx=%d type=%s url=%s)",
|
||||||
|
len(data), max_asset_data_size, idx, media_type, url,
|
||||||
|
)
|
||||||
return idx, None, None
|
return idx, None, None
|
||||||
if media_type == "video" and len(data) > TELEGRAM_MAX_VIDEO_SIZE:
|
if media_type == "video" and len(data) > TELEGRAM_MAX_VIDEO_SIZE:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Media skipped: preloaded video %d bytes exceeds Telegram limit %d (idx=%d url=%s)",
|
||||||
|
len(data), TELEGRAM_MAX_VIDEO_SIZE, idx, url,
|
||||||
|
)
|
||||||
return idx, None, None
|
return idx, None, None
|
||||||
if media_type == "photo":
|
if media_type == "photo":
|
||||||
exceeds, _, _, _ = check_photo_limits(data)
|
exceeds, reason, _, _ = check_photo_limits(data)
|
||||||
if exceeds:
|
if exceeds:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Media skipped: preloaded photo %s (idx=%d url=%s)",
|
||||||
|
reason, idx, url,
|
||||||
|
)
|
||||||
return idx, None, None
|
return idx, None, None
|
||||||
return idx, None, data
|
return idx, None, data
|
||||||
|
|
||||||
@@ -551,18 +593,38 @@ class TelegramClient:
|
|||||||
dl_headers = item.get("headers") or {}
|
dl_headers = item.get("headers") or {}
|
||||||
async with self._session.get(download_url, headers=dl_headers) as resp:
|
async with self._session.get(download_url, headers=dl_headers) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Media skipped: download HTTP %d (idx=%d type=%s url=%s)",
|
||||||
|
resp.status, idx, media_type, url,
|
||||||
|
)
|
||||||
return idx, None, None
|
return idx, None, None
|
||||||
data = await resp.read()
|
data = await resp.read()
|
||||||
if max_asset_data_size and len(data) > max_asset_data_size:
|
if max_asset_data_size and len(data) > max_asset_data_size:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Media skipped: downloaded size %d exceeds max_asset_data_size %d (idx=%d type=%s url=%s)",
|
||||||
|
len(data), max_asset_data_size, idx, media_type, url,
|
||||||
|
)
|
||||||
return idx, None, None
|
return idx, None, None
|
||||||
if media_type == "video" and len(data) > TELEGRAM_MAX_VIDEO_SIZE:
|
if media_type == "video" and len(data) > TELEGRAM_MAX_VIDEO_SIZE:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Media skipped: video %d bytes exceeds Telegram %d-byte limit (idx=%d url=%s)",
|
||||||
|
len(data), TELEGRAM_MAX_VIDEO_SIZE, idx, url,
|
||||||
|
)
|
||||||
return idx, None, None
|
return idx, None, None
|
||||||
if media_type == "photo":
|
if media_type == "photo":
|
||||||
exceeds, _, _, _ = check_photo_limits(data)
|
exceeds, reason, _, _ = check_photo_limits(data)
|
||||||
if exceeds:
|
if exceeds:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Media skipped: photo %s (idx=%d url=%s)",
|
||||||
|
reason, idx, url,
|
||||||
|
)
|
||||||
return idx, None, None
|
return idx, None, None
|
||||||
return idx, None, data
|
return idx, None, data
|
||||||
except aiohttp.ClientError:
|
except aiohttp.ClientError as err:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Media skipped: download failed (idx=%d type=%s url=%s): %s",
|
||||||
|
idx, media_type, url, err,
|
||||||
|
)
|
||||||
return idx, None, None
|
return idx, None, None
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
@@ -602,6 +664,14 @@ class TelegramClient:
|
|||||||
media_json.append(mij)
|
media_json.append(mij)
|
||||||
|
|
||||||
if not media_json:
|
if not media_json:
|
||||||
|
# Every asset in this chunk was filtered out (size, download
|
||||||
|
# failure, etc.). Without this log, sendMediaGroup returns
|
||||||
|
# success=True with zero message_ids and nobody knows why
|
||||||
|
# the user sees only the text reply and no media.
|
||||||
|
_LOGGER.warning(
|
||||||
|
"sendMediaGroup skipped — chunk %d/%d had %d input items but 0 usable (all filtered/failed)",
|
||||||
|
chunk_idx + 1, len(chunks), len(chunk),
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
form.add_field("media", json.dumps(media_json))
|
form.add_field("media", json.dumps(media_json))
|
||||||
@@ -638,10 +708,35 @@ class TelegramClient:
|
|||||||
if eff_cache:
|
if eff_cache:
|
||||||
await eff_cache.async_set_many(cache_entries)
|
await eff_cache.async_set_many(cache_entries)
|
||||||
else:
|
else:
|
||||||
return {"success": False, "error": result.get("description", "Unknown"), "failed_at_chunk": chunk_idx + 1}
|
_LOGGER.error(
|
||||||
|
"Telegram sendMediaGroup failed: status=%s code=%s desc=%r chunk=%d/%d items=%d",
|
||||||
|
response.status, result.get("error_code"),
|
||||||
|
result.get("description", "Unknown"),
|
||||||
|
chunk_idx + 1, len(chunks), len(media_json),
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": result.get("description", "Unknown"),
|
||||||
|
"error_code": result.get("error_code"),
|
||||||
|
"failed_at_chunk": chunk_idx + 1,
|
||||||
|
}
|
||||||
except aiohttp.ClientError as err:
|
except aiohttp.ClientError as err:
|
||||||
|
_LOGGER.error(
|
||||||
|
"Telegram sendMediaGroup transport error on chunk %d/%d (%d items): %s",
|
||||||
|
chunk_idx + 1, len(chunks), len(media_json), err,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
return {"success": False, "error": str(err), "failed_at_chunk": chunk_idx + 1}
|
return {"success": False, "error": str(err), "failed_at_chunk": chunk_idx + 1}
|
||||||
|
|
||||||
|
# Distinguish "posted something" from "posted nothing" so the caller
|
||||||
|
# can surface an ERROR when a command produced a caption reply but no
|
||||||
|
# media ever reached Telegram.
|
||||||
|
if not all_message_ids:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"sendMediaGroup completed with 0 message_ids across %d chunk(s) — nothing was delivered",
|
||||||
|
len(chunks),
|
||||||
|
)
|
||||||
|
return {"success": False, "error": "no_items_delivered", "chunks_sent": len(chunks)}
|
||||||
return {"success": True, "message_ids": all_message_ids, "chunks_sent": len(chunks)}
|
return {"success": True, "message_ids": all_message_ids, "chunks_sent": len(chunks)}
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from typing import Any
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from ..ssrf import UnsafeURLError, validate_outbound_url
|
from ..ssrf import UnsafeURLError, avalidate_outbound_url
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -24,7 +24,7 @@ class WebhookClient:
|
|||||||
|
|
||||||
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
|
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
validate_outbound_url(self._url)
|
await avalidate_outbound_url(self._url)
|
||||||
except UnsafeURLError as err:
|
except UnsafeURLError as err:
|
||||||
return {"success": False, "error": f"Unsafe URL: {err}"}
|
return {"success": False, "error": f"Unsafe URL: {err}"}
|
||||||
try:
|
try:
|
||||||
@@ -33,6 +33,7 @@ class WebhookClient:
|
|||||||
json=payload,
|
json=payload,
|
||||||
headers={"Content-Type": "application/json", **self._headers},
|
headers={"Content-Type": "application/json", **self._headers},
|
||||||
timeout=_DEFAULT_TIMEOUT,
|
timeout=_DEFAULT_TIMEOUT,
|
||||||
|
allow_redirects=False,
|
||||||
) as response:
|
) as response:
|
||||||
if 200 <= response.status < 300:
|
if 200 <= response.status < 300:
|
||||||
return {"success": True, "status_code": response.status}
|
return {"success": True, "status_code": response.status}
|
||||||
|
|||||||
@@ -177,7 +177,9 @@ class ImmichActionExecutor(ActionExecutor):
|
|||||||
needs_thumbnail = album_id in album_created_now
|
needs_thumbnail = album_id in album_created_now
|
||||||
|
|
||||||
if album_id and album_id != "__dry_run_new__":
|
if album_id and album_id != "__dry_run_new__":
|
||||||
album = await self._client.get_album(album_id)
|
# Actions diff the current album state to decide what to
|
||||||
|
# add — must observe fresh data, not a cached view.
|
||||||
|
album = await self._client.get_album(album_id, use_cache=False)
|
||||||
if album is None and create_if_missing and create_album_name:
|
if album is None and create_if_missing and create_album_name:
|
||||||
if not dry_run:
|
if not dry_run:
|
||||||
created = await self._client.create_album(create_album_name)
|
created = await self._client.create_album(create_album_name)
|
||||||
|
|||||||
@@ -2,8 +2,11 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@@ -18,6 +21,51 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
MAX_SEARCH_QUERY_LEN = 256
|
MAX_SEARCH_QUERY_LEN = 256
|
||||||
MAX_SEARCH_PERSON_IDS = 50
|
MAX_SEARCH_PERSON_IDS = 50
|
||||||
|
|
||||||
|
# Module-level TTL caches for album bodies and shared-link listings. The
|
||||||
|
# Immich ``GET /api/albums/{id}`` response can be tens or hundreds of MB on a
|
||||||
|
# large album, and bot commands like /random, /latest, /memory all refetch
|
||||||
|
# the same album in quick succession. A short TTL makes repeat runs nearly
|
||||||
|
# instant and deduplicates concurrent fetches so a burst of commands issues
|
||||||
|
# one HTTP call instead of N.
|
||||||
|
#
|
||||||
|
# Caches are module-scoped (not instance-scoped) because ``ImmichClient`` is
|
||||||
|
# constructed fresh per request in several places (api/providers.py,
|
||||||
|
# services/action_runner.py, command handlers), so an instance cache would
|
||||||
|
# never survive to serve a second caller. This mirrors ``_users_cache`` in
|
||||||
|
# ``provider.py``.
|
||||||
|
_ALBUM_CACHE_TTL_SECONDS = 60
|
||||||
|
_SHARED_LINKS_CACHE_TTL_SECONDS = 60
|
||||||
|
# Guard rail against runaway memory — a 200k-asset album response can be
|
||||||
|
# ~150 MB, so even modest caps bound the worst case.
|
||||||
|
_ALBUM_CACHE_MAX_ENTRIES = 32
|
||||||
|
_album_cache_lock = asyncio.Lock()
|
||||||
|
# key = (server_digest, album_id); value = (monotonic_ts, raw_api_dict)
|
||||||
|
# Store the raw dict rather than the parsed ``ImmichAlbumData`` so callers
|
||||||
|
# that pass a ``users_cache`` still get owner-name enrichment on cache hits.
|
||||||
|
_album_cache: dict[tuple[str, str], tuple[float, dict[str, Any]]] = {}
|
||||||
|
_shared_links_cache_lock = asyncio.Lock()
|
||||||
|
# key = server_digest; value = (monotonic_ts, {album_id: [SharedLinkInfo, ...]})
|
||||||
|
# The underlying ``/api/shared-links`` endpoint has no per-album filter, so
|
||||||
|
# every call was already paying for the full server-wide list. Caching the
|
||||||
|
# bucketed result once per server turns N per-album calls into one fetch.
|
||||||
|
_shared_links_cache: dict[str, tuple[float, dict[str, list[SharedLinkInfo]]]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _server_digest(url: str, api_key: str) -> str:
|
||||||
|
"""Hashed key that avoids putting raw api_key into cache dict keys."""
|
||||||
|
return hashlib.sha256(f"{url}|{api_key}".encode("utf-8")).hexdigest()[:32]
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate_album_cache() -> None:
|
||||||
|
"""Drop every cached album body. Call after mutations that invalidate
|
||||||
|
the cached view (e.g. integration tests, manual /refresh commands)."""
|
||||||
|
_album_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate_shared_links_cache() -> None:
|
||||||
|
"""Drop every cached shared-link listing."""
|
||||||
|
_shared_links_cache.clear()
|
||||||
|
|
||||||
# User-facing error bodies — Immich responses may leak internal paths,
|
# User-facing error bodies — Immich responses may leak internal paths,
|
||||||
# hostnames, or headers injected by intermediary proxies. These helpers keep
|
# hostnames, or headers injected by intermediary proxies. These helpers keep
|
||||||
# only a short, scrubbed summary; full bodies are logged server-side only.
|
# only a short, scrubbed summary; full bodies are logged server-side only.
|
||||||
@@ -184,22 +232,30 @@ class ImmichClient:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def get_shared_links(self, album_id: str) -> list[SharedLinkInfo]:
|
async def get_shared_links(self, album_id: str) -> list[SharedLinkInfo]:
|
||||||
links: list[SharedLinkInfo] = []
|
bucketed = await self._get_shared_links_bucketed()
|
||||||
try:
|
return list(bucketed.get(album_id, []))
|
||||||
async with self._session.get(
|
|
||||||
f"{self._url}/api/shared-links",
|
async def _get_shared_links_bucketed(self) -> dict[str, list[SharedLinkInfo]]:
|
||||||
headers=self._headers,
|
"""Return ``{album_id: [SharedLinkInfo, ...]}`` for the server, hitting
|
||||||
) as response:
|
the module-level TTL cache first. Underlying Immich endpoint has no
|
||||||
if response.status == 200:
|
per-album filter, so one server-wide fetch serves every caller until
|
||||||
data = await response.json()
|
the TTL elapses.
|
||||||
for link in data:
|
"""
|
||||||
album = link.get("album")
|
digest = _server_digest(self._url, self._api_key)
|
||||||
key = link.get("key")
|
now = time.monotonic()
|
||||||
if album and key and album.get("id") == album_id:
|
entry = _shared_links_cache.get(digest)
|
||||||
links.append(SharedLinkInfo.from_api_response(link))
|
if entry is not None and (now - entry[0]) < _SHARED_LINKS_CACHE_TTL_SECONDS:
|
||||||
except aiohttp.ClientError as err:
|
return entry[1]
|
||||||
_LOGGER.warning("Failed to fetch shared links: %s", err)
|
|
||||||
return links
|
async with _shared_links_cache_lock:
|
||||||
|
# Re-check under the lock — another coroutine may have refreshed
|
||||||
|
# while we waited.
|
||||||
|
entry = _shared_links_cache.get(digest)
|
||||||
|
if entry is not None and (time.monotonic() - entry[0]) < _SHARED_LINKS_CACHE_TTL_SECONDS:
|
||||||
|
return entry[1]
|
||||||
|
fresh = await self.get_all_shared_links_by_album()
|
||||||
|
_shared_links_cache[digest] = (time.monotonic(), fresh)
|
||||||
|
return fresh
|
||||||
|
|
||||||
async def get_all_shared_links_by_album(self) -> dict[str, list[SharedLinkInfo]]:
|
async def get_all_shared_links_by_album(self) -> dict[str, list[SharedLinkInfo]]:
|
||||||
"""Fetch every shared link on the server, bucketed by album id.
|
"""Fetch every shared link on the server, bucketed by album id.
|
||||||
@@ -247,7 +303,29 @@ class ImmichClient:
|
|||||||
self,
|
self,
|
||||||
album_id: str,
|
album_id: str,
|
||||||
users_cache: dict[str, str] | None = None,
|
users_cache: dict[str, str] | None = None,
|
||||||
|
*,
|
||||||
|
use_cache: bool = True,
|
||||||
) -> ImmichAlbumData | None:
|
) -> ImmichAlbumData | None:
|
||||||
|
"""Fetch an album by id, optionally serving from the module-level
|
||||||
|
TTL cache. Pass ``use_cache=False`` from paths that must observe the
|
||||||
|
current server state (e.g. the notification poll loop's full-fetch
|
||||||
|
path, where a stale cached entry would delay asset-removal events).
|
||||||
|
Non-cached fetches still populate the cache for subsequent readers.
|
||||||
|
"""
|
||||||
|
cache_key = (_server_digest(self._url, self._api_key), album_id)
|
||||||
|
if use_cache:
|
||||||
|
entry = _album_cache.get(cache_key)
|
||||||
|
if entry is not None and (time.monotonic() - entry[0]) < _ALBUM_CACHE_TTL_SECONDS:
|
||||||
|
# Rehydrate per-call so ``users_cache`` enrichment is applied
|
||||||
|
# with the caller's dict, not whichever one was live when the
|
||||||
|
# cache was populated.
|
||||||
|
return ImmichAlbumData.from_api_response(entry[1], users_cache)
|
||||||
|
|
||||||
|
# Deliberately fetch without holding a lock so concurrent calls for
|
||||||
|
# *different* album_ids (the common case from asyncio.gather in
|
||||||
|
# fetch_albums_with_links) stay parallel. The worst case is a small
|
||||||
|
# duplicate-fetch stampede when two requests miss the same album at
|
||||||
|
# the same instant — acceptable for our scale.
|
||||||
try:
|
try:
|
||||||
async with self._session.get(
|
async with self._session.get(
|
||||||
f"{self._url}/api/albums/{album_id}",
|
f"{self._url}/api/albums/{album_id}",
|
||||||
@@ -260,10 +338,18 @@ class ImmichClient:
|
|||||||
f"Error fetching album {album_id}: HTTP {response.status}"
|
f"Error fetching album {album_id}: HTTP {response.status}"
|
||||||
)
|
)
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
return ImmichAlbumData.from_api_response(data, users_cache)
|
|
||||||
except aiohttp.ClientError as err:
|
except aiohttp.ClientError as err:
|
||||||
raise ImmichApiError(f"Error communicating with Immich: {err}") from err
|
raise ImmichApiError(f"Error communicating with Immich: {err}") from err
|
||||||
|
|
||||||
|
async with _album_cache_lock:
|
||||||
|
# Evict the oldest entry if we're at the cap — simple FIFO is fine
|
||||||
|
# for our access pattern (commands touch a small working set).
|
||||||
|
if len(_album_cache) >= _ALBUM_CACHE_MAX_ENTRIES and cache_key not in _album_cache:
|
||||||
|
oldest = min(_album_cache.items(), key=lambda kv: kv[1][0])[0]
|
||||||
|
_album_cache.pop(oldest, None)
|
||||||
|
_album_cache[cache_key] = (time.monotonic(), data)
|
||||||
|
return ImmichAlbumData.from_api_response(data, users_cache)
|
||||||
|
|
||||||
async def get_album_meta(self, album_id: str) -> ImmichAlbumMeta | None:
|
async def get_album_meta(self, album_id: str) -> ImmichAlbumMeta | None:
|
||||||
"""Fetch album metadata without the assets array.
|
"""Fetch album metadata without the assets array.
|
||||||
|
|
||||||
|
|||||||
@@ -292,7 +292,13 @@ class ImmichServiceProvider(ServiceProvider):
|
|||||||
# the full-fetch path so removals get detected.
|
# the full-fetch path so removals get detected.
|
||||||
|
|
||||||
# Full fetch: first tick, or count-decreased, or delta-unsafe.
|
# Full fetch: first tick, or count-decreased, or delta-unsafe.
|
||||||
album = await self._client.get_album(album_id, self._users_cache)
|
# Bypass the module-level album cache — this path runs when we
|
||||||
|
# specifically need the current server state (e.g. to detect
|
||||||
|
# asset removals), so a stale cached entry would silently delay
|
||||||
|
# the event.
|
||||||
|
album = await self._client.get_album(
|
||||||
|
album_id, self._users_cache, use_cache=False,
|
||||||
|
)
|
||||||
if album is None:
|
if album is None:
|
||||||
# Album was deleted between meta probe and full fetch — handle
|
# Album was deleted between meta probe and full fetch — handle
|
||||||
# the deletion the same way as above.
|
# the deletion the same way as above.
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
|
|
||||||
_DEFAULT_PORT = 3493
|
_DEFAULT_PORT = 3493
|
||||||
_READ_TIMEOUT = 10.0
|
_READ_TIMEOUT = 10.0
|
||||||
|
_WRITE_TIMEOUT = 10.0
|
||||||
_CONNECT_TIMEOUT = 5.0
|
_CONNECT_TIMEOUT = 5.0
|
||||||
|
|
||||||
# Allowed characters for NUT protocol identifiers (UPS names, variable names).
|
# Allowed characters for NUT protocol identifiers (UPS names, variable names).
|
||||||
@@ -84,14 +85,26 @@ class NutClient:
|
|||||||
await self._command(f"PASSWORD {self._password}")
|
await self._command(f"PASSWORD {self._password}")
|
||||||
|
|
||||||
async def disconnect(self) -> None:
|
async def disconnect(self) -> None:
|
||||||
"""Send LOGOUT and close the TCP connection."""
|
"""Send LOGOUT and close the TCP connection.
|
||||||
|
|
||||||
|
``drain`` is bounded by ``_WRITE_TIMEOUT`` so a half-closed peer
|
||||||
|
cannot hold the disconnect indefinitely — a tracker tick would
|
||||||
|
otherwise be pinned by a stuck NUT server and block the scheduler
|
||||||
|
slot (``max_instances=1``).
|
||||||
|
"""
|
||||||
if self._writer is not None:
|
if self._writer is not None:
|
||||||
try:
|
try:
|
||||||
self._writer.write(b"LOGOUT\n")
|
self._writer.write(b"LOGOUT\n")
|
||||||
await self._writer.drain()
|
await asyncio.wait_for(self._writer.drain(), timeout=_WRITE_TIMEOUT)
|
||||||
except OSError:
|
except (OSError, asyncio.TimeoutError):
|
||||||
pass
|
pass
|
||||||
self._writer.close()
|
self._writer.close()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._writer.wait_closed(), timeout=_WRITE_TIMEOUT,
|
||||||
|
)
|
||||||
|
except (OSError, asyncio.TimeoutError):
|
||||||
|
pass
|
||||||
self._reader = None
|
self._reader = None
|
||||||
self._writer = None
|
self._writer = None
|
||||||
|
|
||||||
@@ -135,7 +148,10 @@ class NutClient:
|
|||||||
if self._writer is None:
|
if self._writer is None:
|
||||||
raise NutClientError("Not connected")
|
raise NutClientError("Not connected")
|
||||||
self._writer.write(f"{cmd}\n".encode())
|
self._writer.write(f"{cmd}\n".encode())
|
||||||
await self._writer.drain()
|
try:
|
||||||
|
await asyncio.wait_for(self._writer.drain(), timeout=_WRITE_TIMEOUT)
|
||||||
|
except asyncio.TimeoutError as exc:
|
||||||
|
raise NutClientError("Write timeout") from exc
|
||||||
|
|
||||||
async def _readline(self) -> str:
|
async def _readline(self) -> str:
|
||||||
"""Read one line from upsd, stripping trailing newline."""
|
"""Read one line from upsd, stripping trailing newline."""
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||||
|
|
||||||
from notify_bridge_core.models.events import EventType, ServiceEvent
|
from notify_bridge_core.models.events import EventType, ServiceEvent
|
||||||
from notify_bridge_core.providers.base import ServiceProvider, ServiceProviderType
|
from notify_bridge_core.providers.base import ServiceProvider, ServiceProviderType
|
||||||
@@ -57,6 +58,13 @@ SCHEDULER_VARIABLES: list[TemplateVariableDefinition] = [
|
|||||||
example="Monday",
|
example="Monday",
|
||||||
provider_type=ServiceProviderType.SCHEDULER,
|
provider_type=ServiceProviderType.SCHEDULER,
|
||||||
),
|
),
|
||||||
|
TemplateVariableDefinition(
|
||||||
|
name="timezone",
|
||||||
|
type="string",
|
||||||
|
description="IANA timezone used to compute current_date/time",
|
||||||
|
example="Europe/Warsaw",
|
||||||
|
provider_type=ServiceProviderType.SCHEDULER,
|
||||||
|
),
|
||||||
TemplateVariableDefinition(
|
TemplateVariableDefinition(
|
||||||
name="custom_vars",
|
name="custom_vars",
|
||||||
type="dict",
|
type="dict",
|
||||||
@@ -83,7 +91,8 @@ class SchedulerServiceProvider(ServiceProvider):
|
|||||||
custom_variables: dict[str, str] | None = None,
|
custom_variables: dict[str, str] | None = None,
|
||||||
date_format: str = "%d.%m.%Y",
|
date_format: str = "%d.%m.%Y",
|
||||||
time_format: str = "%H:%M",
|
time_format: str = "%H:%M",
|
||||||
datetime_format: str = "%d.%m.%Y, %H:%M UTC",
|
datetime_format: str = "%d.%m.%Y, %H:%M %Z",
|
||||||
|
timezone_name: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._name = name
|
self._name = name
|
||||||
self._tracker_name = tracker_name
|
self._tracker_name = tracker_name
|
||||||
@@ -91,6 +100,18 @@ class SchedulerServiceProvider(ServiceProvider):
|
|||||||
self._date_format = date_format
|
self._date_format = date_format
|
||||||
self._time_format = time_format
|
self._time_format = time_format
|
||||||
self._datetime_format = datetime_format
|
self._datetime_format = datetime_format
|
||||||
|
# Resolve a timezone for date/time rendering. Falls back to UTC on
|
||||||
|
# invalid IANA names so a typo in app settings doesn't break polls.
|
||||||
|
tz: ZoneInfo
|
||||||
|
if timezone_name:
|
||||||
|
try:
|
||||||
|
tz = ZoneInfo(timezone_name)
|
||||||
|
except (ZoneInfoNotFoundError, ValueError):
|
||||||
|
_LOGGER.warning("Unknown timezone %r; falling back to UTC", timezone_name)
|
||||||
|
tz = ZoneInfo("UTC")
|
||||||
|
else:
|
||||||
|
tz = ZoneInfo("UTC")
|
||||||
|
self._tz = tz
|
||||||
|
|
||||||
async def connect(self) -> bool:
|
async def connect(self) -> bool:
|
||||||
return True # virtual provider — always connected
|
return True # virtual provider — always connected
|
||||||
@@ -103,7 +124,8 @@ class SchedulerServiceProvider(ServiceProvider):
|
|||||||
collection_ids: list[str],
|
collection_ids: list[str],
|
||||||
tracker_state: dict[str, Any],
|
tracker_state: dict[str, Any],
|
||||||
) -> tuple[list[ServiceEvent], dict[str, Any]]:
|
) -> tuple[list[ServiceEvent], dict[str, Any]]:
|
||||||
now = datetime.now(timezone.utc)
|
now_utc = datetime.now(timezone.utc)
|
||||||
|
now = now_utc.astimezone(self._tz)
|
||||||
# State uses {collection_id: {dict}} convention like other providers
|
# State uses {collection_id: {dict}} convention like other providers
|
||||||
sched_state = tracker_state.get("scheduler", {})
|
sched_state = tracker_state.get("scheduler", {})
|
||||||
fire_count = sched_state.get("fire_count", 0) + 1
|
fire_count = sched_state.get("fire_count", 0) + 1
|
||||||
@@ -115,6 +137,7 @@ class SchedulerServiceProvider(ServiceProvider):
|
|||||||
"current_time": now.strftime(self._time_format),
|
"current_time": now.strftime(self._time_format),
|
||||||
"current_datetime": now.strftime(self._datetime_format),
|
"current_datetime": now.strftime(self._datetime_format),
|
||||||
"weekday": _WEEKDAYS[now.weekday()],
|
"weekday": _WEEKDAYS[now.weekday()],
|
||||||
|
"timezone": self._tz.key,
|
||||||
"custom_vars": dict(self._custom_variables),
|
"custom_vars": dict(self._custom_variables),
|
||||||
}
|
}
|
||||||
# Flatten custom variables at top level for easy template access
|
# Flatten custom variables at top level for easy template access
|
||||||
|
|||||||
@@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
@@ -19,34 +21,58 @@ class StorageBackend(Protocol):
|
|||||||
async def remove(self) -> None: ...
|
async def remove(self) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
def _read_file(path: Path) -> str | None:
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
return path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _atomic_write(path: Path, payload: str) -> None:
|
||||||
|
"""Write atomically: tmp file + rename. Prevents half-written files on crash."""
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||||
|
tmp.write_text(payload, encoding="utf-8")
|
||||||
|
os.replace(tmp, path)
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_file(path: Path) -> None:
|
||||||
|
if path.exists():
|
||||||
|
path.unlink()
|
||||||
|
|
||||||
|
|
||||||
class JsonFileBackend:
|
class JsonFileBackend:
|
||||||
"""Simple JSON file storage backend."""
|
"""Simple JSON file storage backend.
|
||||||
|
|
||||||
|
All blocking I/O is wrapped in ``asyncio.to_thread`` so callers can
|
||||||
|
``await load() / save() / remove()`` without stalling the event loop.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, path: Path) -> None:
|
def __init__(self, path: Path) -> None:
|
||||||
self._path = path
|
self._path = path
|
||||||
|
|
||||||
async def load(self) -> dict[str, Any] | None:
|
async def load(self) -> dict[str, Any] | None:
|
||||||
if not self._path.exists():
|
try:
|
||||||
|
text = await asyncio.to_thread(_read_file, self._path)
|
||||||
|
except OSError as err:
|
||||||
|
_LOGGER.warning("Failed to load %s: %s", self._path, err)
|
||||||
|
return None
|
||||||
|
if text is None:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
text = self._path.read_text(encoding="utf-8")
|
|
||||||
return json.loads(text)
|
return json.loads(text)
|
||||||
except (json.JSONDecodeError, OSError) as err:
|
except json.JSONDecodeError as err:
|
||||||
_LOGGER.warning("Failed to load %s: %s", self._path, err)
|
_LOGGER.warning("Failed to parse %s: %s", self._path, err)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def save(self, data: dict[str, Any]) -> None:
|
async def save(self, data: dict[str, Any]) -> None:
|
||||||
|
payload = json.dumps(data, default=str)
|
||||||
try:
|
try:
|
||||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
await asyncio.to_thread(_atomic_write, self._path, payload)
|
||||||
self._path.write_text(
|
|
||||||
json.dumps(data, default=str), encoding="utf-8"
|
|
||||||
)
|
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
_LOGGER.error("Failed to save %s: %s", self._path, err)
|
_LOGGER.error("Failed to save %s: %s", self._path, err)
|
||||||
|
|
||||||
async def remove(self) -> None:
|
async def remove(self) -> None:
|
||||||
try:
|
try:
|
||||||
if self._path.exists():
|
await asyncio.to_thread(_remove_file, self._path)
|
||||||
self._path.unlink()
|
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
_LOGGER.error("Failed to remove %s: %s", self._path, err)
|
_LOGGER.error("Failed to remove %s: %s", self._path, err)
|
||||||
|
|||||||
@@ -224,6 +224,7 @@ def build_template_context(
|
|||||||
ctx.setdefault("current_time", event.extra.get("current_time", ""))
|
ctx.setdefault("current_time", event.extra.get("current_time", ""))
|
||||||
ctx.setdefault("current_datetime", event.extra.get("current_datetime", ""))
|
ctx.setdefault("current_datetime", event.extra.get("current_datetime", ""))
|
||||||
ctx.setdefault("weekday", event.extra.get("weekday", ""))
|
ctx.setdefault("weekday", event.extra.get("weekday", ""))
|
||||||
|
ctx.setdefault("timezone", event.extra.get("timezone", "UTC"))
|
||||||
ctx.setdefault("custom_vars", event.extra.get("custom_vars", {}))
|
ctx.setdefault("custom_vars", event.extra.get("custom_vars", {}))
|
||||||
|
|
||||||
return ctx
|
return ctx
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
📸 Photos from {% if public_url %}<a href="{{ public_url }}">{{ album_name }}</a>{% else %}"{{ album_name }}"{% endif %}:
|
🗓️ Scheduled delivery — random photos from {% if public_url %}<a href="{{ public_url }}">{{ album_name }}</a>{% else %}"{{ album_name }}"{% endif %}:
|
||||||
{%- for asset in assets %}
|
{%- for asset in assets %}
|
||||||
• {%- if asset.type == "VIDEO" %} 🎬{% else %} 🖼️{% endif %} {% if asset.public_url %}<a href="{{ asset.public_url }}">{{ asset.filename }}</a>{% else %}{{ asset.filename }}{% endif %}
|
• {%- if asset.type == "VIDEO" %} 🎬{% else %} 🖼️{% endif %} {% if asset.public_url %}<a href="{{ asset.public_url }}">{{ asset.filename }}</a>{% else %}{{ asset.filename }}{% endif %}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
📸 Фото из {% if public_url %}<a href="{{ public_url }}">{{ album_name }}</a>{% else %}"{{ album_name }}"{% endif %}:
|
🗓️ Доставка по расписанию — случайные фото из {% if public_url %}<a href="{{ public_url }}">{{ album_name }}</a>{% else %}"{{ album_name }}"{% endif %}:
|
||||||
{%- for asset in assets %}
|
{%- for asset in assets %}
|
||||||
• {%- if asset.type == "VIDEO" %} 🎬{% else %} 🖼️{% endif %} {% if asset.public_url %}<a href="{{ asset.public_url }}">{{ asset.filename }}</a>{% else %}{{ asset.filename }}{% endif %}
|
• {%- if asset.type == "VIDEO" %} 🎬{% else %} 🖼️{% endif %} {% if asset.public_url %}<a href="{{ asset.public_url }}">{{ asset.filename }}</a>{% else %}{{ asset.filename }}{% endif %}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "notify-bridge-server"
|
name = "notify-bridge-server"
|
||||||
version = "0.3.0"
|
version = "0.5.0"
|
||||||
description = "Standalone Notify Bridge server — FastAPI REST API with SQLite database"
|
description = "Standalone Notify Bridge server — FastAPI REST API with SQLite database"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
@@ -28,6 +28,7 @@ dev = [
|
|||||||
"pytest>=8.0",
|
"pytest>=8.0",
|
||||||
"pytest-asyncio>=0.23",
|
"pytest-asyncio>=0.23",
|
||||||
"httpx>=0.27",
|
"httpx>=0.27",
|
||||||
|
"aioresponses>=0.7",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
@@ -35,3 +36,14 @@ notify-bridge = "notify_bridge_server.main:run"
|
|||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["src/notify_bridge_server"]
|
packages = ["src/notify_bridge_server"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
# The default filter doesn't let SQLAlchemy warnings fail the suite, which
|
||||||
|
# matters because our migrations emit a handful of deprecation warnings we
|
||||||
|
# don't want to suppress at source.
|
||||||
|
filterwarnings = [
|
||||||
|
"ignore::DeprecationWarning:passlib",
|
||||||
|
"ignore::DeprecationWarning:bcrypt",
|
||||||
|
]
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ _SETTING_KEYS = {
|
|||||||
"telegram_asset_cache_max_entries": None, # LRU cap for both caches
|
"telegram_asset_cache_max_entries": None, # LRU cap for both caches
|
||||||
"supported_locales": None, # comma-separated locale codes
|
"supported_locales": None, # comma-separated locale codes
|
||||||
"timezone": "NOTIFY_BRIDGE_TIMEZONE", # IANA tz (e.g. "Europe/Warsaw"); empty = UTC
|
"timezone": "NOTIFY_BRIDGE_TIMEZONE", # IANA tz (e.g. "Europe/Warsaw"); empty = UTC
|
||||||
|
# Logging — applied live via apply_log_levels() when changed.
|
||||||
|
"log_level": "NOTIFY_BRIDGE_LOG_LEVEL", # DEBUG/INFO/WARNING/ERROR
|
||||||
|
"log_format": "NOTIFY_BRIDGE_LOG_FORMAT", # text|json (requires restart to switch)
|
||||||
|
"log_levels": "NOTIFY_BRIDGE_LOG_LEVELS", # module=LEVEL,module2=LEVEL
|
||||||
}
|
}
|
||||||
|
|
||||||
_DEFAULTS = {
|
_DEFAULTS = {
|
||||||
@@ -35,12 +39,20 @@ _DEFAULTS = {
|
|||||||
"telegram_asset_cache_max_entries": "5000",
|
"telegram_asset_cache_max_entries": "5000",
|
||||||
"supported_locales": "en,ru",
|
"supported_locales": "en,ru",
|
||||||
"timezone": "UTC",
|
"timezone": "UTC",
|
||||||
|
"log_level": "INFO",
|
||||||
|
"log_format": "text",
|
||||||
|
"log_levels": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Settings whose changes require dropping in-memory Telegram caches so the
|
# Settings whose changes require dropping in-memory Telegram caches so the
|
||||||
# next dispatch rebuilds them with the new parameters. Files are preserved.
|
# next dispatch rebuilds them with the new parameters. Files are preserved.
|
||||||
_CACHE_SETTING_KEYS = {"telegram_cache_ttl_hours", "telegram_asset_cache_max_entries"}
|
_CACHE_SETTING_KEYS = {"telegram_cache_ttl_hours", "telegram_asset_cache_max_entries"}
|
||||||
|
|
||||||
|
# Settings that change logging behaviour. ``log_level`` + ``log_levels`` apply
|
||||||
|
# live via apply_log_levels(); ``log_format`` requires a restart because
|
||||||
|
# changing it means swapping the handler formatter entirely.
|
||||||
|
_LOG_SETTING_KEYS = {"log_level", "log_levels", "log_format"}
|
||||||
|
|
||||||
|
|
||||||
async def get_setting(session: AsyncSession, key: str) -> str:
|
async def get_setting(session: AsyncSession, key: str) -> str:
|
||||||
"""Read a setting from DB, falling back to env var then default."""
|
"""Read a setting from DB, falling back to env var then default."""
|
||||||
@@ -66,6 +78,9 @@ class SettingsUpdate(BaseModel):
|
|||||||
telegram_asset_cache_max_entries: int | str | None = None
|
telegram_asset_cache_max_entries: int | str | None = None
|
||||||
supported_locales: str | None = None
|
supported_locales: str | None = None
|
||||||
timezone: str | None = None
|
timezone: str | None = None
|
||||||
|
log_level: str | None = None
|
||||||
|
log_format: str | None = None
|
||||||
|
log_levels: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
@@ -94,6 +109,8 @@ async def update_settings(
|
|||||||
old_base_url = await get_setting(session, "external_url")
|
old_base_url = await get_setting(session, "external_url")
|
||||||
old_secret = await get_setting(session, "telegram_webhook_secret")
|
old_secret = await get_setting(session, "telegram_webhook_secret")
|
||||||
old_cache_values = {k: await get_setting(session, k) for k in _CACHE_SETTING_KEYS}
|
old_cache_values = {k: await get_setting(session, k) for k in _CACHE_SETTING_KEYS}
|
||||||
|
old_timezone = await get_setting(session, "timezone")
|
||||||
|
old_log_values = {k: await get_setting(session, k) for k in _LOG_SETTING_KEYS}
|
||||||
|
|
||||||
for key in _SETTING_KEYS:
|
for key in _SETTING_KEYS:
|
||||||
value = getattr(body, key, None)
|
value = getattr(body, key, None)
|
||||||
@@ -128,6 +145,33 @@ async def update_settings(
|
|||||||
|
|
||||||
new_base_url = await get_setting(session, "external_url")
|
new_base_url = await get_setting(session, "external_url")
|
||||||
new_secret = await get_setting(session, "telegram_webhook_secret")
|
new_secret = await get_setting(session, "telegram_webhook_secret")
|
||||||
|
new_timezone = await get_setting(session, "timezone")
|
||||||
|
new_log_values = {k: await get_setting(session, k) for k in _LOG_SETTING_KEYS}
|
||||||
|
|
||||||
|
# Apply live log-level changes (log_format still needs a restart).
|
||||||
|
if (new_log_values["log_level"] != old_log_values["log_level"]
|
||||||
|
or new_log_values["log_levels"] != old_log_values["log_levels"]):
|
||||||
|
from ..logging_setup import apply_log_levels
|
||||||
|
apply_log_levels(
|
||||||
|
level=new_log_values["log_level"] or None,
|
||||||
|
per_module_levels=new_log_values["log_levels"],
|
||||||
|
)
|
||||||
|
_LOGGER.info(
|
||||||
|
"Log levels updated: root=%s overrides=%r",
|
||||||
|
new_log_values["log_level"], new_log_values["log_levels"],
|
||||||
|
)
|
||||||
|
if new_log_values["log_format"] != old_log_values["log_format"]:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"log_format changed from %r to %r — restart the server for it to take effect",
|
||||||
|
old_log_values["log_format"], new_log_values["log_format"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cron triggers freeze their timezone at construction time, so a tz change
|
||||||
|
# has no effect until the jobs are rebuilt — do that here, before we
|
||||||
|
# return success, so the UI reflects the actual schedule immediately.
|
||||||
|
if new_timezone != old_timezone:
|
||||||
|
from ..services.scheduler import reschedule_cron_jobs_for_timezone_change
|
||||||
|
await reschedule_cron_jobs_for_timezone_change()
|
||||||
|
|
||||||
# Update webhook secret in the webhook handler module
|
# Update webhook secret in the webhook handler module
|
||||||
if new_secret != old_secret:
|
if new_secret != old_secret:
|
||||||
@@ -190,7 +234,10 @@ async def _reregister_webhooks(
|
|||||||
if res.get("success"):
|
if res.get("success"):
|
||||||
_LOGGER.info("Re-registered webhook for bot %d (%s)", bot.id, bot.name)
|
_LOGGER.info("Re-registered webhook for bot %d (%s)", bot.id, bot.name)
|
||||||
else:
|
else:
|
||||||
_LOGGER.warning(
|
# Webhook re-register failure means the bot silently stops
|
||||||
"Failed to re-register webhook for bot %d: %s",
|
# delivering updates — this is operational visibility for an
|
||||||
bot.id, res.get("error"),
|
# admin, ERROR is appropriate.
|
||||||
|
_LOGGER.error(
|
||||||
|
"Failed to re-register webhook for bot %d (%s): %s",
|
||||||
|
bot.id, bot.name, res.get("error"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -93,7 +93,14 @@ async def update_email_bot(
|
|||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
bot = await _get_user_bot(session, bot_id, user.id)
|
bot = await _get_user_bot(session, bot_id, user.id)
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
updates = body.model_dump(exclude_unset=True)
|
||||||
|
# Reject the masked value the GET response returns so the stored password
|
||||||
|
# is preserved if the user saves without retyping it.
|
||||||
|
if "smtp_password" in updates:
|
||||||
|
pw = updates["smtp_password"]
|
||||||
|
if isinstance(pw, str) and pw.startswith("***"):
|
||||||
|
updates.pop("smtp_password")
|
||||||
|
for field, value in updates.items():
|
||||||
setattr(bot, field, value)
|
setattr(bot, field, value)
|
||||||
session.add(bot)
|
session.add(bot)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -7,6 +7,11 @@ from pydantic import BaseModel
|
|||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.ssrf import (
|
||||||
|
UnsafeURLError,
|
||||||
|
avalidate_outbound_url,
|
||||||
|
)
|
||||||
|
|
||||||
from ..auth.dependencies import get_current_user
|
from ..auth.dependencies import get_current_user
|
||||||
from ..database.engine import get_session
|
from ..database.engine import get_session
|
||||||
from ..database.models import MatrixBot, User
|
from ..database.models import MatrixBot, User
|
||||||
@@ -33,6 +38,21 @@ class MatrixBotUpdate(BaseModel):
|
|||||||
display_name: str | None = None
|
display_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_masked_secret(value: str | None) -> bool:
|
||||||
|
"""True when a field still carries our masked placeholder."""
|
||||||
|
return bool(value) and (value.startswith("***") or "..." in value)
|
||||||
|
|
||||||
|
|
||||||
|
async def _validate_homeserver_url(url: str) -> None:
|
||||||
|
"""Reject homeserver URLs that point to blocked networks."""
|
||||||
|
try:
|
||||||
|
await avalidate_outbound_url(url)
|
||||||
|
except UnsafeURLError as err:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"Invalid homeserver_url: {err}"
|
||||||
|
) from err
|
||||||
|
|
||||||
|
|
||||||
@router.get("")
|
@router.get("")
|
||||||
async def list_matrix_bots(
|
async def list_matrix_bots(
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
@@ -50,6 +70,7 @@ async def create_matrix_bot(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
|
await _validate_homeserver_url(body.homeserver_url)
|
||||||
bot = MatrixBot(user_id=user.id, **body.model_dump())
|
bot = MatrixBot(user_id=user.id, **body.model_dump())
|
||||||
session.add(bot)
|
session.add(bot)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -74,7 +95,19 @@ async def update_matrix_bot(
|
|||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
bot = await _get_user_bot(session, bot_id, user.id)
|
bot = await _get_user_bot(session, bot_id, user.id)
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
updates = body.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
# Re-validate homeserver_url whenever the client supplies a new one so
|
||||||
|
# no private/loopback target can ever be saved, even via update.
|
||||||
|
if "homeserver_url" in updates and updates["homeserver_url"]:
|
||||||
|
await _validate_homeserver_url(updates["homeserver_url"])
|
||||||
|
|
||||||
|
# Never accept the masked placeholder the GET response returns. If the
|
||||||
|
# client echoes it back, keep the stored secret.
|
||||||
|
if "access_token" in updates and _is_masked_secret(updates["access_token"]):
|
||||||
|
updates.pop("access_token")
|
||||||
|
|
||||||
|
for field, value in updates.items():
|
||||||
setattr(bot, field, value)
|
setattr(bot, field, value)
|
||||||
session.add(bot)
|
session.add(bot)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -108,15 +141,17 @@ async def test_matrix_bot(
|
|||||||
If room_id is not provided, just verifies the access token by calling /whoami.
|
If room_id is not provided, just verifies the access token by calling /whoami.
|
||||||
"""
|
"""
|
||||||
bot = await _get_user_bot(session, bot_id, user.id)
|
bot = await _get_user_bot(session, bot_id, user.id)
|
||||||
|
# Defense-in-depth: even though create/update validate the URL, a bot row
|
||||||
|
# written before this guard was added could still point at a blocked host.
|
||||||
|
await _validate_homeserver_url(bot.homeserver_url)
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from ..services.http_session import get_http_session
|
from ..services.http_session import get_http_session
|
||||||
http = await get_http_session()
|
http = await get_http_session()
|
||||||
# Verify token with /whoami
|
|
||||||
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
|
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
|
||||||
headers = {"Authorization": f"Bearer {bot.access_token}"}
|
headers = {"Authorization": f"Bearer {bot.access_token}"}
|
||||||
try:
|
try:
|
||||||
async with http.get(whoami_url, headers=headers) as resp:
|
async with http.get(whoami_url, headers=headers, allow_redirects=False) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
body = await resp.text()
|
body = await resp.text()
|
||||||
return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"}
|
return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"}
|
||||||
@@ -126,7 +161,6 @@ async def test_matrix_bot(
|
|||||||
|
|
||||||
result = {"success": True, "user_id": whoami.get("user_id", "")}
|
result = {"success": True, "user_id": whoami.get("user_id", "")}
|
||||||
|
|
||||||
# Optionally send a test message
|
|
||||||
if room_id:
|
if room_id:
|
||||||
from ..services.notifier import _get_test_message
|
from ..services.notifier import _get_test_message
|
||||||
from notify_bridge_core.notifications.matrix.client import MatrixClient
|
from notify_bridge_core.notifications.matrix.client import MatrixClient
|
||||||
@@ -148,7 +182,7 @@ def _response(bot: MatrixBot) -> dict:
|
|||||||
"name": bot.name,
|
"name": bot.name,
|
||||||
"icon": bot.icon,
|
"icon": bot.icon,
|
||||||
"homeserver_url": bot.homeserver_url,
|
"homeserver_url": bot.homeserver_url,
|
||||||
"access_token": f"{bot.access_token[:8]}...{bot.access_token[-4:]}" if len(bot.access_token) > 12 else "***",
|
"access_token": f"***{bot.access_token[-4:]}" if len(bot.access_token) > 4 else "***",
|
||||||
"display_name": bot.display_name,
|
"display_name": bot.display_name,
|
||||||
"created_at": bot.created_at.isoformat(),
|
"created_at": bot.created_at.isoformat(),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ from ..database.models import (
|
|||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
from ..services.notifier import send_test_notification
|
from ..services.notifier import send_test_notification
|
||||||
from ..services.test_dispatch import dispatch_test_notification
|
from ..services.manual_dispatch import dispatch_test_notification
|
||||||
|
from ..services.scheduler import reschedule_immich_dispatch_jobs
|
||||||
from .helpers import get_owned_entity
|
from .helpers import get_owned_entity
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@@ -118,6 +119,7 @@ async def create_notification_tracker_target(
|
|||||||
session.add(tt)
|
session.add(tt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(tt)
|
await session.refresh(tt)
|
||||||
|
await reschedule_immich_dispatch_jobs()
|
||||||
return await _tt_response(session, tt)
|
return await _tt_response(session, tt)
|
||||||
|
|
||||||
|
|
||||||
@@ -164,6 +166,7 @@ async def update_notification_tracker_target(
|
|||||||
session.add(tt)
|
session.add(tt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(tt)
|
await session.refresh(tt)
|
||||||
|
await reschedule_immich_dispatch_jobs()
|
||||||
return await _tt_response(session, tt)
|
return await _tt_response(session, tt)
|
||||||
|
|
||||||
|
|
||||||
@@ -181,6 +184,7 @@ async def delete_notification_tracker_target(
|
|||||||
raise HTTPException(status_code=404, detail="Tracker-target link not found")
|
raise HTTPException(status_code=404, detail="Tracker-target link not found")
|
||||||
await session.delete(tt)
|
await session.delete(tt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
await reschedule_immich_dispatch_jobs()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{tracker_target_id}/test/{test_type}")
|
@router.post("/{tracker_target_id}/test/{test_type}")
|
||||||
|
|||||||
@@ -11,13 +11,18 @@ from ..auth.dependencies import get_current_user
|
|||||||
from ..database.engine import get_session
|
from ..database.engine import get_session
|
||||||
from ..database.models import (
|
from ..database.models import (
|
||||||
EventLog,
|
EventLog,
|
||||||
|
NotificationTarget,
|
||||||
NotificationTracker,
|
NotificationTracker,
|
||||||
NotificationTrackerState,
|
NotificationTrackerState,
|
||||||
NotificationTrackerTarget,
|
NotificationTrackerTarget,
|
||||||
ServiceProvider,
|
ServiceProvider,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
from ..services.scheduler import schedule_tracker, unschedule_tracker
|
from ..services.scheduler import (
|
||||||
|
reschedule_immich_dispatch_jobs,
|
||||||
|
schedule_tracker,
|
||||||
|
unschedule_tracker,
|
||||||
|
)
|
||||||
from .helpers import get_owned_entity
|
from .helpers import get_owned_entity
|
||||||
from .notification_tracker_targets import _tt_response
|
from .notification_tracker_targets import _tt_response
|
||||||
|
|
||||||
@@ -54,11 +59,79 @@ async def list_notification_trackers(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
|
# Batched loader: pull trackers, then all their tracker-target links in
|
||||||
|
# a single query, then the referenced targets in a single query. Avoids
|
||||||
|
# the old 1 + N + N*M pattern that ran ~60 round-trips for 10 trackers.
|
||||||
result = await session.exec(
|
result = await session.exec(
|
||||||
select(NotificationTracker).where(NotificationTracker.user_id == user.id)
|
select(NotificationTracker).where(NotificationTracker.user_id == user.id)
|
||||||
)
|
)
|
||||||
trackers = result.all()
|
trackers = list(result.all())
|
||||||
return [await _tracker_response(session, t) for t in trackers]
|
if not trackers:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tracker_ids = [t.id for t in trackers]
|
||||||
|
tt_result = await session.exec(
|
||||||
|
select(NotificationTrackerTarget).where(
|
||||||
|
NotificationTrackerTarget.tracker_id.in_(tracker_ids)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tt_rows = list(tt_result.all())
|
||||||
|
|
||||||
|
target_ids = {tt.target_id for tt in tt_rows}
|
||||||
|
targets_by_id: dict[int, NotificationTarget] = {}
|
||||||
|
if target_ids:
|
||||||
|
tgt_result = await session.exec(
|
||||||
|
select(NotificationTarget).where(NotificationTarget.id.in_(target_ids))
|
||||||
|
)
|
||||||
|
targets_by_id = {t.id: t for t in tgt_result.all()}
|
||||||
|
|
||||||
|
tts_by_tracker: dict[int, list[NotificationTrackerTarget]] = {}
|
||||||
|
for tt in tt_rows:
|
||||||
|
tts_by_tracker.setdefault(tt.tracker_id, []).append(tt)
|
||||||
|
|
||||||
|
return [
|
||||||
|
_build_tracker_response(t, tts_by_tracker.get(t.id, []), targets_by_id)
|
||||||
|
for t in trackers
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_tracker_response(
|
||||||
|
t: NotificationTracker,
|
||||||
|
tts: list[NotificationTrackerTarget],
|
||||||
|
targets_by_id: dict[int, NotificationTarget],
|
||||||
|
) -> dict:
|
||||||
|
"""In-memory assembler for a tracker + its pre-loaded links/targets."""
|
||||||
|
tracker_targets = []
|
||||||
|
for tt in tts:
|
||||||
|
target = targets_by_id.get(tt.target_id)
|
||||||
|
tracker_targets.append({
|
||||||
|
"id": tt.id,
|
||||||
|
"tracker_id": tt.tracker_id,
|
||||||
|
"target_id": tt.target_id,
|
||||||
|
"target_name": target.name if target else None,
|
||||||
|
"target_type": target.type if target else None,
|
||||||
|
"target_icon": target.icon if target else None,
|
||||||
|
"tracking_config_id": tt.tracking_config_id,
|
||||||
|
"template_config_id": tt.template_config_id,
|
||||||
|
"enabled": tt.enabled,
|
||||||
|
"quiet_hours_start": tt.quiet_hours_start,
|
||||||
|
"quiet_hours_end": tt.quiet_hours_end,
|
||||||
|
"created_at": tt.created_at.isoformat(),
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
"id": t.id,
|
||||||
|
"name": t.name,
|
||||||
|
"icon": t.icon,
|
||||||
|
"provider_id": t.provider_id,
|
||||||
|
"collection_ids": t.collection_ids,
|
||||||
|
"scan_interval": t.scan_interval,
|
||||||
|
"batch_duration": t.batch_duration,
|
||||||
|
"default_tracking_config_id": t.default_tracking_config_id,
|
||||||
|
"default_template_config_id": t.default_template_config_id,
|
||||||
|
"enabled": t.enabled,
|
||||||
|
"tracker_targets": tracker_targets,
|
||||||
|
"created_at": t.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||||
@@ -77,6 +150,7 @@ async def create_notification_tracker(
|
|||||||
await session.refresh(tracker)
|
await session.refresh(tracker)
|
||||||
if tracker.enabled:
|
if tracker.enabled:
|
||||||
await schedule_tracker(tracker.id, tracker.scan_interval)
|
await schedule_tracker(tracker.id, tracker.scan_interval)
|
||||||
|
await reschedule_immich_dispatch_jobs()
|
||||||
return await _tracker_response(session, tracker)
|
return await _tracker_response(session, tracker)
|
||||||
|
|
||||||
|
|
||||||
@@ -107,6 +181,7 @@ async def update_notification_tracker(
|
|||||||
await schedule_tracker(tracker.id, tracker.scan_interval)
|
await schedule_tracker(tracker.id, tracker.scan_interval)
|
||||||
else:
|
else:
|
||||||
await unschedule_tracker(tracker.id)
|
await unschedule_tracker(tracker.id)
|
||||||
|
await reschedule_immich_dispatch_jobs()
|
||||||
return await _tracker_response(session, tracker)
|
return await _tracker_response(session, tracker)
|
||||||
|
|
||||||
|
|
||||||
@@ -139,6 +214,7 @@ async def delete_notification_tracker(
|
|||||||
await session.delete(tracker)
|
await session.delete(tracker)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await unschedule_tracker(tracker_id)
|
await unschedule_tracker(tracker_id)
|
||||||
|
await reschedule_immich_dispatch_jobs()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{tracker_id}/trigger")
|
@router.post("/{tracker_id}/trigger")
|
||||||
|
|||||||
@@ -306,16 +306,31 @@ async def update_provider(
|
|||||||
if body.icon is not None:
|
if body.icon is not None:
|
||||||
provider.icon = body.icon
|
provider.icon = body.icon
|
||||||
|
|
||||||
config_changed = body.config is not None and body.config != provider.config
|
|
||||||
if body.config is not None:
|
if body.config is not None:
|
||||||
_validate_provider_config(provider.type, body.config)
|
# Merge rather than replace so the masked secrets the frontend
|
||||||
provider.config = body.config
|
# receives on GET cannot silently nuke the stored values when the
|
||||||
|
# user saves without re-entering them. Any field that still carries
|
||||||
|
# our mask placeholder ("***…") is dropped from the incoming body.
|
||||||
|
incoming = dict(body.config)
|
||||||
|
for secret_field in (
|
||||||
|
"api_key", "api_token", "webhook_secret", "password",
|
||||||
|
"client_secret", "refresh_token",
|
||||||
|
):
|
||||||
|
value = incoming.get(secret_field)
|
||||||
|
if isinstance(value, str) and value.startswith("***"):
|
||||||
|
incoming.pop(secret_field, None)
|
||||||
|
new_config = {**provider.config, **incoming}
|
||||||
|
_validate_provider_config(provider.type, new_config)
|
||||||
|
config_changed = new_config != provider.config
|
||||||
|
provider.config = new_config
|
||||||
|
|
||||||
# Re-validate connection when config changes for known provider types
|
|
||||||
if config_changed:
|
if config_changed:
|
||||||
test_result = await _validate_provider_connection(provider)
|
test_result = await _validate_provider_connection(provider)
|
||||||
if test_result.get("external_domain"):
|
if test_result.get("external_domain"):
|
||||||
provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
|
provider.config = {
|
||||||
|
**provider.config,
|
||||||
|
"external_domain": test_result["external_domain"],
|
||||||
|
}
|
||||||
|
|
||||||
session.add(provider)
|
session.add(provider)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -242,6 +242,8 @@ async def get_template_variables(
|
|||||||
"current_date": "Current date (formatted)",
|
"current_date": "Current date (formatted)",
|
||||||
"current_time": "Current time (formatted)",
|
"current_time": "Current time (formatted)",
|
||||||
"current_datetime": "Current date and time (formatted)",
|
"current_datetime": "Current date and time (formatted)",
|
||||||
|
"weekday": "Day of the week (Monday..Sunday)",
|
||||||
|
"timezone": "IANA timezone used for current_date/time",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
|||||||
from ..auth.dependencies import get_current_user
|
from ..auth.dependencies import get_current_user
|
||||||
from ..database.engine import get_session
|
from ..database.engine import get_session
|
||||||
from ..database.models import TrackingConfig, User
|
from ..database.models import TrackingConfig, User
|
||||||
|
from ..services.scheduler import reschedule_immich_dispatch_jobs
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -127,6 +128,8 @@ async def create_config(
|
|||||||
session.add(config)
|
session.add(config)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(config)
|
await session.refresh(config)
|
||||||
|
if config.provider_type == "immich":
|
||||||
|
await reschedule_immich_dispatch_jobs()
|
||||||
return _response(config)
|
return _response(config)
|
||||||
|
|
||||||
|
|
||||||
@@ -152,6 +155,8 @@ async def update_config(
|
|||||||
session.add(config)
|
session.add(config)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
await session.refresh(config)
|
await session.refresh(config)
|
||||||
|
if config.provider_type == "immich":
|
||||||
|
await reschedule_immich_dispatch_jobs()
|
||||||
return _response(config)
|
return _response(config)
|
||||||
|
|
||||||
|
|
||||||
@@ -164,8 +169,11 @@ async def delete_config(
|
|||||||
from .delete_protection import check_tracking_config, raise_if_used
|
from .delete_protection import check_tracking_config, raise_if_used
|
||||||
config = await _get(session, config_id, user.id)
|
config = await _get(session, config_id, user.id)
|
||||||
raise_if_used(await check_tracking_config(session, config.id), config.name)
|
raise_if_used(await check_tracking_config(session, config.id), config.name)
|
||||||
|
provider_type = config.provider_type
|
||||||
await session.delete(config)
|
await session.delete(config)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
if provider_type == "immich":
|
||||||
|
await reschedule_immich_dispatch_jobs()
|
||||||
|
|
||||||
|
|
||||||
def _response(c: TrackingConfig) -> dict:
|
def _response(c: TrackingConfig) -> dict:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""User management API routes (admin only)."""
|
"""User management API routes (admin only)."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
@@ -14,6 +15,15 @@ from ..auth.dependencies import require_admin
|
|||||||
from ..database.engine import get_session
|
from ..database.engine import get_session
|
||||||
from ..database.models import User
|
from ..database.models import User
|
||||||
|
|
||||||
|
|
||||||
|
async def _hash_password(password: str) -> str:
|
||||||
|
"""Run bcrypt off the event loop. Matches the helper in auth/routes.py."""
|
||||||
|
|
||||||
|
def _work() -> str:
|
||||||
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_work)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/users", tags=["users"])
|
router = APIRouter(prefix="/api/users", tags=["users"])
|
||||||
@@ -36,8 +46,12 @@ async def list_users(
|
|||||||
admin: User = Depends(require_admin),
|
admin: User = Depends(require_admin),
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""List all users (admin only)."""
|
"""List all users (admin only).
|
||||||
result = await session.exec(select(User))
|
|
||||||
|
Excludes the internal ``__system__`` placeholder (id=0) used as the
|
||||||
|
owner of default templates/configs — it is never a real account.
|
||||||
|
"""
|
||||||
|
result = await session.exec(select(User).where(User.id != 0))
|
||||||
return [
|
return [
|
||||||
{"id": u.id, "username": u.username, "role": u.role, "created_at": u.created_at.isoformat()}
|
{"id": u.id, "username": u.username, "role": u.role, "created_at": u.created_at.isoformat()}
|
||||||
for u in result.all()
|
for u in result.all()
|
||||||
@@ -61,7 +75,7 @@ async def create_user(
|
|||||||
|
|
||||||
user = User(
|
user = User(
|
||||||
username=body.username,
|
username=body.username,
|
||||||
hashed_password=bcrypt.hashpw(body.password.encode(), bcrypt.gensalt()).decode(),
|
hashed_password=await _hash_password(body.password),
|
||||||
role=body.role if body.role in ("admin", "user") else "user",
|
role=body.role if body.role in ("admin", "user") else "user",
|
||||||
)
|
)
|
||||||
session.add(user)
|
session.add(user)
|
||||||
@@ -162,7 +176,7 @@ async def reset_user_password(
|
|||||||
raise HTTPException(status_code=404, detail="User not found")
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
if len(body.new_password) < 8:
|
if len(body.new_password) < 8:
|
||||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||||
user.hashed_password = bcrypt.hashpw(body.new_password.encode(), bcrypt.gensalt()).decode()
|
user.hashed_password = await _hash_password(body.new_password)
|
||||||
# Invalidate all prior JWTs issued for this user — matches the self-serve
|
# Invalidate all prior JWTs issued for this user — matches the self-serve
|
||||||
# password-change path in auth/routes.py.
|
# password-change path in auth/routes.py.
|
||||||
user.token_version = (user.token_version or 1) + 1
|
user.token_version = (user.token_version or 1) + 1
|
||||||
|
|||||||
@@ -37,6 +37,42 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(prefix="/api/webhooks", tags=["webhooks"])
|
router = APIRouter(prefix="/api/webhooks", tags=["webhooks"])
|
||||||
|
|
||||||
|
# Hard cap on inbound webhook body size (1 MiB is far larger than anything
|
||||||
|
# legitimate providers send and keeps the worst-case memory footprint bounded
|
||||||
|
# when a malicious peer lies about Content-Length or streams slowly).
|
||||||
|
_MAX_WEBHOOK_BODY_BYTES = 1_000_000
|
||||||
|
|
||||||
|
|
||||||
|
async def _read_bounded_body(request: Request, limit: int = _MAX_WEBHOOK_BODY_BYTES) -> bytes:
|
||||||
|
"""Reject oversized inbound bodies before they exhaust memory.
|
||||||
|
|
||||||
|
First checks ``Content-Length`` (fast-path for honest peers), then
|
||||||
|
streams the body in chunks enforcing the same cap on actual bytes
|
||||||
|
received so a peer that lies about Content-Length cannot slip through.
|
||||||
|
"""
|
||||||
|
declared = request.headers.get("content-length")
|
||||||
|
if declared:
|
||||||
|
try:
|
||||||
|
if int(declared) > limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail=f"Payload too large (max {limit} bytes)",
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid Content-Length")
|
||||||
|
|
||||||
|
chunks: list[bytes] = []
|
||||||
|
size = 0
|
||||||
|
async for chunk in request.stream():
|
||||||
|
size += len(chunk)
|
||||||
|
if size > limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=413,
|
||||||
|
detail=f"Payload too large (max {limit} bytes)",
|
||||||
|
)
|
||||||
|
chunks.append(chunk)
|
||||||
|
return b"".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
async def _get_provider_by_token(
|
async def _get_provider_by_token(
|
||||||
session: AsyncSession, token: str, expected_type: str,
|
session: AsyncSession, token: str, expected_type: str,
|
||||||
@@ -169,7 +205,8 @@ async def _dispatch_webhook_event(
|
|||||||
))
|
))
|
||||||
|
|
||||||
# Dispatch to targets
|
# Dispatch to targets
|
||||||
dispatcher = NotificationDispatcher()
|
from ..services.http_session import get_http_session
|
||||||
|
dispatcher = NotificationDispatcher(session=await get_http_session())
|
||||||
target_configs = _build_target_configs(event, link_data, provider_config, app_tz)
|
target_configs = _build_target_configs(event, link_data, provider_config, app_tz)
|
||||||
if target_configs:
|
if target_configs:
|
||||||
results = await dispatcher.dispatch(event, target_configs)
|
results = await dispatcher.dispatch(event, target_configs)
|
||||||
@@ -203,7 +240,7 @@ async def gitea_webhook(token: str, request: Request):
|
|||||||
webhook_secret = (provider.config or {}).get("webhook_secret", "")
|
webhook_secret = (provider.config or {}).get("webhook_secret", "")
|
||||||
|
|
||||||
# Read raw body for HMAC check
|
# Read raw body for HMAC check
|
||||||
raw_body = await request.body()
|
raw_body = await _read_bounded_body(request)
|
||||||
|
|
||||||
if not webhook_secret:
|
if not webhook_secret:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -221,8 +258,8 @@ async def gitea_webhook(token: str, request: Request):
|
|||||||
return {"ok": True, "skipped": "no event header"}
|
return {"ok": True, "skipped": "no event header"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = await request.json()
|
payload = json.loads(raw_body.decode("utf-8"))
|
||||||
except (json.JSONDecodeError, ValueError):
|
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
|
||||||
raise HTTPException(status_code=400, detail="Invalid JSON")
|
raise HTTPException(status_code=400, detail="Invalid JSON")
|
||||||
|
|
||||||
event = parse_gitea_webhook(event_header, payload, provider.name)
|
event = parse_gitea_webhook(event_header, payload, provider.name)
|
||||||
@@ -280,10 +317,10 @@ async def planka_webhook(token: str, request: Request):
|
|||||||
if not _verify_planka_token(webhook_secret, request):
|
if not _verify_planka_token(webhook_secret, request):
|
||||||
raise HTTPException(status_code=403, detail="Invalid token")
|
raise HTTPException(status_code=403, detail="Invalid token")
|
||||||
|
|
||||||
# Parse payload
|
# Parse payload from the bounded raw_body we already read.
|
||||||
try:
|
try:
|
||||||
payload = await request.json()
|
payload = json.loads(raw_body.decode("utf-8"))
|
||||||
except (json.JSONDecodeError, ValueError):
|
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
|
||||||
raise HTTPException(status_code=400, detail="Invalid JSON")
|
raise HTTPException(status_code=400, detail="Invalid JSON")
|
||||||
|
|
||||||
event_type = payload.get("type", "")
|
event_type = payload.get("type", "")
|
||||||
@@ -446,23 +483,22 @@ async def generic_webhook(token: str, request: Request):
|
|||||||
store_payloads = provider_config.get("store_payloads", True)
|
store_payloads = provider_config.get("store_payloads", True)
|
||||||
max_stored = min(max(int(provider_config.get("max_stored_payloads", 20)), 1), 100)
|
max_stored = min(max(int(provider_config.get("max_stored_payloads", 20)), 1), 100)
|
||||||
|
|
||||||
raw_body = await request.body()
|
raw_body = await _read_bounded_body(request)
|
||||||
|
|
||||||
# Enforce payload size limit BEFORE parsing JSON
|
# Bounded read above already enforces the size cap; no need to re-check.
|
||||||
if len(raw_body) > 1_000_000:
|
|
||||||
raise HTTPException(status_code=413, detail="Payload too large (max 1 MB)")
|
|
||||||
|
|
||||||
if not _verify_generic_webhook_auth(provider_config, request, raw_body):
|
if not _verify_generic_webhook_auth(provider_config, request, raw_body):
|
||||||
raise HTTPException(status_code=403, detail="Authentication failed")
|
raise HTTPException(status_code=403, detail="Authentication failed")
|
||||||
|
|
||||||
safe_headers = _filter_headers(dict(request.headers))
|
safe_headers = _filter_headers(dict(request.headers))
|
||||||
|
|
||||||
# Parse JSON payload
|
# Parse JSON payload from the already-bounded raw_body (request.body()
|
||||||
|
# has been consumed, so request.json() is no longer usable here).
|
||||||
try:
|
try:
|
||||||
payload = await request.json()
|
payload = json.loads(raw_body.decode("utf-8"))
|
||||||
if not isinstance(payload, dict):
|
if not isinstance(payload, dict):
|
||||||
raise ValueError("Payload must be a JSON object")
|
raise ValueError("Payload must be a JSON object")
|
||||||
except (json.JSONDecodeError, ValueError):
|
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
|
||||||
if store_payloads:
|
if store_payloads:
|
||||||
async with AsyncSession(get_engine()) as log_session:
|
async with AsyncSession(get_engine()) as log_session:
|
||||||
await _save_webhook_log(
|
await _save_webhook_log(
|
||||||
|
|||||||
@@ -7,30 +7,51 @@ import jwt
|
|||||||
from ..config import settings
|
from ..config import settings
|
||||||
|
|
||||||
ALGORITHM = "HS256"
|
ALGORITHM = "HS256"
|
||||||
|
_LEEWAY_SECONDS = 10
|
||||||
|
|
||||||
|
|
||||||
|
def _now_utc() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(user_id: int, role: str, token_version: int = 1) -> str:
|
def create_access_token(user_id: int, role: str, token_version: int = 1) -> str:
|
||||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes)
|
now = _now_utc()
|
||||||
|
expire = now + timedelta(minutes=settings.access_token_expire_minutes)
|
||||||
payload = {
|
payload = {
|
||||||
|
"iss": settings.jwt_issuer,
|
||||||
|
"aud": settings.jwt_audience,
|
||||||
"sub": str(user_id),
|
"sub": str(user_id),
|
||||||
"role": role,
|
"role": role,
|
||||||
"type": "access",
|
"type": "access",
|
||||||
"ver": token_version,
|
"ver": token_version,
|
||||||
|
"iat": now,
|
||||||
"exp": expire,
|
"exp": expire,
|
||||||
}
|
}
|
||||||
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
|
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
def create_refresh_token(user_id: int, token_version: int = 1) -> str:
|
def create_refresh_token(user_id: int, token_version: int = 1) -> str:
|
||||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.refresh_token_expire_days)
|
now = _now_utc()
|
||||||
|
expire = now + timedelta(days=settings.refresh_token_expire_days)
|
||||||
payload = {
|
payload = {
|
||||||
|
"iss": settings.jwt_issuer,
|
||||||
|
"aud": settings.jwt_audience,
|
||||||
"sub": str(user_id),
|
"sub": str(user_id),
|
||||||
"type": "refresh",
|
"type": "refresh",
|
||||||
"ver": token_version,
|
"ver": token_version,
|
||||||
|
"iat": now,
|
||||||
"exp": expire,
|
"exp": expire,
|
||||||
}
|
}
|
||||||
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
|
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
def decode_token(token: str) -> dict:
|
def decode_token(token: str) -> dict:
|
||||||
return jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM])
|
return jwt.decode(
|
||||||
|
token,
|
||||||
|
settings.secret_key,
|
||||||
|
algorithms=[ALGORITHM],
|
||||||
|
audience=settings.jwt_audience,
|
||||||
|
issuer=settings.jwt_issuer,
|
||||||
|
leeway=_LEEWAY_SECONDS,
|
||||||
|
options={"require": ["exp", "sub", "iss", "aud", "type"]},
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Authentication API routes."""
|
"""Authentication API routes."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from slowapi import Limiter
|
from slowapi import Limiter
|
||||||
@@ -16,7 +18,9 @@ from .jwt import create_access_token, create_refresh_token, decode_token
|
|||||||
|
|
||||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||||
|
|
||||||
limiter = Limiter(key_func=get_remote_address)
|
# Default rate limit applied by SlowAPIMiddleware to every route that does NOT
|
||||||
|
# specify its own @limiter.limit(...) — protects against blanket abuse.
|
||||||
|
limiter = Limiter(key_func=get_remote_address, default_limits=["600/minute"])
|
||||||
|
|
||||||
|
|
||||||
class SetupRequest(BaseModel):
|
class SetupRequest(BaseModel):
|
||||||
@@ -45,27 +49,52 @@ class RefreshRequest(BaseModel):
|
|||||||
refresh_token: str
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
def _hash_password(password: str) -> str:
|
async def _hash_password(password: str) -> str:
|
||||||
|
"""bcrypt.hashpw is CPU-bound (~200-500ms); never run it on the event loop."""
|
||||||
|
|
||||||
|
def _work() -> str:
|
||||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_work)
|
||||||
|
|
||||||
def _verify_password(password: str, hashed: str) -> bool:
|
|
||||||
|
async def _verify_password(password: str, hashed: str) -> bool:
|
||||||
|
def _work() -> bool:
|
||||||
|
try:
|
||||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||||
|
except ValueError:
|
||||||
|
# Malformed hash in DB — treat as mismatch, never raise to caller.
|
||||||
|
return False
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_work)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/setup", response_model=TokenResponse)
|
@router.post("/setup", response_model=TokenResponse)
|
||||||
@limiter.limit("3/minute")
|
@limiter.limit("3/minute")
|
||||||
async def setup(request: Request, body: SetupRequest, session: AsyncSession = Depends(get_session)):
|
async def setup(request: Request, body: SetupRequest, session: AsyncSession = Depends(get_session)):
|
||||||
result = await session.exec(select(func.count()).select_from(User))
|
|
||||||
count = result.one()
|
|
||||||
if count > 0:
|
|
||||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Setup already completed.")
|
|
||||||
|
|
||||||
if len(body.password) < 8:
|
if len(body.password) < 8:
|
||||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||||
user = User(username=body.username, hashed_password=_hash_password(body.password), role="admin")
|
# Compute hash BEFORE opening the transaction so we don't hold a writer lock
|
||||||
|
# during the CPU-bound bcrypt work.
|
||||||
|
hashed = await _hash_password(body.password)
|
||||||
|
|
||||||
|
# Serialize setup via an INSERT-inside-transaction-with-count-guard.
|
||||||
|
# SQLite's writer lock plus the count check inside the transaction closes
|
||||||
|
# the TOCTOU window between two concurrent POSTs. We ignore id=0 — that's
|
||||||
|
# the internal "__system__" placeholder used for ownership of default
|
||||||
|
# templates, never a real admin.
|
||||||
|
async with session.begin():
|
||||||
|
result = await session.exec(
|
||||||
|
select(func.count()).select_from(User).where(User.id != 0)
|
||||||
|
)
|
||||||
|
count = result.one()
|
||||||
|
if count > 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Setup already completed.",
|
||||||
|
)
|
||||||
|
user = User(username=body.username, hashed_password=hashed, role="admin")
|
||||||
session.add(user)
|
session.add(user)
|
||||||
await session.commit()
|
|
||||||
await session.refresh(user)
|
await session.refresh(user)
|
||||||
|
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
@@ -79,7 +108,13 @@ async def setup(request: Request, body: SetupRequest, session: AsyncSession = De
|
|||||||
async def login(request: Request, body: LoginRequest, session: AsyncSession = Depends(get_session)):
|
async def login(request: Request, body: LoginRequest, session: AsyncSession = Depends(get_session)):
|
||||||
result = await session.exec(select(User).where(User.username == body.username))
|
result = await session.exec(select(User).where(User.username == body.username))
|
||||||
user = result.first()
|
user = result.first()
|
||||||
if not user or not _verify_password(body.password, user.hashed_password):
|
# Always run a bcrypt verification to keep the response time constant,
|
||||||
|
# preventing username-enumeration via timing side channel.
|
||||||
|
password_ok = await _verify_password(
|
||||||
|
body.password,
|
||||||
|
user.hashed_password if user else "$2b$12$" + "a" * 53,
|
||||||
|
)
|
||||||
|
if not user or not password_ok:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
|
||||||
|
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
@@ -124,16 +159,18 @@ class PasswordChangeRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@router.put("/password")
|
@router.put("/password")
|
||||||
|
@limiter.limit("10/minute")
|
||||||
async def change_password(
|
async def change_password(
|
||||||
|
request: Request,
|
||||||
body: PasswordChangeRequest,
|
body: PasswordChangeRequest,
|
||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
):
|
):
|
||||||
if not _verify_password(body.current_password, user.hashed_password):
|
if not await _verify_password(body.current_password, user.hashed_password):
|
||||||
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
||||||
if len(body.new_password) < 8:
|
if len(body.new_password) < 8:
|
||||||
raise HTTPException(status_code=400, detail="New password must be at least 8 characters")
|
raise HTTPException(status_code=400, detail="New password must be at least 8 characters")
|
||||||
user.hashed_password = _hash_password(body.new_password)
|
user.hashed_password = await _hash_password(body.new_password)
|
||||||
user.token_version = (user.token_version or 1) + 1
|
user.token_version = (user.token_version or 1) + 1
|
||||||
session.add(user)
|
session.add(user)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -141,7 +178,12 @@ async def change_password(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/needs-setup")
|
@router.get("/needs-setup")
|
||||||
async def needs_setup(session: AsyncSession = Depends(get_session)):
|
@limiter.limit("30/minute")
|
||||||
result = await session.exec(select(func.count()).select_from(User))
|
async def needs_setup(request: Request, session: AsyncSession = Depends(get_session)):
|
||||||
|
# Exclude the internal __system__ placeholder (id=0) from the count so
|
||||||
|
# a fresh install still reports needs_setup=True.
|
||||||
|
result = await session.exec(
|
||||||
|
select(func.count()).select_from(User).where(User.id != 0)
|
||||||
|
)
|
||||||
count = result.one()
|
count = result.one()
|
||||||
return {"needs_setup": count == 0}
|
return {"needs_setup": count == 0}
|
||||||
|
|||||||
@@ -108,13 +108,18 @@ def _render_cmd_template(
|
|||||||
"""Render a locale-aware command template. Falls back to 'en'."""
|
"""Render a locale-aware command template. Falls back to 'en'."""
|
||||||
template_str = _resolve_template(templates, slot_name, locale)
|
template_str = _resolve_template(templates, slot_name, locale)
|
||||||
if not template_str:
|
if not template_str:
|
||||||
_LOGGER.warning("No command template found for slot '%s' locale '%s'", slot_name, locale)
|
# Missing template = user sees "[No template: X]" — this is an ERROR,
|
||||||
|
# not a warning. Broken replies must stand out in production logs.
|
||||||
|
_LOGGER.error("No command template found for slot '%s' locale '%s'", slot_name, locale)
|
||||||
return f"[No template: {slot_name}]"
|
return f"[No template: {slot_name}]"
|
||||||
try:
|
try:
|
||||||
tmpl = _compile_template(template_str)
|
tmpl = _compile_template(template_str)
|
||||||
return tmpl.render(**context)
|
return tmpl.render(**context)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
_LOGGER.warning("Failed to render command template '%s': %s", slot_name, e)
|
_LOGGER.error(
|
||||||
|
"Failed to render command template '%s' locale=%s — user will see a broken reply",
|
||||||
|
slot_name, locale, exc_info=True,
|
||||||
|
)
|
||||||
return f"[Template error: {slot_name}]"
|
return f"[Template error: {slot_name}]"
|
||||||
|
|
||||||
|
|
||||||
@@ -296,6 +301,10 @@ async def handle_command(
|
|||||||
# Rate limit check (once per command, shared across all trackers)
|
# Rate limit check (once per command, shared across all trackers)
|
||||||
wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits)
|
wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits)
|
||||||
if wait is not None:
|
if wait is not None:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Rate-limited /%s for bot=%d chat=%s — %ds cooldown remaining",
|
||||||
|
cmd, bot.id, chat_id, wait,
|
||||||
|
)
|
||||||
text_resp = _render_cmd_template(merged_templates, "rate_limited", locale, {"wait": wait})
|
text_resp = _render_cmd_template(merged_templates, "rate_limited", locale, {"wait": wait})
|
||||||
return [CommandResponse(text=text_resp)]
|
return [CommandResponse(text=text_resp)]
|
||||||
|
|
||||||
@@ -322,8 +331,8 @@ async def handle_command(
|
|||||||
for tracker, config, provider, listener in ctx_tuples:
|
for tracker, config, provider, listener in ctx_tuples:
|
||||||
if len(responses) >= _MAX_RESPONSES_PER_COMMAND:
|
if len(responses) >= _MAX_RESPONSES_PER_COMMAND:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"Truncated command responses at %d for bot %d cmd /%s",
|
"Truncated command responses at %d for bot=%d chat=%s cmd=/%s (listener context size=%d)",
|
||||||
_MAX_RESPONSES_PER_COMMAND, bot.id, cmd,
|
_MAX_RESPONSES_PER_COMMAND, bot.id, chat_id, cmd, len(ctx_tuples),
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -418,7 +427,12 @@ async def send_reply(
|
|||||||
disable_web_page_preview=True,
|
disable_web_page_preview=True,
|
||||||
)
|
)
|
||||||
if not result.get("success"):
|
if not result.get("success"):
|
||||||
_LOGGER.warning("Telegram reply failed: %s", result.get("error"))
|
# User-visible failure: the bot's reply never reached the chat.
|
||||||
|
_LOGGER.error(
|
||||||
|
"Telegram reply failed (chat=%s reply_to=%s len=%d): code=%s error=%r",
|
||||||
|
chat_id, reply_to_message_id, len(text or ""),
|
||||||
|
result.get("error_code"), result.get("error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def send_media_group(
|
async def send_media_group(
|
||||||
@@ -442,6 +456,14 @@ async def send_media_group(
|
|||||||
assets hit the cache and skip the re-upload.
|
assets hit the cache and skip the re-upload.
|
||||||
"""
|
"""
|
||||||
if not media_items:
|
if not media_items:
|
||||||
|
# This is what happened in the /random blind spot: the text reply
|
||||||
|
# was sent, but the media follow-up was silently skipped because
|
||||||
|
# the caller passed an empty media list. Surface it so we can see
|
||||||
|
# it in the log and correlate with the text message.
|
||||||
|
_LOGGER.warning(
|
||||||
|
"send_media_group called with 0 items (chat=%s reply_to=%s) — no media will be delivered",
|
||||||
|
chat_id, reply_to_message_id,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
from ..services.telegram_send import send_telegram_media
|
from ..services.telegram_send import send_telegram_media
|
||||||
@@ -452,7 +474,13 @@ async def send_media_group(
|
|||||||
chat_action=None,
|
chat_action=None,
|
||||||
)
|
)
|
||||||
if not result.get("success"):
|
if not result.get("success"):
|
||||||
_LOGGER.warning("Telegram media group failed: %s", result.get("error"))
|
# User-visible failure: media promised by the text reply never arrived.
|
||||||
|
_LOGGER.error(
|
||||||
|
"Telegram media group failed (chat=%s items=%d reply_to=%s): code=%s error=%r failed_at_chunk=%s",
|
||||||
|
chat_id, len(media_items), reply_to_message_id,
|
||||||
|
result.get("error_code"), result.get("error"),
|
||||||
|
result.get("failed_at_chunk"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def register_commands_with_telegram(bot: TelegramBot) -> bool:
|
async def register_commands_with_telegram(bot: TelegramBot) -> bool:
|
||||||
|
|||||||
@@ -144,6 +144,7 @@ def _format_assets(
|
|||||||
# other's cached file_ids (which is what made the cache look empty
|
# other's cached file_ids (which is what made the cache look empty
|
||||||
# from the WebUI after running /random).
|
# from the WebUI after running /random).
|
||||||
media_items: list[dict[str, Any]] = []
|
media_items: list[dict[str, Any]] = []
|
||||||
|
dropped = 0
|
||||||
for asset in assets:
|
for asset in assets:
|
||||||
asset_id = asset.get("id", "")
|
asset_id = asset.get("id", "")
|
||||||
asset_type = (asset.get("type") or "").upper()
|
asset_type = (asset.get("type") or "").upper()
|
||||||
@@ -156,6 +157,20 @@ def _format_assets(
|
|||||||
)
|
)
|
||||||
if entry is not None:
|
if entry is not None:
|
||||||
media_items.append(entry)
|
media_items.append(entry)
|
||||||
|
else:
|
||||||
|
dropped += 1
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Dropped asset from /%s media payload: id=%s type=%s (empty preview URL)",
|
||||||
|
cmd, asset_id, asset_type,
|
||||||
|
)
|
||||||
|
if not media_items and assets:
|
||||||
|
# All assets were filtered out before reaching Telegram. The user
|
||||||
|
# will see the text reply but no media — surface it here so the
|
||||||
|
# log shows WHY the media group ended up empty.
|
||||||
|
_LOGGER.warning(
|
||||||
|
"/%s media payload empty: %d asset(s) in, 0 out (all dropped)",
|
||||||
|
cmd, len(assets),
|
||||||
|
)
|
||||||
# Return text message + media items — text is sent first, media as reply
|
# Return text message + media items — text is sent first, media as reply
|
||||||
return {"text": text, "media": media_items}
|
return {"text": text, "media": media_items}
|
||||||
|
|
||||||
|
|||||||
@@ -143,7 +143,16 @@ async def _cmd_immich(
|
|||||||
# chat). ``None`` = no filter (rare); empty set = show nothing (common
|
# chat). ``None`` = no filter (rare); empty set = show nothing (common
|
||||||
# when the chat has no tracker routing).
|
# when the chat has no tracker routing).
|
||||||
if allowed_album_ids is not None:
|
if allowed_album_ids is not None:
|
||||||
|
before = len(all_album_ids)
|
||||||
all_album_ids = [aid for aid in all_album_ids if aid in allowed_album_ids]
|
all_album_ids = [aid for aid in all_album_ids if aid in allowed_album_ids]
|
||||||
|
if not all_album_ids:
|
||||||
|
# A command that sees zero albums is a routing/tracker config issue
|
||||||
|
# the operator needs to notice — otherwise the user gets
|
||||||
|
# "no results" with no hint at why.
|
||||||
|
_LOGGER.info(
|
||||||
|
"Command /%s has empty album scope for provider=%d (had %d trackers, chat scope allowed %d)",
|
||||||
|
cmd, provider.id, before, len(allowed_album_ids),
|
||||||
|
)
|
||||||
|
|
||||||
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
|
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from notify_bridge_core.log_context import bind_log_context
|
||||||
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
||||||
|
|
||||||
from ..database.engine import get_session
|
from ..database.engine import get_session
|
||||||
@@ -18,6 +20,7 @@ from ..services.telegram import save_chat_from_webhook
|
|||||||
from ..services.telegram_send import telegram_chat_action
|
from ..services.telegram_send import telegram_chat_action
|
||||||
from .base import CommandResponse
|
from .base import CommandResponse
|
||||||
from .handler import classify_command_chat_action, handle_command, send_media_group, send_reply
|
from .handler import classify_command_chat_action, handle_command, send_media_group, send_reply
|
||||||
|
from .parser import parse_command
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -93,20 +96,62 @@ async def telegram_webhook(
|
|||||||
)
|
)
|
||||||
)).first()
|
)).first()
|
||||||
if not chat_row or not chat_row.commands_enabled:
|
if not chat_row or not chat_row.commands_enabled:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Command ignored — commands disabled for bot=%s chat=%s text=%r",
|
||||||
|
bot_id, chat_id, text[:64],
|
||||||
|
)
|
||||||
return {"ok": True, "skipped": "commands_disabled"}
|
return {"ok": True, "skipped": "commands_disabled"}
|
||||||
effective_lang = chat_row.language_override or msg_language
|
effective_lang = chat_row.language_override or msg_language
|
||||||
message_id = message.get("message_id")
|
message_id = message.get("message_id")
|
||||||
|
|
||||||
|
cmd_name, _, _ = parse_command(text)
|
||||||
|
update_id = update.get("update_id")
|
||||||
|
request_id = f"tg:{update_id}" if update_id is not None else f"tg:msg{message_id}"
|
||||||
|
|
||||||
|
with bind_log_context(
|
||||||
|
request_id=request_id,
|
||||||
|
command=cmd_name or "-",
|
||||||
|
chat_id=chat_id,
|
||||||
|
bot_id=bot_id,
|
||||||
|
):
|
||||||
|
started = time.monotonic()
|
||||||
|
_LOGGER.info("Command received: /%s args=%r lang=%s", cmd_name, text[:200], effective_lang)
|
||||||
|
try:
|
||||||
async with telegram_chat_action(
|
async with telegram_chat_action(
|
||||||
bot_token, chat_id, classify_command_chat_action(text),
|
bot_token, chat_id, classify_command_chat_action(text),
|
||||||
):
|
):
|
||||||
responses = await handle_command(bot, chat_id, text, language_code=effective_lang)
|
responses = await handle_command(bot, chat_id, text, language_code=effective_lang)
|
||||||
if responses:
|
if not responses:
|
||||||
for resp in responses:
|
_LOGGER.info(
|
||||||
|
"Command produced no response (cmd=%r) after %.0f ms",
|
||||||
|
cmd_name, (time.monotonic() - started) * 1000,
|
||||||
|
)
|
||||||
|
return {"ok": True, "skipped": "no_response"}
|
||||||
|
text_count = sum(1 for r in responses if r.text)
|
||||||
|
media_count = sum(len(r.media or []) for r in responses)
|
||||||
|
_LOGGER.info(
|
||||||
|
"Command dispatching %d response(s): text=%d media_items=%d",
|
||||||
|
len(responses), text_count, media_count,
|
||||||
|
)
|
||||||
|
for idx, resp in enumerate(responses):
|
||||||
if resp.text:
|
if resp.text:
|
||||||
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||||
if resp.media:
|
if resp.media:
|
||||||
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||||
|
_LOGGER.info(
|
||||||
|
"Command /%s completed in %.0f ms (responses=%d media=%d)",
|
||||||
|
cmd_name, (time.monotonic() - started) * 1000,
|
||||||
|
len(responses), media_count,
|
||||||
|
)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
except Exception:
|
||||||
|
_LOGGER.exception(
|
||||||
|
"Command /%s raised after %.0f ms",
|
||||||
|
cmd_name, (time.monotonic() - started) * 1000,
|
||||||
|
)
|
||||||
|
# Return 200 so Telegram doesn't retry the same update — we
|
||||||
|
# already logged the failure and can't usefully reprocess.
|
||||||
|
return {"ok": True, "error": "handler_exception"}
|
||||||
|
|
||||||
return {"ok": True, "skipped": "not_a_command"}
|
return {"ok": True, "skipped": "not_a_command"}
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,20 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
# Secret keys we will actively refuse. These cover the default template value
|
||||||
|
# and dev-only literals that have appeared in scripts or documentation.
|
||||||
|
_FORBIDDEN_SECRETS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"change-me-in-production",
|
||||||
|
"test-secret-key-minimum-32-chars",
|
||||||
|
"dev-secret-key-not-for-production",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""Application settings loaded from environment variables."""
|
"""Application settings loaded from environment variables."""
|
||||||
@@ -13,29 +25,25 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
secret_key: str = "change-me-in-production"
|
secret_key: str = "change-me-in-production"
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
access_token_expire_minutes: int = 15
|
||||||
if self.secret_key == "change-me-in-production":
|
|
||||||
raise ValueError(
|
|
||||||
"SECURITY: Refusing to start with the default secret_key. "
|
|
||||||
"Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) "
|
|
||||||
"before starting the server (debug mode included)."
|
|
||||||
)
|
|
||||||
if len(self.secret_key) < 32:
|
|
||||||
raise ValueError(
|
|
||||||
"SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters."
|
|
||||||
)
|
|
||||||
if "*" in self.cors_allowed_origins.split(","):
|
|
||||||
raise ValueError(
|
|
||||||
"SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled."
|
|
||||||
)
|
|
||||||
|
|
||||||
access_token_expire_minutes: int = 60
|
|
||||||
refresh_token_expire_days: int = 30
|
refresh_token_expire_days: int = 30
|
||||||
|
|
||||||
|
jwt_issuer: str = "notify-bridge"
|
||||||
|
jwt_audience: str = "notify-bridge-api"
|
||||||
|
|
||||||
host: str = "0.0.0.0"
|
host: str = "0.0.0.0"
|
||||||
port: int = 8420
|
port: int = 8420
|
||||||
debug: bool = False
|
debug: bool = False
|
||||||
|
|
||||||
|
# Comma-separated list of trusted proxy IPs uvicorn will honor for
|
||||||
|
# X-Forwarded-For / X-Forwarded-Proto. Use "*" ONLY when you trust the
|
||||||
|
# network (never directly on the internet). Default matches uvicorn.
|
||||||
|
forwarded_allow_ips: str = "127.0.0.1"
|
||||||
|
|
||||||
|
# How long to wait for in-flight requests / scheduler jobs before force
|
||||||
|
# killing on SIGTERM.
|
||||||
|
graceful_shutdown_seconds: int = 60
|
||||||
|
|
||||||
anthropic_api_key: str = ""
|
anthropic_api_key: str = ""
|
||||||
ai_model: str = "claude-sonnet-4-20250514"
|
ai_model: str = "claude-sonnet-4-20250514"
|
||||||
ai_max_tokens: int = 1024
|
ai_max_tokens: int = 1024
|
||||||
@@ -48,8 +56,61 @@ class Settings(BaseSettings):
|
|||||||
static_dir: str = ""
|
static_dir: str = ""
|
||||||
"""Path to frontend static files. Set to serve SvelteKit build via FastAPI (e.g. /app/static in Docker)."""
|
"""Path to frontend static files. Set to serve SvelteKit build via FastAPI (e.g. /app/static in Docker)."""
|
||||||
|
|
||||||
|
# --- Logging ---
|
||||||
|
log_level: str = "INFO"
|
||||||
|
"""Root log level for the app loggers (``DEBUG``/``INFO``/``WARNING``/``ERROR``)."""
|
||||||
|
|
||||||
|
log_format: str = "text"
|
||||||
|
"""Log output format: ``text`` (human-readable) or ``json`` (one object per line)."""
|
||||||
|
|
||||||
|
log_levels: str = ""
|
||||||
|
"""Comma-separated per-module overrides, e.g. ``notify_bridge_core.notifications.telegram.client=DEBUG,sqlalchemy.engine=INFO``."""
|
||||||
|
|
||||||
|
# --- Retention ---
|
||||||
|
event_log_retention_days: int = 30
|
||||||
|
"""Days of event_log history to retain. 0 disables the retention job."""
|
||||||
|
|
||||||
|
pre_migrate_snapshot_keep: int = 5
|
||||||
|
"""Number of pre-migration DB snapshots to keep in ``data_dir/backups/``.
|
||||||
|
0 disables snapshotting entirely. Each snapshot is produced at boot
|
||||||
|
before migrations run using SQLite's ``VACUUM INTO`` (atomic, consistent).
|
||||||
|
"""
|
||||||
|
|
||||||
model_config = {"env_prefix": "NOTIFY_BRIDGE_"}
|
model_config = {"env_prefix": "NOTIFY_BRIDGE_"}
|
||||||
|
|
||||||
|
def model_post_init(self, __context: Any) -> None:
|
||||||
|
if self.secret_key in _FORBIDDEN_SECRETS:
|
||||||
|
raise ValueError(
|
||||||
|
"SECURITY: Refusing to start with a known/default secret_key. "
|
||||||
|
"Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) "
|
||||||
|
"before starting the server."
|
||||||
|
)
|
||||||
|
if len(self.secret_key) < 32:
|
||||||
|
raise ValueError(
|
||||||
|
"SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters."
|
||||||
|
)
|
||||||
|
origins = [o.strip() for o in self.cors_allowed_origins.split(",") if o.strip()]
|
||||||
|
if "*" in origins:
|
||||||
|
raise ValueError(
|
||||||
|
"SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled."
|
||||||
|
)
|
||||||
|
for origin in origins:
|
||||||
|
parsed = urlparse(origin)
|
||||||
|
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||||||
|
raise ValueError(
|
||||||
|
f"CORS origin {origin!r} is invalid — must include scheme (http/https) and host."
|
||||||
|
)
|
||||||
|
if self.access_token_expire_minutes <= 0:
|
||||||
|
raise ValueError("access_token_expire_minutes must be > 0")
|
||||||
|
if self.refresh_token_expire_days <= 0:
|
||||||
|
raise ValueError("refresh_token_expire_days must be > 0")
|
||||||
|
if not (1 <= self.port <= 65535):
|
||||||
|
raise ValueError("port must be in range 1..65535")
|
||||||
|
if self.event_log_retention_days < 0:
|
||||||
|
raise ValueError("event_log_retention_days must be >= 0")
|
||||||
|
if self.pre_migrate_snapshot_keep < 0:
|
||||||
|
raise ValueError("pre_migrate_snapshot_keep must be >= 0")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def effective_database_url(self) -> str:
|
def effective_database_url(self) -> str:
|
||||||
if self.database_url:
|
if self.database_url:
|
||||||
|
|||||||
@@ -1,23 +1,59 @@
|
|||||||
"""Database engine and session management."""
|
"""Database engine and session management."""
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from sqlalchemy import event
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||||
|
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
_engine: AsyncEngine | None = None
|
_engine: AsyncEngine | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _install_sqlite_pragmas(engine: AsyncEngine) -> None:
|
||||||
|
"""Apply production-grade SQLite PRAGMAs on every new connection.
|
||||||
|
|
||||||
|
WAL mode lets readers and writers work concurrently without blocking;
|
||||||
|
``busy_timeout`` gives contending writers a chance instead of instant
|
||||||
|
SQLITE_BUSY; ``foreign_keys`` enforces the FK constraints declared in the
|
||||||
|
models (SQLite disables them by default); ``synchronous=NORMAL`` is a
|
||||||
|
safe-by-default durability trade-off that is standard in WAL mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
|
def _pragmas(dbapi_conn, _record): # pragma: no cover — driver hook
|
||||||
|
cur = dbapi_conn.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute("PRAGMA journal_mode=WAL")
|
||||||
|
cur.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
cur.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cur.execute("PRAGMA busy_timeout=10000")
|
||||||
|
cur.execute("PRAGMA temp_store=MEMORY")
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
|
||||||
|
|
||||||
def get_engine() -> AsyncEngine:
|
def get_engine() -> AsyncEngine:
|
||||||
global _engine
|
global _engine
|
||||||
if _engine is None:
|
if _engine is None:
|
||||||
|
url = settings.effective_database_url
|
||||||
|
connect_args: dict = {}
|
||||||
|
if url.startswith("sqlite"):
|
||||||
|
connect_args["timeout"] = 30
|
||||||
_engine = create_async_engine(
|
_engine = create_async_engine(
|
||||||
settings.effective_database_url,
|
url,
|
||||||
echo=settings.debug,
|
echo=settings.debug,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
connect_args=connect_args,
|
||||||
)
|
)
|
||||||
|
if url.startswith("sqlite"):
|
||||||
|
_install_sqlite_pragmas(_engine)
|
||||||
|
_LOGGER.info("Database engine initialized: %s", url.split("://", 1)[0])
|
||||||
return _engine
|
return _engine
|
||||||
|
|
||||||
|
|
||||||
@@ -31,3 +67,11 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
|||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
async with AsyncSession(engine) as session:
|
async with AsyncSession(engine) as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def dispose_engine() -> None:
|
||||||
|
"""Close the engine's connection pool. Call during graceful shutdown."""
|
||||||
|
global _engine
|
||||||
|
if _engine is not None:
|
||||||
|
await _engine.dispose()
|
||||||
|
_engine = None
|
||||||
|
|||||||
@@ -1282,3 +1282,141 @@ async def migrate_user_token_version(engine: AsyncEngine) -> None:
|
|||||||
text("ALTER TABLE user ADD COLUMN token_version INTEGER NOT NULL DEFAULT 1")
|
text("ALTER TABLE user ADD COLUMN token_version INTEGER NOT NULL DEFAULT 1")
|
||||||
)
|
)
|
||||||
logger.info("Added token_version column to user table")
|
logger.info("Added token_version column to user table")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Performance indexes — covers every FK / owner column the list endpoints
|
||||||
|
# and the webhook hot-path filter on. All use CREATE INDEX IF NOT EXISTS so
|
||||||
|
# they are safe to re-run on every boot.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_INDEXES: list[tuple[str, str, str]] = [
|
||||||
|
# (index_name, table, columns)
|
||||||
|
("ix_service_provider_user_id", "service_provider", "user_id"),
|
||||||
|
("ix_telegram_bot_user_id", "telegram_bot", "user_id"),
|
||||||
|
("ix_matrix_bot_user_id", "matrix_bot", "user_id"),
|
||||||
|
("ix_email_bot_user_id", "email_bot", "user_id"),
|
||||||
|
("ix_telegram_chat_bot_id", "telegram_chat", "bot_id"),
|
||||||
|
("ix_tracking_config_user_id", "tracking_config", "user_id"),
|
||||||
|
("ix_tracking_config_provider_type", "tracking_config", "provider_type"),
|
||||||
|
("ix_notification_target_user_id", "notification_target", "user_id"),
|
||||||
|
("ix_notification_target_type", "notification_target", "type"),
|
||||||
|
("ix_notification_tracker_user_id", "notification_tracker", "user_id"),
|
||||||
|
("ix_notification_tracker_provider_id", "notification_tracker", "provider_id"),
|
||||||
|
# Composite for the webhook hot path: WHERE provider_id = ? AND enabled = true
|
||||||
|
(
|
||||||
|
"ix_notification_tracker_provider_enabled",
|
||||||
|
"notification_tracker",
|
||||||
|
"provider_id, enabled",
|
||||||
|
),
|
||||||
|
("ix_command_config_user_id", "command_config", "user_id"),
|
||||||
|
("ix_command_template_config_user_id", "command_template_config", "user_id"),
|
||||||
|
("ix_command_tracker_user_id", "command_tracker", "user_id"),
|
||||||
|
("ix_command_tracker_provider_id", "command_tracker", "provider_id"),
|
||||||
|
("ix_action_user_id", "action", "user_id"),
|
||||||
|
("ix_action_provider_id", "action", "provider_id"),
|
||||||
|
# Dashboard: SELECT event_log WHERE user_id = ? ORDER BY created_at DESC
|
||||||
|
("ix_event_log_user_created", "event_log", "user_id, created_at DESC"),
|
||||||
|
("ix_event_log_provider_id", "event_log", "provider_id"),
|
||||||
|
("ix_event_log_notification_tracker_id", "event_log", "notification_tracker_id"),
|
||||||
|
("ix_event_log_action_id", "event_log", "action_id"),
|
||||||
|
# Webhook log hot path: WHERE provider_id = ? ORDER BY created_at DESC
|
||||||
|
(
|
||||||
|
"ix_webhook_payload_log_provider_created",
|
||||||
|
"webhook_payload_log",
|
||||||
|
"provider_id, created_at DESC",
|
||||||
|
),
|
||||||
|
# Notification tracker join tables
|
||||||
|
(
|
||||||
|
"ix_notification_tracker_target_notification_tracker_id",
|
||||||
|
"notification_tracker_target",
|
||||||
|
"notification_tracker_id",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ix_notification_tracker_target_target_id",
|
||||||
|
"notification_tracker_target",
|
||||||
|
"target_id",
|
||||||
|
),
|
||||||
|
("ix_target_receiver_target_id", "target_receiver", "target_id"),
|
||||||
|
("ix_template_slot_config_id", "template_slot", "config_id"),
|
||||||
|
("ix_command_template_slot_config_id", "command_template_slot", "config_id"),
|
||||||
|
("ix_action_rule_action_id", "action_rule", "action_id"),
|
||||||
|
("ix_action_execution_action_started", "action_execution", "action_id, started_at DESC"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate_performance_indexes(engine: AsyncEngine) -> None:
|
||||||
|
"""Create missing performance indexes on hot query paths.
|
||||||
|
|
||||||
|
Every index is created with IF NOT EXISTS so the migration is safe to
|
||||||
|
replay on every boot. We only create the index when the table exists —
|
||||||
|
early boots before other migrations land would otherwise raise.
|
||||||
|
"""
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
for name, table, columns in _INDEXES:
|
||||||
|
_assert_ident(name, "index")
|
||||||
|
_assert_ident(table, "table")
|
||||||
|
# Columns list is a trusted literal constructed above — never user input.
|
||||||
|
if not await _has_table(conn, table):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await conn.execute(
|
||||||
|
text(f"CREATE INDEX IF NOT EXISTS {name} ON {table} ({columns})")
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover — log and continue
|
||||||
|
logger.warning(
|
||||||
|
"Failed to create index %s on %s(%s)",
|
||||||
|
name, table, columns, exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Schema version tracking — lightweight alternative to Alembic while the
|
||||||
|
# hand-rolled idempotent migrations remain the source of truth. Gives
|
||||||
|
# operators a single-row answer to "what schema is this DB at" and lets
|
||||||
|
# future upgrades short-circuit migrations that already ran.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
CURRENT_SCHEMA_VERSION = 1
|
||||||
|
|
||||||
|
|
||||||
|
async def migrate_schema_version(engine: AsyncEngine) -> None:
|
||||||
|
"""Create schema_version table and bump it to CURRENT_SCHEMA_VERSION."""
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"CREATE TABLE IF NOT EXISTS schema_version ("
|
||||||
|
" id INTEGER PRIMARY KEY CHECK (id = 1),"
|
||||||
|
" version INTEGER NOT NULL,"
|
||||||
|
" applied_at TEXT NOT NULL"
|
||||||
|
")"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = await conn.run_sync(
|
||||||
|
lambda sc: sc.execute(
|
||||||
|
text("SELECT version FROM schema_version WHERE id = 1")
|
||||||
|
).fetchone()
|
||||||
|
)
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
if row is None:
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT INTO schema_version (id, version, applied_at) "
|
||||||
|
"VALUES (1, :v, :t)"
|
||||||
|
),
|
||||||
|
{"v": CURRENT_SCHEMA_VERSION, "t": now},
|
||||||
|
)
|
||||||
|
logger.info("Initialized schema_version at %d", CURRENT_SCHEMA_VERSION)
|
||||||
|
elif int(row[0]) < CURRENT_SCHEMA_VERSION:
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"UPDATE schema_version SET version = :v, applied_at = :t "
|
||||||
|
"WHERE id = 1"
|
||||||
|
),
|
||||||
|
{"v": CURRENT_SCHEMA_VERSION, "t": now},
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Bumped schema_version from %s to %d",
|
||||||
|
row[0], CURRENT_SCHEMA_VERSION,
|
||||||
|
)
|
||||||
|
|||||||
@@ -394,8 +394,37 @@ async def _seed_default_command_configs() -> None:
|
|||||||
# Public entry point
|
# Public entry point
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _ensure_system_user() -> None:
|
||||||
|
"""Ensure a User row with id=0 exists.
|
||||||
|
|
||||||
|
Historically the app used ``user_id=0`` as a sentinel for "system-owned"
|
||||||
|
defaults (tracking configs, templates, etc.). Now that we enable
|
||||||
|
``PRAGMA foreign_keys=ON`` at connect time, those inserts would fail
|
||||||
|
with ``FOREIGN KEY constraint failed`` unless a placeholder user row
|
||||||
|
with the matching id exists.
|
||||||
|
"""
|
||||||
|
engine = get_engine()
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
# INSERT OR IGNORE so re-running seeds is cheap and idempotent.
|
||||||
|
await conn.execute(
|
||||||
|
text(
|
||||||
|
"INSERT OR IGNORE INTO user "
|
||||||
|
"(id, username, hashed_password, role, token_version, created_at) "
|
||||||
|
"VALUES (0, :u, :p, :r, 1, :t)"
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"u": "__system__",
|
||||||
|
# Invalid bcrypt hash — nobody can ever log in as this user.
|
||||||
|
"p": "!disabled!",
|
||||||
|
"r": "system",
|
||||||
|
"t": datetime.now(timezone.utc).isoformat(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def seed_all() -> None:
|
async def seed_all() -> None:
|
||||||
"""Run all seed functions in order."""
|
"""Run all seed functions in order."""
|
||||||
|
await _ensure_system_user()
|
||||||
await _seed_default_templates()
|
await _seed_default_templates()
|
||||||
await _seed_default_command_templates()
|
await _seed_default_command_templates()
|
||||||
await _seed_default_tracking_configs()
|
await _seed_default_tracking_configs()
|
||||||
|
|||||||
@@ -0,0 +1,155 @@
|
|||||||
|
"""Pre-migration database snapshots.
|
||||||
|
|
||||||
|
Runs at lifespan startup BEFORE migrations execute. Produces a consistent
|
||||||
|
point-in-time copy of the SQLite database using ``VACUUM INTO`` (atomic,
|
||||||
|
cannot tear against concurrent activity, works with WAL).
|
||||||
|
|
||||||
|
The snapshot is the operator's fallback if a future migration corrupts the
|
||||||
|
schema — restore is a single ``mv`` / ``docker cp``. We keep the N most
|
||||||
|
recent files (default 5) and never fail startup if the snapshot itself
|
||||||
|
fails: a snapshot is best-effort safety net, not a gate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SNAPSHOT_GLOB = "pre-migrate-*.db"
|
||||||
|
_SNAPSHOT_NAME_RE = re.compile(r"^[A-Za-z0-9._+\-:]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def _sqlite_path_from_url(url: str) -> Path | None:
|
||||||
|
"""Extract the filesystem path from a ``sqlite+aiosqlite:///...`` URL."""
|
||||||
|
if not url.startswith("sqlite"):
|
||||||
|
return None
|
||||||
|
# e.g. "sqlite+aiosqlite:///C:/data/notify_bridge.db"
|
||||||
|
prefix, _, rest = url.partition(":///")
|
||||||
|
if not rest:
|
||||||
|
return None
|
||||||
|
return Path(rest)
|
||||||
|
|
||||||
|
|
||||||
|
async def snapshot_database(
|
||||||
|
engine: AsyncEngine,
|
||||||
|
target_dir: Path,
|
||||||
|
*,
|
||||||
|
label: str = "pre-migrate",
|
||||||
|
) -> Path | None:
|
||||||
|
"""Write a consistent copy of the SQLite DB to ``target_dir``.
|
||||||
|
|
||||||
|
Uses ``VACUUM INTO`` which SQLite executes atomically against a read
|
||||||
|
snapshot — safe under WAL, cannot produce a torn copy. Returns the
|
||||||
|
snapshot path on success, ``None`` when skipped or on non-fatal
|
||||||
|
failure. Never raises: callers treat a missing snapshot as acceptable
|
||||||
|
(the main DB remains the source of truth).
|
||||||
|
"""
|
||||||
|
if not _SNAPSHOT_NAME_RE.match(label):
|
||||||
|
_LOGGER.warning("Snapshot label %r contains unsafe characters; skipping", label)
|
||||||
|
return None
|
||||||
|
|
||||||
|
url = str(engine.url)
|
||||||
|
src = _sqlite_path_from_url(url)
|
||||||
|
if src is None:
|
||||||
|
_LOGGER.debug("Non-SQLite engine; skipping snapshot")
|
||||||
|
return None
|
||||||
|
if not src.exists():
|
||||||
|
_LOGGER.debug("DB file %s does not exist yet (fresh install); skipping snapshot", src)
|
||||||
|
return None
|
||||||
|
|
||||||
|
target_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
|
||||||
|
dest = target_dir / f"{label}-{ts}.db"
|
||||||
|
|
||||||
|
# VACUUM INTO accepts a string literal, not a bind parameter. The dest
|
||||||
|
# path is built from our own label + timestamp (never user input), so
|
||||||
|
# escaping is straightforward — still, reject any dest containing a
|
||||||
|
# single quote as a belt-and-braces check.
|
||||||
|
dest_str = str(dest)
|
||||||
|
if "'" in dest_str:
|
||||||
|
_LOGGER.warning("Refusing to snapshot to path containing a single quote: %s", dest_str)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
# VACUUM cannot run inside an explicit transaction; use the
|
||||||
|
# plain connection without begin().
|
||||||
|
await conn.execute(text(f"VACUUM INTO '{dest_str}'"))
|
||||||
|
_LOGGER.info("Database snapshot written: %s (%.1f KiB)", dest, dest.stat().st_size / 1024)
|
||||||
|
return dest
|
||||||
|
except Exception:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Pre-migration snapshot failed — continuing with startup. "
|
||||||
|
"Check disk space in %s.",
|
||||||
|
target_dir,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Partial file can linger if VACUUM INTO aborted mid-write; clean up.
|
||||||
|
try:
|
||||||
|
if dest.exists():
|
||||||
|
dest.unlink()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def prune_old_snapshots(target_dir: Path, keep: int) -> list[Path]:
|
||||||
|
"""Keep the ``keep`` most recent pre-migrate snapshots, delete the rest.
|
||||||
|
|
||||||
|
Returns the list of paths that were deleted. Safe to call with
|
||||||
|
``keep=0`` (deletes everything) or when the directory does not exist.
|
||||||
|
"""
|
||||||
|
if keep < 0:
|
||||||
|
raise ValueError("keep must be >= 0")
|
||||||
|
if not target_dir.is_dir():
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
snapshots = sorted(
|
||||||
|
target_dir.glob(_SNAPSHOT_GLOB),
|
||||||
|
key=lambda p: p.stat().st_mtime,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
deleted: list[Path] = []
|
||||||
|
for old in snapshots[keep:]:
|
||||||
|
try:
|
||||||
|
old.unlink()
|
||||||
|
deleted.append(old)
|
||||||
|
except OSError:
|
||||||
|
_LOGGER.debug("Could not delete old snapshot %s", old, exc_info=True)
|
||||||
|
if deleted:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Pruned %d old pre-migrate snapshot(s); kept %d most recent",
|
||||||
|
len(deleted), min(keep, len(snapshots)),
|
||||||
|
)
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
|
||||||
|
async def snapshot_and_prune(
|
||||||
|
engine: AsyncEngine,
|
||||||
|
target_dir: Path,
|
||||||
|
*,
|
||||||
|
keep: int,
|
||||||
|
) -> Path | None:
|
||||||
|
"""Take a snapshot and prune old ones. Used by the lifespan startup path.
|
||||||
|
|
||||||
|
``keep=0`` disables snapshotting entirely.
|
||||||
|
"""
|
||||||
|
if keep <= 0:
|
||||||
|
return None
|
||||||
|
snapshot_path = await snapshot_database(engine, target_dir)
|
||||||
|
# Always prune even if this run's snapshot failed — old files still
|
||||||
|
# cost disk and may have been written by prior successful boots.
|
||||||
|
await asyncio.to_thread(prune_old_snapshots, target_dir, keep)
|
||||||
|
return snapshot_path
|
||||||
@@ -0,0 +1,324 @@
|
|||||||
|
"""Production-grade logging configuration.
|
||||||
|
|
||||||
|
Installs one ``dictConfig`` layout with:
|
||||||
|
|
||||||
|
* A ``LogRecordFactory`` that pulls request-scoped identifiers from
|
||||||
|
``notify_bridge_core.log_context`` onto every record, so logs can be
|
||||||
|
filtered/correlated by ``request_id``, ``command``, ``chat_id``,
|
||||||
|
``bot_id``, ``dispatch_id`` without each call site passing them.
|
||||||
|
* A ``SecretMaskingFilter`` that redacts Telegram bot tokens and common
|
||||||
|
``Authorization`` / ``x-api-key`` headers so an accidental ``repr`` or
|
||||||
|
dumped request doesn't leak credentials into the log aggregator.
|
||||||
|
* A text formatter (default) or a JSON formatter (one line per record)
|
||||||
|
selectable via ``NOTIFY_BRIDGE_LOG_FORMAT`` / app setting.
|
||||||
|
|
||||||
|
Levels are configurable three ways (later wins):
|
||||||
|
|
||||||
|
1. ``NOTIFY_BRIDGE_LOG_LEVEL`` env var (root) plus
|
||||||
|
``NOTIFY_BRIDGE_LOG_LEVELS`` (``mod=LEVEL,mod2=LEVEL``).
|
||||||
|
2. DB ``AppSetting`` rows ``log_level`` / ``log_levels`` / ``log_format``,
|
||||||
|
applied after migrations during startup.
|
||||||
|
3. Live edits via the settings API — ``apply_log_levels()`` updates
|
||||||
|
existing loggers in place without a server restart.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from notify_bridge_core.log_context import (
|
||||||
|
bot_id_var,
|
||||||
|
chat_id_var,
|
||||||
|
command_var,
|
||||||
|
dispatch_id_var,
|
||||||
|
request_id_var,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Secret masking
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Telegram bot tokens: /bot<digits>:<alnum with dashes/underscores>
|
||||||
|
_TELEGRAM_TOKEN_RE = re.compile(r"/bot\d+:[A-Za-z0-9_-]{20,}")
|
||||||
|
|
||||||
|
# Header-style secrets: Authorization: Bearer xxx, x-api-key=xxx, etc.
|
||||||
|
# Only matches reasonably long tokens so short legitimate values don't trip.
|
||||||
|
_HEADER_SECRET_RE = re.compile(
|
||||||
|
r"(?i)(authorization|x-api-key|api[_-]?key|password|secret|access[_-]?token|refresh[_-]?token)"
|
||||||
|
r"([\"']?\s*[:=]\s*[\"']?)"
|
||||||
|
r"([A-Za-z0-9._+/=\-]{12,})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _mask(text: str) -> str:
|
||||||
|
redacted = _TELEGRAM_TOKEN_RE.sub("/bot***", text)
|
||||||
|
redacted = _HEADER_SECRET_RE.sub(r"\1\2***", redacted)
|
||||||
|
return redacted
|
||||||
|
|
||||||
|
|
||||||
|
class SecretMaskingFilter(logging.Filter):
|
||||||
|
"""Redact likely secrets from every log message before it's emitted.
|
||||||
|
|
||||||
|
Covers three surfaces where a leaked token can end up in the log:
|
||||||
|
the formatted message, a cached exception traceback (``exc_text``),
|
||||||
|
and a cached stack frame dump (``stack_info``). The formatter still
|
||||||
|
expands ``exc_info`` for us when ``exc_text`` is None, so we also
|
||||||
|
pre-render + mask on first emission.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def filter(self, record: logging.LogRecord) -> bool:
|
||||||
|
try:
|
||||||
|
msg = record.getMessage()
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
redacted = _mask(msg)
|
||||||
|
if redacted != msg:
|
||||||
|
# Replace the formatted message and drop args so the handler
|
||||||
|
# doesn't re-format with the original values.
|
||||||
|
record.msg = redacted
|
||||||
|
record.args = ()
|
||||||
|
|
||||||
|
if record.exc_info and not record.exc_text:
|
||||||
|
# Pre-render so we can mask before the formatter caches it.
|
||||||
|
fmt = logging.Formatter()
|
||||||
|
record.exc_text = fmt.formatException(record.exc_info)
|
||||||
|
if record.exc_text:
|
||||||
|
record.exc_text = _mask(record.exc_text)
|
||||||
|
if record.stack_info:
|
||||||
|
record.stack_info = _mask(record.stack_info)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Record factory — injects context identifiers onto every record
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_CONTEXT_FIELDS = ("request_id", "command", "chat_id", "bot_id", "dispatch_id")
|
||||||
|
_PLACEHOLDER = "-"
|
||||||
|
|
||||||
|
_original_factory = logging.getLogRecordFactory()
|
||||||
|
|
||||||
|
|
||||||
|
def _context_record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
|
||||||
|
record = _original_factory(*args, **kwargs)
|
||||||
|
record.request_id = request_id_var.get() or _PLACEHOLDER
|
||||||
|
record.command = command_var.get() or _PLACEHOLDER
|
||||||
|
record.chat_id = chat_id_var.get() or _PLACEHOLDER
|
||||||
|
bid = bot_id_var.get()
|
||||||
|
record.bot_id = str(bid) if bid is not None else _PLACEHOLDER
|
||||||
|
record.dispatch_id = dispatch_id_var.get() or _PLACEHOLDER
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# JSON formatter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class JsonFormatter(logging.Formatter):
|
||||||
|
"""Emit one JSON object per log record."""
|
||||||
|
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"ts": self.formatTime(record, "%Y-%m-%dT%H:%M:%S") + f".{int(record.msecs):03d}",
|
||||||
|
"level": record.levelname,
|
||||||
|
"logger": record.name,
|
||||||
|
"module": record.module,
|
||||||
|
"line": record.lineno,
|
||||||
|
"msg": record.getMessage(),
|
||||||
|
}
|
||||||
|
for field in _CONTEXT_FIELDS:
|
||||||
|
val = getattr(record, field, None)
|
||||||
|
if val and val != _PLACEHOLDER:
|
||||||
|
payload[field] = val
|
||||||
|
# Prefer the pre-masked exc_text cached by SecretMaskingFilter over
|
||||||
|
# re-formatting from exc_info, which would bypass the mask.
|
||||||
|
if record.exc_text:
|
||||||
|
payload["exc"] = record.exc_text
|
||||||
|
elif record.exc_info:
|
||||||
|
payload["exc"] = self.formatException(record.exc_info)
|
||||||
|
if record.stack_info:
|
||||||
|
payload["stack"] = record.stack_info
|
||||||
|
return json.dumps(payload, ensure_ascii=False, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Text formatter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Keeps all context fields on one line so grep-by-field works. Empty values
|
||||||
|
# are rendered as "-" by the record factory to avoid KeyError if a record
|
||||||
|
# arrives without the filter.
|
||||||
|
_TEXT_FORMAT = (
|
||||||
|
"%(asctime)s %(levelname)-7s %(name)s:%(lineno)d "
|
||||||
|
"[req=%(request_id)s cmd=%(command)s bot=%(bot_id)s chat=%(chat_id)s disp=%(dispatch_id)s] "
|
||||||
|
"%(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Level override parsing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_VALID_LEVELS = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", "NOTSET"}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_level_overrides(raw: str) -> dict[str, str]:
|
||||||
|
"""Parse ``module=LEVEL,module2=LEVEL`` into a mapping of validated levels.
|
||||||
|
|
||||||
|
Invalid entries (bad format, unknown level) are silently dropped —
|
||||||
|
a malformed env var or DB setting must not crash boot.
|
||||||
|
"""
|
||||||
|
result: dict[str, str] = {}
|
||||||
|
for chunk in (raw or "").split(","):
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if not chunk or "=" not in chunk:
|
||||||
|
continue
|
||||||
|
mod, _, lvl = chunk.partition("=")
|
||||||
|
mod = mod.strip()
|
||||||
|
lvl = lvl.strip().upper()
|
||||||
|
if not mod or lvl not in _VALID_LEVELS:
|
||||||
|
continue
|
||||||
|
result[mod] = lvl
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_level(level: str | None, default: str = "INFO") -> str:
|
||||||
|
if not level:
|
||||||
|
return default
|
||||||
|
up = level.strip().upper()
|
||||||
|
return up if up in _VALID_LEVELS else default
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Setup + live apply
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Libraries we quiet by default — noisy at DEBUG and almost always irrelevant
|
||||||
|
# to a service issue. Override via LOG_LEVELS=sqlalchemy.engine=DEBUG when
|
||||||
|
# actually debugging.
|
||||||
|
_NOISY_LIBRARY_DEFAULTS: dict[str, str] = {
|
||||||
|
"sqlalchemy": "WARNING",
|
||||||
|
"sqlalchemy.engine": "WARNING",
|
||||||
|
"sqlalchemy.pool": "WARNING",
|
||||||
|
"aiohttp": "WARNING",
|
||||||
|
"aiohttp.access": "WARNING",
|
||||||
|
"aiohttp.client": "WARNING",
|
||||||
|
"aiohttp.server": "WARNING",
|
||||||
|
"apscheduler": "WARNING",
|
||||||
|
"apscheduler.scheduler": "WARNING",
|
||||||
|
"apscheduler.executors.default": "WARNING",
|
||||||
|
"urllib3": "WARNING",
|
||||||
|
"asyncio": "WARNING",
|
||||||
|
"httpx": "WARNING",
|
||||||
|
"httpcore": "WARNING",
|
||||||
|
"PIL": "WARNING",
|
||||||
|
"uvicorn.access": "WARNING",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(
|
||||||
|
*,
|
||||||
|
level: str = "INFO",
|
||||||
|
fmt: str = "text",
|
||||||
|
per_module_levels: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""Install the logging configuration. Safe to call more than once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Root log level (applied to ``notify_bridge_*`` loggers).
|
||||||
|
fmt: ``"text"`` (default) or ``"json"``.
|
||||||
|
per_module_levels: ``mod=LEVEL,mod2=LEVEL`` overrides. Wins over the
|
||||||
|
root level for the listed loggers.
|
||||||
|
"""
|
||||||
|
root_level = _normalize_level(level, "INFO")
|
||||||
|
overrides = parse_level_overrides(per_module_levels)
|
||||||
|
|
||||||
|
# Install the context-aware record factory (idempotent — setting the same
|
||||||
|
# factory twice is fine because ``_original_factory`` is captured at
|
||||||
|
# import time).
|
||||||
|
logging.setLogRecordFactory(_context_record_factory)
|
||||||
|
|
||||||
|
if fmt == "json":
|
||||||
|
formatters = {"default": {"()": f"{__name__}.JsonFormatter"}}
|
||||||
|
else:
|
||||||
|
formatters = {
|
||||||
|
"default": {
|
||||||
|
"format": _TEXT_FORMAT,
|
||||||
|
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Start with noisy-library defaults, then layer user overrides on top so
|
||||||
|
# the user can raise them to DEBUG when actually debugging.
|
||||||
|
loggers: dict[str, dict[str, Any]] = {}
|
||||||
|
for mod, lvl in _NOISY_LIBRARY_DEFAULTS.items():
|
||||||
|
loggers[mod] = {"level": lvl, "propagate": True}
|
||||||
|
# App loggers follow the root level unless overridden.
|
||||||
|
loggers["notify_bridge_server"] = {"level": root_level, "propagate": True}
|
||||||
|
loggers["notify_bridge_core"] = {"level": root_level, "propagate": True}
|
||||||
|
# User overrides win.
|
||||||
|
for mod, lvl in overrides.items():
|
||||||
|
loggers[mod] = {"level": lvl, "propagate": True}
|
||||||
|
|
||||||
|
config: dict[str, Any] = {
|
||||||
|
"version": 1,
|
||||||
|
"disable_existing_loggers": False,
|
||||||
|
"filters": {
|
||||||
|
"mask_secrets": {"()": f"{__name__}.SecretMaskingFilter"},
|
||||||
|
},
|
||||||
|
"formatters": formatters,
|
||||||
|
"handlers": {
|
||||||
|
"stderr": {
|
||||||
|
"class": "logging.StreamHandler",
|
||||||
|
"stream": sys.stderr,
|
||||||
|
"formatter": "default",
|
||||||
|
"filters": ["mask_secrets"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"level": root_level,
|
||||||
|
"handlers": ["stderr"],
|
||||||
|
},
|
||||||
|
"loggers": loggers,
|
||||||
|
}
|
||||||
|
logging.config.dictConfig(config)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_log_levels(
|
||||||
|
*,
|
||||||
|
level: str | None,
|
||||||
|
per_module_levels: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Update existing logger levels in-place without rebuilding handlers.
|
||||||
|
|
||||||
|
Called when an admin changes the log settings at runtime. Setting
|
||||||
|
``level`` to None leaves the root untouched; setting it to a valid
|
||||||
|
level applies to ``notify_bridge_server`` / ``notify_bridge_core``.
|
||||||
|
|
||||||
|
``per_module_levels`` is treated as an exclusive set — loggers that
|
||||||
|
previously had an override but aren't in the new string are reset
|
||||||
|
*toward* the root level so a removed override actually takes effect.
|
||||||
|
"""
|
||||||
|
if level:
|
||||||
|
lvl = _normalize_level(level, "INFO")
|
||||||
|
logging.getLogger("notify_bridge_server").setLevel(lvl)
|
||||||
|
logging.getLogger("notify_bridge_core").setLevel(lvl)
|
||||||
|
# NOTSET on root is almost never what you want — keep root where it is
|
||||||
|
# unless the caller explicitly set something.
|
||||||
|
logging.getLogger().setLevel(lvl)
|
||||||
|
|
||||||
|
if per_module_levels is not None:
|
||||||
|
overrides = parse_level_overrides(per_module_levels)
|
||||||
|
# Apply new overrides
|
||||||
|
for mod, lvl in overrides.items():
|
||||||
|
logging.getLogger(mod).setLevel(lvl)
|
||||||
|
# Reset noisy libs that aren't in the new overrides back to defaults
|
||||||
|
for mod, default_lvl in _NOISY_LIBRARY_DEFAULTS.items():
|
||||||
|
if mod not in overrides:
|
||||||
|
logging.getLogger(mod).setLevel(default_lvl)
|
||||||
@@ -9,13 +9,18 @@ from slowapi import _rate_limit_exceeded_handler
|
|||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
from slowapi.middleware import SlowAPIMiddleware
|
from slowapi.middleware import SlowAPIMiddleware
|
||||||
|
|
||||||
# Ensure app-level loggers are visible
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
from .config import settings as _log_cfg
|
from .config import settings as _log_cfg
|
||||||
_log_level = logging.DEBUG if _log_cfg.debug else logging.INFO
|
from .logging_setup import setup_logging
|
||||||
logging.getLogger("notify_bridge_server").setLevel(_log_level)
|
|
||||||
logging.getLogger("notify_bridge_core").setLevel(_log_level)
|
# Boot logging from env-based config. DB-backed AppSetting rows (``log_level`` /
|
||||||
|
# ``log_levels`` / ``log_format``) override this after migrations — see the
|
||||||
|
# lifespan block below.
|
||||||
|
setup_logging(
|
||||||
|
level="DEBUG" if _log_cfg.debug else _log_cfg.log_level,
|
||||||
|
fmt=_log_cfg.log_format,
|
||||||
|
per_module_levels=_log_cfg.log_levels,
|
||||||
|
)
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
from .database.engine import init_db
|
from .database.engine import init_db
|
||||||
from .database.models import * # noqa: F401,F403 — ensure all models registered
|
from .database.models import * # noqa: F401,F403 — ensure all models registered
|
||||||
@@ -47,13 +52,41 @@ from .api.webhook_logs import router as webhook_logs_router
|
|||||||
from .api.backup import router as backup_router
|
from .api.backup import router as backup_router
|
||||||
|
|
||||||
|
|
||||||
|
# Readiness flag — flipped to True once the scheduler has started and the
|
||||||
|
# app is fully initialized. Exposed via /api/ready for orchestrators.
|
||||||
|
_READY: bool = False
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
global _READY
|
||||||
await init_db()
|
await init_db()
|
||||||
# Run data migrations (idempotent)
|
# Run data migrations (idempotent)
|
||||||
from .database.engine import get_engine
|
from .database.engine import get_engine
|
||||||
from .database.migrations import migrate_schema, migrate_tracker_targets, migrate_entity_refactor, migrate_template_slots, migrate_target_receivers, migrate_template_locale, migrate_receivers_from_config, migrate_command_slot_locale, migrate_notification_slot_locale, migrate_user_token_version
|
from .database.migrations import (
|
||||||
|
migrate_schema,
|
||||||
|
migrate_tracker_targets,
|
||||||
|
migrate_entity_refactor,
|
||||||
|
migrate_template_slots,
|
||||||
|
migrate_target_receivers,
|
||||||
|
migrate_template_locale,
|
||||||
|
migrate_receivers_from_config,
|
||||||
|
migrate_command_slot_locale,
|
||||||
|
migrate_notification_slot_locale,
|
||||||
|
migrate_user_token_version,
|
||||||
|
migrate_performance_indexes,
|
||||||
|
migrate_schema_version,
|
||||||
|
)
|
||||||
|
from .database.snapshot import snapshot_and_prune
|
||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
|
# Take a consistent DB snapshot BEFORE migrations run, so operators can
|
||||||
|
# roll back a bad upgrade by restoring one file. Best-effort — failures
|
||||||
|
# are logged, not raised.
|
||||||
|
await snapshot_and_prune(
|
||||||
|
engine,
|
||||||
|
_log_cfg.data_dir / "backups",
|
||||||
|
keep=_log_cfg.pre_migrate_snapshot_keep,
|
||||||
|
)
|
||||||
await migrate_schema(engine)
|
await migrate_schema(engine)
|
||||||
await migrate_tracker_targets(engine)
|
await migrate_tracker_targets(engine)
|
||||||
await migrate_entity_refactor(engine)
|
await migrate_entity_refactor(engine)
|
||||||
@@ -64,8 +97,28 @@ async def lifespan(app: FastAPI):
|
|||||||
await migrate_command_slot_locale(engine)
|
await migrate_command_slot_locale(engine)
|
||||||
await migrate_notification_slot_locale(engine)
|
await migrate_notification_slot_locale(engine)
|
||||||
await migrate_user_token_version(engine)
|
await migrate_user_token_version(engine)
|
||||||
|
await migrate_performance_indexes(engine)
|
||||||
|
await migrate_schema_version(engine)
|
||||||
from .database.seeds import seed_all
|
from .database.seeds import seed_all
|
||||||
await seed_all()
|
await seed_all()
|
||||||
|
# Apply DB-backed logging settings (override env-based boot config).
|
||||||
|
# log_format still needs a restart — changing it means swapping the
|
||||||
|
# handler formatter entirely.
|
||||||
|
try:
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession as _AS_log
|
||||||
|
from .api.app_settings import get_setting as _get_log_setting
|
||||||
|
from .logging_setup import apply_log_levels
|
||||||
|
async with _AS_log(engine) as _log_session:
|
||||||
|
db_level = await _get_log_setting(_log_session, "log_level")
|
||||||
|
db_levels = await _get_log_setting(_log_session, "log_levels")
|
||||||
|
apply_log_levels(level=db_level or None, per_module_levels=db_levels)
|
||||||
|
_LOGGER.info(
|
||||||
|
"Logging initialized: level=%s overrides=%r format=%s",
|
||||||
|
db_level or _log_cfg.log_level, db_levels or _log_cfg.log_levels,
|
||||||
|
_log_cfg.log_format,
|
||||||
|
)
|
||||||
|
except Exception: # pragma: no cover — never let logging setup abort boot
|
||||||
|
_LOGGER.exception("Failed to apply DB-backed log settings; keeping env-based levels")
|
||||||
# Apply any pending restore staged via /api/backup/prepare-restore
|
# Apply any pending restore staged via /api/backup/prepare-restore
|
||||||
from .services.pending_restore import apply_pending_restore_if_any
|
from .services.pending_restore import apply_pending_restore_if_any
|
||||||
await apply_pending_restore_if_any()
|
await apply_pending_restore_if_any()
|
||||||
@@ -77,16 +130,28 @@ async def lifespan(app: FastAPI):
|
|||||||
set_webhook_secret(_secret or None)
|
set_webhook_secret(_secret or None)
|
||||||
from .services.scheduler import start_scheduler, get_scheduler
|
from .services.scheduler import start_scheduler, get_scheduler
|
||||||
await start_scheduler()
|
await start_scheduler()
|
||||||
|
_READY = True
|
||||||
yield
|
yield
|
||||||
# Graceful shutdown
|
# Graceful shutdown — stop the scheduler FIRST so in-flight jobs finish
|
||||||
from .services.http_session import close_http_session
|
# before we close their HTTP session. Then close the shared session and
|
||||||
await close_http_session()
|
# dispose the DB engine.
|
||||||
|
_READY = False
|
||||||
scheduler = get_scheduler()
|
scheduler = get_scheduler()
|
||||||
if scheduler.running:
|
if scheduler.running:
|
||||||
scheduler.shutdown()
|
scheduler.shutdown(wait=True)
|
||||||
|
from .services.http_session import close_http_session
|
||||||
|
await close_http_session()
|
||||||
|
from .database.engine import dispose_engine
|
||||||
|
await dispose_engine()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Notify Bridge", version="0.1.0", lifespan=lifespan)
|
try:
|
||||||
|
from importlib.metadata import version as _pkg_version
|
||||||
|
_APP_VERSION = _pkg_version("notify-bridge-server")
|
||||||
|
except Exception: # pragma: no cover — editable install edge cases
|
||||||
|
_APP_VERSION = "0.0.0+unknown"
|
||||||
|
|
||||||
|
app = FastAPI(title="Notify Bridge", version=_APP_VERSION, lifespan=lifespan)
|
||||||
|
|
||||||
# --- Security headers ---
|
# --- Security headers ---
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
@@ -94,6 +159,24 @@ from starlette.requests import Request as StarletteRequest
|
|||||||
from starlette.responses import Response as StarletteResponse
|
from starlette.responses import Response as StarletteResponse
|
||||||
|
|
||||||
|
|
||||||
|
_CSP = (
|
||||||
|
"default-src 'self'; "
|
||||||
|
"img-src 'self' data: blob: https:; "
|
||||||
|
"style-src 'self' 'unsafe-inline'; "
|
||||||
|
# SvelteKit's static adapter emits an inline bootstrap <script> with the
|
||||||
|
# hydration payload, so 'self' alone blocks the SPA from starting.
|
||||||
|
# 'unsafe-inline' re-enables it; the app's primary XSS protection still
|
||||||
|
# comes from Svelte's template auto-escaping and frontend/sanitize.ts
|
||||||
|
# for the few {@html} paths that render user-controlled content.
|
||||||
|
"script-src 'self' 'unsafe-inline'; "
|
||||||
|
"connect-src 'self'; "
|
||||||
|
"font-src 'self' data:; "
|
||||||
|
"base-uri 'self'; "
|
||||||
|
"form-action 'self'; "
|
||||||
|
"frame-ancestors 'none'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: StarletteRequest, call_next):
|
async def dispatch(self, request: StarletteRequest, call_next):
|
||||||
response: StarletteResponse = await call_next(request)
|
response: StarletteResponse = await call_next(request)
|
||||||
@@ -101,6 +184,14 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|||||||
response.headers["X-Frame-Options"] = "DENY"
|
response.headers["X-Frame-Options"] = "DENY"
|
||||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||||
|
response.headers.setdefault("Content-Security-Policy", _CSP)
|
||||||
|
# HSTS only makes sense over HTTPS; set when the edge terminates TLS
|
||||||
|
# and forwards X-Forwarded-Proto=https.
|
||||||
|
if request.headers.get("x-forwarded-proto") == "https":
|
||||||
|
response.headers.setdefault(
|
||||||
|
"Strict-Transport-Security",
|
||||||
|
"max-age=31536000; includeSubDomains",
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@@ -153,7 +244,22 @@ app.include_router(backup_router)
|
|||||||
|
|
||||||
@app.get("/api/health")
|
@app.get("/api/health")
|
||||||
async def health():
|
async def health():
|
||||||
return {"status": "ok"}
|
"""Liveness: process is up and responding. Always returns 200 once the
|
||||||
|
ASGI app has started. Keep this endpoint anonymous and trivially cheap."""
|
||||||
|
return {"status": "ok", "version": _APP_VERSION}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/ready")
|
||||||
|
async def ready():
|
||||||
|
"""Readiness: migrations and scheduler have started, app can serve traffic.
|
||||||
|
|
||||||
|
Returns 503 until the lifespan startup sequence has completed. Use this
|
||||||
|
for orchestrator readiness probes (Docker, Kubernetes).
|
||||||
|
"""
|
||||||
|
if not _READY:
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
return JSONResponse({"status": "starting"}, status_code=503)
|
||||||
|
return {"status": "ready", "version": _APP_VERSION}
|
||||||
|
|
||||||
|
|
||||||
# --- Serve frontend static files (production) ---
|
# --- Serve frontend static files (production) ---
|
||||||
@@ -186,4 +292,12 @@ if _cfg.static_dir and Path(_cfg.static_dir).is_dir():
|
|||||||
|
|
||||||
def run():
|
def run():
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(app, host=_cfg.host, port=_cfg.port)
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=_cfg.host,
|
||||||
|
port=_cfg.port,
|
||||||
|
proxy_headers=True,
|
||||||
|
forwarded_allow_ips=_cfg.forwarded_allow_ips or "127.0.0.1",
|
||||||
|
timeout_graceful_shutdown=_cfg.graceful_shutdown_seconds,
|
||||||
|
access_log=not _cfg.debug,
|
||||||
|
)
|
||||||
|
|||||||
@@ -387,10 +387,9 @@ async def export_backup_to_file(
|
|||||||
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
|
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
|
||||||
filename = f"backup-{ts}.json"
|
filename = f"backup-{ts}.json"
|
||||||
filepath = backup_dir / filename
|
filepath = backup_dir / filename
|
||||||
filepath.write_text(
|
import asyncio as _asyncio
|
||||||
json.dumps(backup.model_dump(), indent=2, ensure_ascii=False),
|
payload = json.dumps(backup.model_dump(), indent=2, ensure_ascii=False)
|
||||||
encoding="utf-8",
|
await _asyncio.to_thread(filepath.write_text, payload, encoding="utf-8")
|
||||||
)
|
|
||||||
_LOGGER.info("Scheduled backup saved: %s", filepath)
|
_LOGGER.info("Scheduled backup saved: %s", filepath)
|
||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
@@ -399,7 +398,13 @@ def cleanup_old_backups(backup_dir: Path, keep: int = 5) -> list[str]:
|
|||||||
"""Delete oldest backup files exceeding `keep` count. Returns deleted filenames."""
|
"""Delete oldest backup files exceeding `keep` count. Returns deleted filenames."""
|
||||||
if not backup_dir.is_dir():
|
if not backup_dir.is_dir():
|
||||||
return []
|
return []
|
||||||
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
|
# Sort by mtime (newest first) so behavior doesn't depend on the filename
|
||||||
|
# timestamp format, which could change later without updating this code.
|
||||||
|
files = sorted(
|
||||||
|
backup_dir.glob("backup-*.json"),
|
||||||
|
key=lambda f: f.stat().st_mtime,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
deleted = []
|
deleted = []
|
||||||
for old in files[keep:]:
|
for old in files[keep:]:
|
||||||
old.unlink()
|
old.unlink()
|
||||||
@@ -413,7 +418,13 @@ def list_backup_files(backup_dir: Path) -> list[dict[str, Any]]:
|
|||||||
"""List backup files in the directory with metadata."""
|
"""List backup files in the directory with metadata."""
|
||||||
if not backup_dir.is_dir():
|
if not backup_dir.is_dir():
|
||||||
return []
|
return []
|
||||||
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
|
# Sort by mtime (newest first) so behavior doesn't depend on the filename
|
||||||
|
# timestamp format, which could change later without updating this code.
|
||||||
|
files = sorted(
|
||||||
|
backup_dir.glob("backup-*.json"),
|
||||||
|
key=lambda f: f.stat().st_mtime,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
result = []
|
result = []
|
||||||
for f in files:
|
for f in files:
|
||||||
stat = f.stat()
|
stat = f.stat()
|
||||||
|
|||||||
@@ -11,15 +11,27 @@ Call ``close_http_session()`` once during application shutdown.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30, connect=10)
|
||||||
|
|
||||||
_session: aiohttp.ClientSession | None = None
|
_session: aiohttp.ClientSession | None = None
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
async def get_http_session() -> aiohttp.ClientSession:
|
async def get_http_session() -> aiohttp.ClientSession:
|
||||||
"""Get or create the shared HTTP session."""
|
"""Get or create the shared HTTP session.
|
||||||
|
|
||||||
|
Concurrent first-callers are serialized through ``_lock`` so we never
|
||||||
|
leak a second ClientSession / connector pair. Once established, hot
|
||||||
|
callers skip the lock via the fast-path check.
|
||||||
|
"""
|
||||||
global _session
|
global _session
|
||||||
|
if _session is not None and not _session.closed:
|
||||||
|
return _session
|
||||||
|
async with _lock:
|
||||||
if _session is None or _session.closed:
|
if _session is None or _session.closed:
|
||||||
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
|
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
|
||||||
return _session
|
return _session
|
||||||
@@ -28,6 +40,7 @@ async def get_http_session() -> aiohttp.ClientSession:
|
|||||||
async def close_http_session() -> None:
|
async def close_http_session() -> None:
|
||||||
"""Close the shared HTTP session (call on app shutdown)."""
|
"""Close the shared HTTP session (call on app shutdown)."""
|
||||||
global _session
|
global _session
|
||||||
|
async with _lock:
|
||||||
if _session is not None and not _session.closed:
|
if _session is not None and not _session.closed:
|
||||||
await _session.close()
|
await _session.close()
|
||||||
_session = None
|
_session = None
|
||||||
|
|||||||
@@ -188,6 +188,7 @@ _SAMPLE_CONTEXT = {
|
|||||||
"current_time": "09:00",
|
"current_time": "09:00",
|
||||||
"current_datetime": "22.03.2026, 09:00 UTC",
|
"current_datetime": "22.03.2026, 09:00 UTC",
|
||||||
"weekday": "Monday",
|
"weekday": "Monday",
|
||||||
|
"timezone": "UTC",
|
||||||
"custom_vars": {"team": "Engineering", "message": "Time for standup!"},
|
"custom_vars": {"team": "Engineering", "message": "Time for standup!"},
|
||||||
"team": "Engineering",
|
"team": "Engineering",
|
||||||
"message": "Time for standup!",
|
"message": "Time for standup!",
|
||||||
|
|||||||
@@ -0,0 +1,242 @@
|
|||||||
|
"""Cron-fired scheduled / periodic / memory dispatch for Immich trackers.
|
||||||
|
|
||||||
|
The Immich provider exposes three notification slots that fire on a wall-clock
|
||||||
|
schedule rather than in response to album changes:
|
||||||
|
|
||||||
|
* ``scheduled_assets_message`` — random asset selection at fixed times of day
|
||||||
|
* ``periodic_summary_message`` — album stats summary at fixed times of day
|
||||||
|
* ``memory_mode_message`` — "On This Day" memories at fixed times of day
|
||||||
|
|
||||||
|
The fire times live on the tracker's default ``TrackingConfig`` as comma-
|
||||||
|
separated ``HH:MM`` strings (``scheduled_times`` / ``periodic_times`` /
|
||||||
|
``memory_times``) interpreted in the app-level IANA timezone
|
||||||
|
(``AppSetting.timezone``). The scheduler module wires the cron jobs; this
|
||||||
|
module owns the dispatch flow once a job fires.
|
||||||
|
|
||||||
|
Note on per-link tracking-config overrides: schedule *times* come from the
|
||||||
|
tracker's default config — a per-link override may disable the slot for that
|
||||||
|
link (via ``{kind}_enabled``) but cannot shift its fire time. Consistent with
|
||||||
|
the test-dispatch path in ``manual_dispatch``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from notify_bridge_core.models.events import EventType
|
||||||
|
from notify_bridge_core.notifications.dispatcher import (
|
||||||
|
NotificationDispatcher,
|
||||||
|
TargetConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..database.engine import get_engine
|
||||||
|
from ..database.models import (
|
||||||
|
EventLog,
|
||||||
|
NotificationTracker,
|
||||||
|
ServiceProvider,
|
||||||
|
TemplateSlot,
|
||||||
|
TrackingConfig,
|
||||||
|
)
|
||||||
|
from .dispatch_helpers import (
|
||||||
|
event_allowed_by_config,
|
||||||
|
get_app_timezone,
|
||||||
|
load_link_data,
|
||||||
|
)
|
||||||
|
from .manual_dispatch import _build_immich_event, _build_immich_periodic_event
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ScheduledKind = Literal["scheduled", "periodic", "memory"]
|
||||||
|
|
||||||
|
# Maps the dispatch kind to the DB slot name that holds its template.
|
||||||
|
# The dispatcher keys templates by ``event.event_type.value`` (always
|
||||||
|
# ``scheduled_message`` here), so we read the right ``TemplateSlot`` row and
|
||||||
|
# inject it under that single event-type key — same pattern as the test path.
|
||||||
|
_SLOT_MAP: dict[ScheduledKind, str] = {
|
||||||
|
"scheduled": "scheduled_assets_message",
|
||||||
|
"periodic": "periodic_summary_message",
|
||||||
|
"memory": "memory_mode_message",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def dispatch_scheduled_for_tracker(
|
||||||
|
tracker_id: int, kind: ScheduledKind
|
||||||
|
) -> None:
|
||||||
|
"""Build the slot's event for ``tracker_id`` and fan out to its links.
|
||||||
|
|
||||||
|
Skips silently when the tracker is disabled, the provider is not Immich,
|
||||||
|
the slot is disabled on the tracker's default tracking config, or no link
|
||||||
|
has a ``TemplateConfig`` with the corresponding slot row.
|
||||||
|
"""
|
||||||
|
engine = get_engine()
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
tracker = await session.get(NotificationTracker, tracker_id)
|
||||||
|
if not tracker or not tracker.enabled:
|
||||||
|
return
|
||||||
|
provider = await session.get(ServiceProvider, tracker.provider_id)
|
||||||
|
if not provider or provider.type != "immich":
|
||||||
|
return
|
||||||
|
|
||||||
|
default_tc: TrackingConfig | None = None
|
||||||
|
if tracker.default_tracking_config_id:
|
||||||
|
default_tc = await session.get(
|
||||||
|
TrackingConfig, tracker.default_tracking_config_id
|
||||||
|
)
|
||||||
|
# If the default config disables this kind, nothing to do — schedule
|
||||||
|
# rebuild only adds jobs when the flag is set, but a stale job from
|
||||||
|
# a previous DB state could still fire one tick before invalidation.
|
||||||
|
if default_tc is None or not getattr(default_tc, f"{kind}_enabled", False):
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Scheduled %s skipped for tracker %d: kind disabled on default config",
|
||||||
|
kind, tracker_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Snapshot every field we need outside the session — after the
|
||||||
|
# ``async with`` exits the instances are detached and lazy-load
|
||||||
|
# would fail. Cheaper than re-fetching, safer than touching
|
||||||
|
# attributes through a closed session.
|
||||||
|
provider_id = provider.id
|
||||||
|
provider_config = dict(provider.config)
|
||||||
|
provider_name = provider.name or provider.type
|
||||||
|
tracker_user_id = tracker.user_id
|
||||||
|
tracker_name = tracker.name or ""
|
||||||
|
collection_ids = list(tracker.collection_ids or [])
|
||||||
|
|
||||||
|
app_tz = await get_app_timezone(session)
|
||||||
|
link_data = await load_link_data(session, tracker_id)
|
||||||
|
|
||||||
|
if not link_data:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Scheduled %s for tracker %d: no enabled links, skipping",
|
||||||
|
kind, tracker_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if kind == "periodic":
|
||||||
|
event = await _build_immich_periodic_event(
|
||||||
|
provider_config=provider_config,
|
||||||
|
provider_name=provider_name,
|
||||||
|
tracker_name=tracker_name,
|
||||||
|
collection_ids=collection_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
event = await _build_immich_event(
|
||||||
|
provider_config=provider_config,
|
||||||
|
provider_name=provider_name,
|
||||||
|
tracker_name=tracker_name,
|
||||||
|
collection_ids=collection_ids,
|
||||||
|
test_type=kind,
|
||||||
|
tracking_config=default_tc,
|
||||||
|
)
|
||||||
|
if event is None:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Scheduled %s for tracker %d: provider returned no event",
|
||||||
|
kind, tracker_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip empty payloads for asset-bearing kinds — sending the bare
|
||||||
|
# "On this day:" / "Scheduled delivery —" header with no items below
|
||||||
|
# spams chats with title-only messages every day. ``periodic`` is
|
||||||
|
# different: it's a stats summary that's still meaningful with zero
|
||||||
|
# assets, so we let it through.
|
||||||
|
if kind in ("scheduled", "memory") and not event.added_assets:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Scheduled %s for tracker %d: 0 assets matched, skipping dispatch",
|
||||||
|
kind, tracker_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
slot_name = _SLOT_MAP[kind]
|
||||||
|
target_configs: list[TargetConfig] = []
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
for ld in link_data:
|
||||||
|
tc = ld["tracking_config"] or default_tc
|
||||||
|
tmpl = ld["template_config"]
|
||||||
|
if tc is not None:
|
||||||
|
# Per-link override may disable this kind even when the
|
||||||
|
# default has it on — honour that here.
|
||||||
|
if not getattr(tc, f"{kind}_enabled", True):
|
||||||
|
continue
|
||||||
|
if not event_allowed_by_config(event, tc, app_tz):
|
||||||
|
continue
|
||||||
|
if tmpl is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
slot_rows = (await session.exec(
|
||||||
|
select(TemplateSlot).where(
|
||||||
|
TemplateSlot.config_id == tmpl.id,
|
||||||
|
TemplateSlot.slot_name == slot_name,
|
||||||
|
)
|
||||||
|
)).all()
|
||||||
|
if not slot_rows:
|
||||||
|
continue
|
||||||
|
locale_map = {s.locale: s.template for s in slot_rows}
|
||||||
|
template_slots = {EventType.SCHEDULED_MESSAGE.value: locale_map}
|
||||||
|
|
||||||
|
target_configs.append(TargetConfig(
|
||||||
|
type=ld["target_type"],
|
||||||
|
config=ld["target_config"],
|
||||||
|
template_slots=template_slots,
|
||||||
|
date_format=tmpl.date_format,
|
||||||
|
date_only_format=(
|
||||||
|
tmpl.date_only_format or "%d.%m.%Y"
|
||||||
|
),
|
||||||
|
provider_api_key=provider_config.get("api_key"),
|
||||||
|
provider_internal_url=provider_config.get("url", ""),
|
||||||
|
provider_external_url=provider_config.get("external_domain", ""),
|
||||||
|
receivers=ld["receivers"],
|
||||||
|
))
|
||||||
|
|
||||||
|
if not target_configs:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Scheduled %s for tracker %d: no targets after filtering",
|
||||||
|
kind, tracker_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Lazy import to break the watcher↔scheduler↔scheduled_dispatch cycle.
|
||||||
|
from .watcher import _get_telegram_caches
|
||||||
|
from .http_session import get_http_session
|
||||||
|
|
||||||
|
url_cache, asset_cache = await _get_telegram_caches()
|
||||||
|
http_session = await get_http_session()
|
||||||
|
dispatcher = NotificationDispatcher(
|
||||||
|
url_cache=url_cache, asset_cache=asset_cache, session=http_session,
|
||||||
|
)
|
||||||
|
_LOGGER.info(
|
||||||
|
"Dispatching scheduled %s for tracker %d to %d link(s)",
|
||||||
|
kind, tracker_id, len(target_configs),
|
||||||
|
)
|
||||||
|
results = await dispatcher.dispatch(event, target_configs)
|
||||||
|
|
||||||
|
# Mirror the watcher's audit trail: surface scheduled fires in EventLog so
|
||||||
|
# the dashboard shows *why* a notification arrived (otherwise these would
|
||||||
|
# be invisible to the activity feed).
|
||||||
|
successes = sum(1 for r in results if isinstance(r, dict) and r.get("success"))
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
session.add(EventLog(
|
||||||
|
user_id=tracker_user_id,
|
||||||
|
tracker_id=tracker_id,
|
||||||
|
tracker_name=tracker_name,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_name=provider_name,
|
||||||
|
event_type=event.event_type.value,
|
||||||
|
collection_id=event.collection_id,
|
||||||
|
collection_name=event.collection_name,
|
||||||
|
assets_count=event.added_count or 0,
|
||||||
|
details={
|
||||||
|
"kind": kind,
|
||||||
|
"slot": slot_name,
|
||||||
|
"trigger": "cron",
|
||||||
|
"timezone": app_tz,
|
||||||
|
"targets_dispatched": len(target_configs),
|
||||||
|
"targets_succeeded": successes,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
@@ -3,11 +3,39 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||||
|
|
||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_zoneinfo(tz_name: str | None) -> ZoneInfo:
|
||||||
|
"""Resolve an IANA tz string to a ZoneInfo, falling back to UTC on any error.
|
||||||
|
|
||||||
|
Kept local to avoid importing from api/dispatch layers inside the scheduler
|
||||||
|
module (which is loaded at startup, before the API routers).
|
||||||
|
"""
|
||||||
|
if not tz_name:
|
||||||
|
return ZoneInfo("UTC")
|
||||||
|
try:
|
||||||
|
return ZoneInfo(tz_name)
|
||||||
|
except (ZoneInfoNotFoundError, ValueError):
|
||||||
|
_LOGGER.warning("Unknown timezone %r; falling back to UTC", tz_name)
|
||||||
|
return ZoneInfo("UTC")
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_app_timezone() -> ZoneInfo:
|
||||||
|
"""Load the admin-configured app timezone from AppSetting (falls back to UTC)."""
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from ..api.app_settings import get_setting
|
||||||
|
from ..database.engine import get_engine
|
||||||
|
|
||||||
|
async with AsyncSession(get_engine()) as session:
|
||||||
|
tz_name = await get_setting(session, "timezone")
|
||||||
|
return _resolve_zoneinfo(tz_name)
|
||||||
|
|
||||||
_scheduler: AsyncIOScheduler | None = None
|
_scheduler: AsyncIOScheduler | None = None
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -57,7 +85,21 @@ def _compute_jitter(interval_seconds: int) -> int:
|
|||||||
def get_scheduler() -> AsyncIOScheduler:
|
def get_scheduler() -> AsyncIOScheduler:
|
||||||
global _scheduler
|
global _scheduler
|
||||||
if _scheduler is None:
|
if _scheduler is None:
|
||||||
_scheduler = AsyncIOScheduler()
|
# Sensible production defaults applied to every job unless overridden:
|
||||||
|
# * coalesce — collapse a queue of missed runs into one firing after
|
||||||
|
# a restart / pause, instead of bursting to catch up.
|
||||||
|
# * misfire_grace_time — accept firings up to 5 min late without
|
||||||
|
# dropping them silently.
|
||||||
|
# * max_instances=1 — never run two copies of the same tracker tick
|
||||||
|
# concurrently; the scheduler already enforces this on add_job,
|
||||||
|
# but we also set it as the default for safety.
|
||||||
|
_scheduler = AsyncIOScheduler(
|
||||||
|
job_defaults={
|
||||||
|
"coalesce": True,
|
||||||
|
"misfire_grace_time": 300,
|
||||||
|
"max_instances": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
return _scheduler
|
return _scheduler
|
||||||
|
|
||||||
|
|
||||||
@@ -69,6 +111,7 @@ async def start_scheduler() -> None:
|
|||||||
|
|
||||||
await _load_tracker_jobs()
|
await _load_tracker_jobs()
|
||||||
await _load_action_jobs()
|
await _load_action_jobs()
|
||||||
|
await _load_immich_dispatch_jobs()
|
||||||
|
|
||||||
# Start Telegram bot polling for bots with active command listeners
|
# Start Telegram bot polling for bots with active command listeners
|
||||||
from .telegram_poller import start_command_listener_polling
|
from .telegram_poller import start_command_listener_polling
|
||||||
@@ -251,21 +294,38 @@ async def _refresh_telegram_chat_titles() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def _cleanup_old_events() -> None:
|
async def _cleanup_old_events() -> None:
|
||||||
"""Delete EventLog entries older than 90 days."""
|
"""Delete EventLog / WebhookPayloadLog / ActionExecution rows older than the
|
||||||
|
configured retention window. A retention of 0 disables the job.
|
||||||
|
"""
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from sqlmodel import delete
|
from sqlmodel import delete
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from ..config import settings
|
||||||
from ..database.engine import get_engine
|
from ..database.engine import get_engine
|
||||||
from ..database.models import EventLog
|
from ..database.models import ActionExecution, EventLog, WebhookPayloadLog
|
||||||
|
|
||||||
cutoff = datetime.now(timezone.utc) - timedelta(days=90)
|
days = settings.event_log_retention_days
|
||||||
|
if days <= 0:
|
||||||
|
_LOGGER.debug("Event log retention disabled (days=0); skipping cleanup")
|
||||||
|
return
|
||||||
|
|
||||||
|
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
||||||
engine = get_engine()
|
engine = get_engine()
|
||||||
async with AsyncSession(engine) as session:
|
async with AsyncSession(engine) as session:
|
||||||
await session.exec(delete(EventLog).where(EventLog.created_at < cutoff))
|
await session.exec(delete(EventLog).where(EventLog.created_at < cutoff))
|
||||||
|
await session.exec(
|
||||||
|
delete(WebhookPayloadLog).where(WebhookPayloadLog.created_at < cutoff)
|
||||||
|
)
|
||||||
|
await session.exec(
|
||||||
|
delete(ActionExecution).where(ActionExecution.started_at < cutoff)
|
||||||
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
_LOGGER.info("Cleaned up event log entries older than %s", cutoff.date())
|
_LOGGER.info(
|
||||||
|
"Cleaned event_log / webhook_payload_log / action_execution older than %s",
|
||||||
|
cutoff.date(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _load_tracker_jobs() -> None:
|
async def _load_tracker_jobs() -> None:
|
||||||
@@ -293,6 +353,8 @@ async def _load_tracker_jobs() -> None:
|
|||||||
)
|
)
|
||||||
provider_types = {p.id: p.type for p in provider_result.all()}
|
provider_types = {p.id: p.type for p in provider_result.all()}
|
||||||
|
|
||||||
|
tz = await _load_app_timezone()
|
||||||
|
|
||||||
for tracker in trackers:
|
for tracker in trackers:
|
||||||
job_id = f"tracker_{tracker.id}"
|
job_id = f"tracker_{tracker.id}"
|
||||||
if scheduler.get_job(job_id):
|
if scheduler.get_job(job_id):
|
||||||
@@ -306,7 +368,7 @@ async def _load_tracker_jobs() -> None:
|
|||||||
cron_expr = filters.get("cron_expression", "")
|
cron_expr = filters.get("cron_expression", "")
|
||||||
if cron_expr:
|
if cron_expr:
|
||||||
try:
|
try:
|
||||||
_add_cron_job(scheduler, job_id, tracker.id, cron_expr, tracker.name)
|
_add_cron_job(scheduler, job_id, tracker.id, cron_expr, tracker.name, tz)
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_LOGGER.error(
|
_LOGGER.error(
|
||||||
@@ -337,10 +399,18 @@ def _add_cron_job(
|
|||||||
tracker_id: int,
|
tracker_id: int,
|
||||||
cron_expression: str,
|
cron_expression: str,
|
||||||
tracker_name: str,
|
tracker_name: str,
|
||||||
|
tz: ZoneInfo,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a cron-triggered job for a scheduler-type tracker."""
|
"""Add a cron-triggered job for a scheduler-type tracker.
|
||||||
|
|
||||||
|
``tz`` is the user-configured app timezone; without it APScheduler
|
||||||
|
interprets the crontab in the host's local timezone, which surfaces as
|
||||||
|
events firing at the "wrong" wall-clock time for operators in a non-UTC
|
||||||
|
zone (see the companion fix in ``update_settings`` which reschedules on
|
||||||
|
timezone changes).
|
||||||
|
"""
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
trigger = CronTrigger.from_crontab(cron_expression)
|
trigger = CronTrigger.from_crontab(cron_expression, timezone=tz)
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_poll_tracker,
|
_poll_tracker,
|
||||||
trigger,
|
trigger,
|
||||||
@@ -349,7 +419,10 @@ def _add_cron_job(
|
|||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
max_instances=1,
|
max_instances=1,
|
||||||
)
|
)
|
||||||
_LOGGER.info("Scheduled tracker %d (%s) with cron: %s", tracker_id, tracker_name, cron_expression)
|
_LOGGER.info(
|
||||||
|
"Scheduled tracker %d (%s) with cron: %s [tz=%s]",
|
||||||
|
tracker_id, tracker_name, cron_expression, tz.key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def schedule_tracker(
|
async def schedule_tracker(
|
||||||
@@ -371,7 +444,8 @@ async def schedule_tracker(
|
|||||||
|
|
||||||
if cron_expression:
|
if cron_expression:
|
||||||
try:
|
try:
|
||||||
_add_cron_job(scheduler, job_id, tracker_id, cron_expression, f"tracker-{tracker_id}")
|
tz = await _load_app_timezone()
|
||||||
|
_add_cron_job(scheduler, job_id, tracker_id, cron_expression, f"tracker-{tracker_id}", tz)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_LOGGER.error("Invalid cron for tracker %d: %s — using interval", tracker_id, e)
|
_LOGGER.error("Invalid cron for tracker %d: %s — using interval", tracker_id, e)
|
||||||
@@ -506,6 +580,8 @@ async def _load_action_jobs() -> None:
|
|||||||
)
|
)
|
||||||
actions = result.all()
|
actions = result.all()
|
||||||
|
|
||||||
|
tz = await _load_app_timezone()
|
||||||
|
|
||||||
for action in actions:
|
for action in actions:
|
||||||
job_id = f"action_{action.id}"
|
job_id = f"action_{action.id}"
|
||||||
if scheduler.get_job(job_id):
|
if scheduler.get_job(job_id):
|
||||||
@@ -514,7 +590,7 @@ async def _load_action_jobs() -> None:
|
|||||||
if action.schedule_type == "cron" and action.schedule_cron:
|
if action.schedule_type == "cron" and action.schedule_cron:
|
||||||
try:
|
try:
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
trigger = CronTrigger.from_crontab(action.schedule_cron)
|
trigger = CronTrigger.from_crontab(action.schedule_cron, timezone=tz)
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_run_action,
|
_run_action,
|
||||||
trigger,
|
trigger,
|
||||||
@@ -523,8 +599,8 @@ async def _load_action_jobs() -> None:
|
|||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
"Scheduled action %d (%s) with cron: %s",
|
"Scheduled action %d (%s) with cron: %s [tz=%s]",
|
||||||
action.id, action.name, action.schedule_cron,
|
action.id, action.name, action.schedule_cron, tz.key,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -563,7 +639,8 @@ async def schedule_action(
|
|||||||
if schedule_type == "cron" and cron_expression:
|
if schedule_type == "cron" and cron_expression:
|
||||||
try:
|
try:
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
trigger = CronTrigger.from_crontab(cron_expression)
|
tz = await _load_app_timezone()
|
||||||
|
trigger = CronTrigger.from_crontab(cron_expression, timezone=tz)
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_run_action,
|
_run_action,
|
||||||
trigger,
|
trigger,
|
||||||
@@ -571,7 +648,10 @@ async def schedule_action(
|
|||||||
args=[action_id],
|
args=[action_id],
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
)
|
)
|
||||||
_LOGGER.info("Scheduled action %d with cron: %s", action_id, cron_expression)
|
_LOGGER.info(
|
||||||
|
"Scheduled action %d with cron: %s [tz=%s]",
|
||||||
|
action_id, cron_expression, tz.key,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_LOGGER.error("Invalid cron for action %d: %s — using interval", action_id, e)
|
_LOGGER.error("Invalid cron for action %d: %s — using interval", action_id, e)
|
||||||
@@ -596,6 +676,96 @@ async def unschedule_action(action_id: int) -> None:
|
|||||||
_LOGGER.info("Unscheduled action %d", action_id)
|
_LOGGER.info("Unscheduled action %d", action_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def reschedule_cron_jobs_for_timezone_change() -> None:
|
||||||
|
"""Re-add every cron-triggered tracker/action job under the new app timezone.
|
||||||
|
|
||||||
|
Called by the admin settings endpoint after the ``timezone`` AppSetting is
|
||||||
|
updated. APScheduler's ``CronTrigger`` freezes its timezone at construction
|
||||||
|
time, so a timezone change has no effect on jobs already in the scheduler
|
||||||
|
— we have to rebuild those jobs. Interval-triggered jobs are tz-agnostic
|
||||||
|
and are left alone.
|
||||||
|
"""
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from ..database.engine import get_engine
|
||||||
|
from ..database.models import Action, NotificationTracker, ServiceProvider as ServiceProviderModel
|
||||||
|
|
||||||
|
engine = get_engine()
|
||||||
|
scheduler = get_scheduler()
|
||||||
|
tz = await _load_app_timezone()
|
||||||
|
rescheduled = 0
|
||||||
|
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
# Trackers with cron scheduling (scheduler provider + schedule_type=cron).
|
||||||
|
trackers = (await session.exec(
|
||||||
|
select(NotificationTracker).where(NotificationTracker.enabled == True) # noqa: E712
|
||||||
|
)).all()
|
||||||
|
provider_ids = list({t.provider_id for t in trackers})
|
||||||
|
provider_types: dict[int, str] = {}
|
||||||
|
if provider_ids:
|
||||||
|
rows = await session.exec(
|
||||||
|
select(ServiceProviderModel).where(ServiceProviderModel.id.in_(provider_ids))
|
||||||
|
)
|
||||||
|
provider_types = {p.id: p.type for p in rows.all()}
|
||||||
|
|
||||||
|
for tracker in trackers:
|
||||||
|
if provider_types.get(tracker.provider_id) != "scheduler":
|
||||||
|
continue
|
||||||
|
filters = tracker.filters or {}
|
||||||
|
if filters.get("schedule_type") != "cron":
|
||||||
|
continue
|
||||||
|
cron_expr = filters.get("cron_expression", "")
|
||||||
|
if not cron_expr:
|
||||||
|
continue
|
||||||
|
job_id = f"tracker_{tracker.id}"
|
||||||
|
if scheduler.get_job(job_id):
|
||||||
|
scheduler.remove_job(job_id)
|
||||||
|
try:
|
||||||
|
_add_cron_job(scheduler, job_id, tracker.id, cron_expr, tracker.name, tz)
|
||||||
|
rescheduled += 1
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
_LOGGER.error(
|
||||||
|
"Failed to re-apply cron for tracker %d on tz change: %s",
|
||||||
|
tracker.id, e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Actions with cron schedules.
|
||||||
|
actions = (await session.exec(
|
||||||
|
select(Action).where(Action.enabled == True) # noqa: E712
|
||||||
|
)).all()
|
||||||
|
|
||||||
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
|
for action in actions:
|
||||||
|
if action.schedule_type != "cron" or not action.schedule_cron:
|
||||||
|
continue
|
||||||
|
job_id = f"action_{action.id}"
|
||||||
|
if scheduler.get_job(job_id):
|
||||||
|
scheduler.remove_job(job_id)
|
||||||
|
try:
|
||||||
|
scheduler.add_job(
|
||||||
|
_run_action,
|
||||||
|
CronTrigger.from_crontab(action.schedule_cron, timezone=tz),
|
||||||
|
id=job_id,
|
||||||
|
args=[action.id],
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
rescheduled += 1
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
_LOGGER.error(
|
||||||
|
"Failed to re-apply cron for action %d on tz change: %s",
|
||||||
|
action.id, e,
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER.info(
|
||||||
|
"Rescheduled %d cron job(s) for new app timezone %s", rescheduled, tz.key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Immich scheduled/periodic/memory jobs are also CronTrigger-based and
|
||||||
|
# carry the same frozen-tz problem — rebuild them under the new tz.
|
||||||
|
await reschedule_immich_dispatch_jobs()
|
||||||
|
|
||||||
|
|
||||||
async def _run_action(action_id: int) -> None:
|
async def _run_action(action_id: int) -> None:
|
||||||
"""Run an action (called by APScheduler)."""
|
"""Run an action (called by APScheduler)."""
|
||||||
from .action_runner import run_action
|
from .action_runner import run_action
|
||||||
@@ -605,6 +775,155 @@ async def _run_action(action_id: int) -> None:
|
|||||||
_LOGGER.error("Error running action %d: %s", action_id, e)
|
_LOGGER.error("Error running action %d: %s", action_id, e)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Immich scheduled / periodic / memory dispatch (cron-fired)
|
||||||
|
#
|
||||||
|
# These three slots fire on wall-clock schedules taken from the tracker's
|
||||||
|
# default ``TrackingConfig`` (``scheduled_times``, ``periodic_times``,
|
||||||
|
# ``memory_times`` — comma-separated ``HH:MM`` strings) interpreted in the
|
||||||
|
# app-level IANA timezone. The dispatch flow lives in
|
||||||
|
# ``services.scheduled_dispatch``; this section just owns scheduling.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_IMMICH_DISPATCH_KINDS = ("scheduled", "periodic", "memory")
|
||||||
|
_IMMICH_DISPATCH_PREFIX = "immich_dispatch_"
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_hhmm_list(raw: str) -> list[tuple[int, int]]:
|
||||||
|
"""Parse ``"09:00,18:30"`` → ``[(9, 0), (18, 30)]``, skipping bad entries.
|
||||||
|
|
||||||
|
A typo in one slot must not prevent the others from scheduling — we log
|
||||||
|
and move on rather than raising.
|
||||||
|
"""
|
||||||
|
out: list[tuple[int, int]] = []
|
||||||
|
for part in (raw or "").split(","):
|
||||||
|
part = part.strip()
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
h_str, m_str = part.split(":", 1)
|
||||||
|
hour, minute = int(h_str), int(m_str)
|
||||||
|
except ValueError:
|
||||||
|
_LOGGER.warning("Skipping invalid time literal %r", part)
|
||||||
|
continue
|
||||||
|
if not (0 <= hour <= 23 and 0 <= minute <= 59):
|
||||||
|
_LOGGER.warning("Skipping out-of-range time %r", part)
|
||||||
|
continue
|
||||||
|
out.append((hour, minute))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_immich_dispatch(tracker_id: int, kind: str) -> None:
|
||||||
|
"""APScheduler entry point — wraps the dispatch helper to swallow errors."""
|
||||||
|
from .scheduled_dispatch import dispatch_scheduled_for_tracker
|
||||||
|
try:
|
||||||
|
await dispatch_scheduled_for_tracker(tracker_id, kind) # type: ignore[arg-type]
|
||||||
|
except Exception as err: # noqa: BLE001
|
||||||
|
_LOGGER.error(
|
||||||
|
"Immich %s dispatch for tracker %d failed: %s", kind, tracker_id, err,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_immich_dispatch_jobs() -> None:
|
||||||
|
"""Schedule cron jobs for every (tracker, kind, time) where the kind is on.
|
||||||
|
|
||||||
|
Reads each enabled Immich tracker's *default* tracking config — per-link
|
||||||
|
overrides only gate dispatch (handled in ``scheduled_dispatch``), they do
|
||||||
|
not influence the fire schedule.
|
||||||
|
"""
|
||||||
|
from sqlmodel import select
|
||||||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
|
|
||||||
|
from ..database.engine import get_engine
|
||||||
|
from ..database.models import (
|
||||||
|
NotificationTracker,
|
||||||
|
ServiceProvider as ServiceProviderModel,
|
||||||
|
TrackingConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = get_engine()
|
||||||
|
scheduler = get_scheduler()
|
||||||
|
tz = await _load_app_timezone()
|
||||||
|
|
||||||
|
async with AsyncSession(engine) as session:
|
||||||
|
trackers = (await session.exec(
|
||||||
|
select(NotificationTracker).where(NotificationTracker.enabled == True) # noqa: E712
|
||||||
|
)).all()
|
||||||
|
if not trackers:
|
||||||
|
return
|
||||||
|
|
||||||
|
provider_ids = list({t.provider_id for t in trackers})
|
||||||
|
provider_types: dict[int, str] = {}
|
||||||
|
if provider_ids:
|
||||||
|
rows = await session.exec(
|
||||||
|
select(ServiceProviderModel).where(
|
||||||
|
ServiceProviderModel.id.in_(provider_ids)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
provider_types = {p.id: p.type for p in rows.all()}
|
||||||
|
|
||||||
|
tc_ids = list({
|
||||||
|
t.default_tracking_config_id for t in trackers
|
||||||
|
if t.default_tracking_config_id
|
||||||
|
})
|
||||||
|
tc_map: dict[int, TrackingConfig] = {}
|
||||||
|
if tc_ids:
|
||||||
|
rows = await session.exec(
|
||||||
|
select(TrackingConfig).where(TrackingConfig.id.in_(tc_ids))
|
||||||
|
)
|
||||||
|
tc_map = {tc.id: tc for tc in rows.all()}
|
||||||
|
|
||||||
|
scheduled = 0
|
||||||
|
for tracker in trackers:
|
||||||
|
if provider_types.get(tracker.provider_id) != "immich":
|
||||||
|
continue
|
||||||
|
tc = tc_map.get(tracker.default_tracking_config_id) if tracker.default_tracking_config_id else None
|
||||||
|
if tc is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for kind in _IMMICH_DISPATCH_KINDS:
|
||||||
|
if not getattr(tc, f"{kind}_enabled", False):
|
||||||
|
continue
|
||||||
|
times_raw = getattr(tc, f"{kind}_times", "") or ""
|
||||||
|
for hour, minute in _parse_hhmm_list(times_raw):
|
||||||
|
job_id = f"{_IMMICH_DISPATCH_PREFIX}{kind}_{tracker.id}_{hour:02d}{minute:02d}"
|
||||||
|
scheduler.add_job(
|
||||||
|
_run_immich_dispatch,
|
||||||
|
CronTrigger(hour=hour, minute=minute, timezone=tz),
|
||||||
|
id=job_id,
|
||||||
|
args=[tracker.id, kind],
|
||||||
|
replace_existing=True,
|
||||||
|
max_instances=1,
|
||||||
|
)
|
||||||
|
scheduled += 1
|
||||||
|
_LOGGER.info(
|
||||||
|
"Scheduled Immich %s for tracker %d at %02d:%02d [tz=%s]",
|
||||||
|
kind, tracker.id, hour, minute, tz.key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if scheduled:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Loaded %d Immich scheduled/periodic/memory job(s) [tz=%s]",
|
||||||
|
scheduled, tz.key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def reschedule_immich_dispatch_jobs() -> None:
|
||||||
|
"""Drop and rebuild all Immich scheduled/periodic/memory jobs.
|
||||||
|
|
||||||
|
Cheap to call on every relevant mutation — a typical install has only a
|
||||||
|
handful of trackers. Called from the tracker, link, and tracking-config
|
||||||
|
CRUD endpoints, and from ``reschedule_cron_jobs_for_timezone_change``.
|
||||||
|
"""
|
||||||
|
scheduler = get_scheduler()
|
||||||
|
for job in list(scheduler.get_jobs()):
|
||||||
|
if job.id.startswith(_IMMICH_DISPATCH_PREFIX):
|
||||||
|
scheduler.remove_job(job.id)
|
||||||
|
await _load_immich_dispatch_jobs()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Scheduled backup
|
# Scheduled backup
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -11,11 +11,13 @@ CommandTrackerListeners with enabled CommandTrackers.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from notify_bridge_core.log_context import bind_log_context
|
||||||
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
||||||
|
|
||||||
from ..database.engine import get_engine
|
from ..database.engine import get_engine
|
||||||
@@ -125,7 +127,14 @@ async def stop_bot_if_unused(bot_id: int) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def schedule_bot_polling(bot_id: int) -> None:
|
def schedule_bot_polling(bot_id: int) -> None:
|
||||||
"""Add a polling job for a bot (idempotent)."""
|
"""Add a polling job for a bot (idempotent).
|
||||||
|
|
||||||
|
We schedule at a 30 s interval, but each tick calls ``getUpdates`` with
|
||||||
|
``timeout=25`` — Telegram holds the connection open until either an
|
||||||
|
update arrives or the timeout elapses, so in practice the bot streams
|
||||||
|
updates with sub-second latency while consuming ~2 API calls / minute
|
||||||
|
per bot (down from 20 under the old 3 s short-poll).
|
||||||
|
"""
|
||||||
scheduler = get_scheduler()
|
scheduler = get_scheduler()
|
||||||
job_id = f"telegram_poll_{bot_id}"
|
job_id = f"telegram_poll_{bot_id}"
|
||||||
if scheduler.get_job(job_id):
|
if scheduler.get_job(job_id):
|
||||||
@@ -133,13 +142,13 @@ def schedule_bot_polling(bot_id: int) -> None:
|
|||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
_poll_bot,
|
_poll_bot,
|
||||||
"interval",
|
"interval",
|
||||||
seconds=3,
|
seconds=30,
|
||||||
id=job_id,
|
id=job_id,
|
||||||
args=[bot_id],
|
args=[bot_id],
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
max_instances=1,
|
max_instances=1,
|
||||||
)
|
)
|
||||||
_LOGGER.info("Started polling for bot %d", bot_id)
|
_LOGGER.info("Started polling for bot %d (long-poll, 25s timeout)", bot_id)
|
||||||
|
|
||||||
|
|
||||||
def unschedule_bot_polling(bot_id: int) -> None:
|
def unschedule_bot_polling(bot_id: int) -> None:
|
||||||
@@ -231,8 +240,10 @@ async def _poll_bot(bot_id: int) -> None:
|
|||||||
from .http_session import get_http_session
|
from .http_session import get_http_session
|
||||||
http = await get_http_session()
|
http = await get_http_session()
|
||||||
client = TelegramClient(http, bot_token)
|
client = TelegramClient(http, bot_token)
|
||||||
|
# Long-poll: hold connection open until an update arrives or 25 s
|
||||||
|
# elapse. Drastically cuts API calls vs. 3 s short-poll.
|
||||||
result = await client.get_updates(
|
result = await client.get_updates(
|
||||||
offset=offset + 1 if offset else None, limit=50,
|
offset=offset + 1 if offset else None, limit=50, timeout=25,
|
||||||
)
|
)
|
||||||
if not result.get("success"):
|
if not result.get("success"):
|
||||||
err_text = str(result.get("error") or "")
|
err_text = str(result.get("error") or "")
|
||||||
@@ -289,6 +300,18 @@ async def _poll_bot(bot_id: int) -> None:
|
|||||||
|
|
||||||
# Dispatch commands (only if chat has commands enabled)
|
# Dispatch commands (only if chat has commands enabled)
|
||||||
if text and text.startswith("/"):
|
if text and text.startswith("/"):
|
||||||
|
from ..commands.parser import parse_command
|
||||||
|
cmd_name, _, _ = parse_command(text)
|
||||||
|
update_id = update.get("update_id")
|
||||||
|
message_id = message.get("message_id")
|
||||||
|
request_id = f"tg:{update_id}" if update_id is not None else f"tg:msg{message_id}"
|
||||||
|
with bind_log_context(
|
||||||
|
request_id=request_id,
|
||||||
|
command=cmd_name or "-",
|
||||||
|
chat_id=chat_id,
|
||||||
|
bot_id=bot_obj.id,
|
||||||
|
):
|
||||||
|
started = time.monotonic()
|
||||||
try:
|
try:
|
||||||
async with AsyncSession(engine) as cmd_session:
|
async with AsyncSession(engine) as cmd_session:
|
||||||
chat_row = (await cmd_session.exec(
|
chat_row = (await cmd_session.exec(
|
||||||
@@ -298,20 +321,43 @@ async def _poll_bot(bot_id: int) -> None:
|
|||||||
)
|
)
|
||||||
)).first()
|
)).first()
|
||||||
if not chat_row or not chat_row.commands_enabled:
|
if not chat_row or not chat_row.commands_enabled:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Command ignored — commands disabled (poll) for bot=%s chat=%s",
|
||||||
|
bot_obj.id, chat_id,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
effective_lang = chat_row.language_override or msg_language
|
effective_lang = chat_row.language_override or msg_language
|
||||||
message_id = message.get("message_id")
|
_LOGGER.info("Command received (poll): /%s args=%r lang=%s", cmd_name, text[:200], effective_lang)
|
||||||
async with telegram_chat_action(
|
async with telegram_chat_action(
|
||||||
bot_token, chat_id, classify_command_chat_action(text),
|
bot_token, chat_id, classify_command_chat_action(text),
|
||||||
):
|
):
|
||||||
responses = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
|
responses = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
|
||||||
if responses:
|
if not responses:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Command produced no response (cmd=%r, poll) after %.0f ms",
|
||||||
|
cmd_name, (time.monotonic() - started) * 1000,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
text_count = sum(1 for r in responses if r.text)
|
||||||
|
media_count = sum(len(r.media or []) for r in responses)
|
||||||
|
_LOGGER.info(
|
||||||
|
"Command dispatching %d response(s): text=%d media_items=%d",
|
||||||
|
len(responses), text_count, media_count,
|
||||||
|
)
|
||||||
for resp in responses:
|
for resp in responses:
|
||||||
if resp.text:
|
if resp.text:
|
||||||
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||||
if resp.media:
|
if resp.media:
|
||||||
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||||
|
_LOGGER.info(
|
||||||
|
"Command /%s completed in %.0f ms (responses=%d media=%d)",
|
||||||
|
cmd_name, (time.monotonic() - started) * 1000,
|
||||||
|
len(responses), media_count,
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
_LOGGER.error("Error handling command from bot %d", bot_id, exc_info=True)
|
_LOGGER.exception(
|
||||||
|
"Error handling command /%s from bot %d after %.0f ms",
|
||||||
|
cmd_name, bot_id, (time.monotonic() - started) * 1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -246,6 +246,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
|||||||
name=provider_name,
|
name=provider_name,
|
||||||
tracker_name=tracker_name,
|
tracker_name=tracker_name,
|
||||||
custom_variables=custom_vars,
|
custom_variables=custom_vars,
|
||||||
|
timezone_name=app_tz,
|
||||||
)
|
)
|
||||||
events, new_state = await sched.poll(collection_ids, state_dict)
|
events, new_state = await sched.poll(collection_ids, state_dict)
|
||||||
elif provider_type == "nut":
|
elif provider_type == "nut":
|
||||||
@@ -317,6 +318,26 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
|||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
assets_count = event.added_count or event.removed_count or 0
|
assets_count = event.added_count or event.removed_count or 0
|
||||||
|
details: dict[str, Any] = {
|
||||||
|
"added_count": event.added_count,
|
||||||
|
"removed_count": event.removed_count,
|
||||||
|
"provider_type": event.provider_type.value,
|
||||||
|
}
|
||||||
|
# Scheduler/periodic events carry the schedule context in ``extra``
|
||||||
|
# (cron expression, interval, timezone, fire count). Surface that
|
||||||
|
# in the event log so the dashboard and audit queries can show
|
||||||
|
# *why* the event fired, not just that it did.
|
||||||
|
if event.event_type.value == "scheduled_message":
|
||||||
|
sched_type = tracker_filters.get("schedule_type", "interval")
|
||||||
|
details["schedule_type"] = sched_type
|
||||||
|
if sched_type == "cron":
|
||||||
|
details["cron_expression"] = tracker_filters.get("cron_expression", "")
|
||||||
|
else:
|
||||||
|
details["interval_seconds"] = tracker.scan_interval
|
||||||
|
details["timezone"] = app_tz
|
||||||
|
fire_count = event.extra.get("fire_count") if event.extra else None
|
||||||
|
if fire_count is not None:
|
||||||
|
details["fire_count"] = fire_count
|
||||||
log = EventLog(
|
log = EventLog(
|
||||||
user_id=tracker.user_id,
|
user_id=tracker.user_id,
|
||||||
tracker_id=tracker_id,
|
tracker_id=tracker_id,
|
||||||
@@ -327,11 +348,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
|||||||
collection_id=event.collection_id,
|
collection_id=event.collection_id,
|
||||||
collection_name=event.collection_name,
|
collection_name=event.collection_name,
|
||||||
assets_count=assets_count,
|
assets_count=assets_count,
|
||||||
details={
|
details=details,
|
||||||
"added_count": event.added_count,
|
|
||||||
"removed_count": event.removed_count,
|
|
||||||
"provider_type": event.provider_type.value,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
session.add(log)
|
session.add(log)
|
||||||
|
|
||||||
@@ -352,7 +369,13 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
|||||||
|
|
||||||
if events and link_data:
|
if events and link_data:
|
||||||
url_cache, asset_cache = await _get_telegram_caches()
|
url_cache, asset_cache = await _get_telegram_caches()
|
||||||
dispatcher = NotificationDispatcher(url_cache=url_cache, asset_cache=asset_cache)
|
from .http_session import get_http_session
|
||||||
|
shared_session = await get_http_session()
|
||||||
|
dispatcher = NotificationDispatcher(
|
||||||
|
url_cache=url_cache,
|
||||||
|
asset_cache=asset_cache,
|
||||||
|
session=shared_session,
|
||||||
|
)
|
||||||
for event in events:
|
for event in events:
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
"Dispatching event %s for %s (added=%d removed=%d)",
|
"Dispatching event %s for %s (added=%d removed=%d)",
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
"""Shared pytest fixtures.
|
||||||
|
|
||||||
|
We set the required env vars before any ``notify_bridge_server`` module is
|
||||||
|
imported so ``Settings()`` passes its startup validation and opens the DB
|
||||||
|
in a writable temp directory instead of the production ``/data`` default.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Provision a writable temp data dir BEFORE the server package is imported —
|
||||||
|
# Settings() materializes at import time, so env-var overrides have to land
|
||||||
|
# here (conftest.py) to be effective.
|
||||||
|
_TMP = Path(tempfile.mkdtemp(prefix="notify-bridge-tests-"))
|
||||||
|
os.environ["NOTIFY_BRIDGE_DATA_DIR"] = str(_TMP)
|
||||||
|
|
||||||
|
os.environ.setdefault(
|
||||||
|
"NOTIFY_BRIDGE_SECRET_KEY",
|
||||||
|
"pytest-secret-key-" + "x" * 40,
|
||||||
|
)
|
||||||
|
os.environ.setdefault("NOTIFY_BRIDGE_DEBUG", "false")
|
||||||
|
os.environ.setdefault("NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS", "http://localhost:8420")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def tmp_data_dir() -> Path:
|
||||||
|
"""Expose the already-provisioned temp dir to tests that want the path."""
|
||||||
|
return _TMP
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
"""Unit tests for the Settings validator."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from notify_bridge_server.config import Settings, _FORBIDDEN_SECRETS
|
||||||
|
|
||||||
|
|
||||||
|
_GOOD = "a" * 40
|
||||||
|
|
||||||
|
|
||||||
|
def _make(**kwargs) -> Settings:
|
||||||
|
defaults = dict(secret_key=_GOOD, cors_allowed_origins="http://localhost:8420")
|
||||||
|
defaults.update(kwargs)
|
||||||
|
return Settings(**defaults)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSecretKey:
|
||||||
|
def test_accepts_long_random_key(self) -> None:
|
||||||
|
_make()
|
||||||
|
|
||||||
|
def test_rejects_default(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="SECURITY"):
|
||||||
|
_make(secret_key="change-me-in-production")
|
||||||
|
|
||||||
|
def test_rejects_known_dev_keys(self) -> None:
|
||||||
|
for bad in _FORBIDDEN_SECRETS:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_make(secret_key=bad)
|
||||||
|
|
||||||
|
def test_rejects_short_key(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="32 characters"):
|
||||||
|
_make(secret_key="short")
|
||||||
|
|
||||||
|
|
||||||
|
class TestCors:
|
||||||
|
def test_rejects_wildcard(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="wildcard"):
|
||||||
|
_make(cors_allowed_origins="*")
|
||||||
|
|
||||||
|
def test_rejects_missing_scheme(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="scheme"):
|
||||||
|
_make(cors_allowed_origins="example.com")
|
||||||
|
|
||||||
|
def test_accepts_multiple(self) -> None:
|
||||||
|
cfg = _make(cors_allowed_origins="http://localhost:8420,https://example.com")
|
||||||
|
assert "http://localhost:8420" in cfg.cors_allowed_origins
|
||||||
|
|
||||||
|
|
||||||
|
class TestNumericValidation:
|
||||||
|
def test_rejects_zero_access_token_expiry(self) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_make(access_token_expire_minutes=0)
|
||||||
|
|
||||||
|
def test_rejects_invalid_port(self) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_make(port=0)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_make(port=70000)
|
||||||
|
|
||||||
|
def test_rejects_negative_retention(self) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_make(event_log_retention_days=-1)
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
"""Discord client 429-retry bounding."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from aioresponses import aioresponses
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.discord.client import DiscordClient
|
||||||
|
|
||||||
|
|
||||||
|
WEBHOOK = "https://discord.com/api/webhooks/123/abc"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bounded_retries_on_persistent_429() -> None:
|
||||||
|
"""If every response is 429, the client gives up after _MAX_RETRIES."""
|
||||||
|
with aioresponses() as mocked:
|
||||||
|
mocked.post(WEBHOOK, status=429, headers={"Retry-After": "0.001"}, repeat=True)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as sess:
|
||||||
|
client = DiscordClient(sess)
|
||||||
|
result = await client.send(WEBHOOK, "hello")
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
# Either the custom "Rate limited" message or the bare HTTP 429 from the
|
||||||
|
# final attempt — both indicate bounded retries without infinite recursion.
|
||||||
|
assert "429" in result["error"] or "Rate limited" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_caps_retry_after() -> None:
|
||||||
|
"""A malicious Retry-After: 99999 must not pin the task for hours."""
|
||||||
|
with aioresponses() as mocked:
|
||||||
|
# First call: absurd Retry-After. Second call: success.
|
||||||
|
mocked.post(WEBHOOK, status=429, headers={"Retry-After": "99999"})
|
||||||
|
mocked.post(WEBHOOK, status=204)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as sess:
|
||||||
|
client = DiscordClient(sess)
|
||||||
|
# Override the cap to something trivial so the test completes fast.
|
||||||
|
client._MAX_RETRY_AFTER = 0.001 # type: ignore[attr-defined]
|
||||||
|
result = await client.send(WEBHOOK, "hello")
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
"""Smoke test: app imports, /api/health returns 200, version string present."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_endpoint(tmp_data_dir) -> None: # noqa: ARG001 — fixture applies env
|
||||||
|
from notify_bridge_server.main import app
|
||||||
|
|
||||||
|
# TestClient runs the lifespan on enter/exit, so migrations run once
|
||||||
|
# against the temp data dir — a genuine integration smoke check.
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["status"] == "ok"
|
||||||
|
assert body["version"] != "0.0.0+unknown"
|
||||||
|
assert "." in body["version"] # looks like a real version
|
||||||
|
|
||||||
|
|
||||||
|
def test_ready_endpoint(tmp_data_dir) -> None: # noqa: ARG001
|
||||||
|
from notify_bridge_server.main import app
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/ready")
|
||||||
|
# By the time TestClient yields, lifespan startup has completed.
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "ready"
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_is_anonymous(tmp_data_dir) -> None: # noqa: ARG001
|
||||||
|
"""/api/health must not require auth — the Docker healthcheck depends on it."""
|
||||||
|
from notify_bridge_server.main import app
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
@@ -0,0 +1,74 @@
|
|||||||
|
"""JWT encode/decode round-trips."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import jwt as pyjwt
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from notify_bridge_server.auth.jwt import (
|
||||||
|
ALGORITHM,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
decode_token,
|
||||||
|
)
|
||||||
|
from notify_bridge_server.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def test_access_token_round_trip() -> None:
|
||||||
|
token = create_access_token(user_id=1, role="admin", token_version=3)
|
||||||
|
payload = decode_token(token)
|
||||||
|
assert payload["sub"] == "1"
|
||||||
|
assert payload["type"] == "access"
|
||||||
|
assert payload["role"] == "admin"
|
||||||
|
assert payload["ver"] == 3
|
||||||
|
assert payload["iss"] == settings.jwt_issuer
|
||||||
|
assert payload["aud"] == settings.jwt_audience
|
||||||
|
|
||||||
|
|
||||||
|
def test_refresh_token_round_trip() -> None:
|
||||||
|
token = create_refresh_token(user_id=7, token_version=2)
|
||||||
|
payload = decode_token(token)
|
||||||
|
assert payload["type"] == "refresh"
|
||||||
|
assert payload["sub"] == "7"
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_rejects_wrong_audience() -> None:
|
||||||
|
"""A token signed with our key but for a different audience is rejected."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
forged = pyjwt.encode(
|
||||||
|
{
|
||||||
|
"iss": settings.jwt_issuer,
|
||||||
|
"aud": "other-service",
|
||||||
|
"sub": "1",
|
||||||
|
"type": "access",
|
||||||
|
"ver": 1,
|
||||||
|
"iat": now,
|
||||||
|
"exp": now + timedelta(minutes=5),
|
||||||
|
},
|
||||||
|
settings.secret_key,
|
||||||
|
algorithm=ALGORITHM,
|
||||||
|
)
|
||||||
|
with pytest.raises(pyjwt.InvalidAudienceError):
|
||||||
|
decode_token(forged)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_rejects_none_alg() -> None:
|
||||||
|
"""An ``alg: none`` token must never be accepted."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
forged = pyjwt.encode(
|
||||||
|
{
|
||||||
|
"iss": settings.jwt_issuer,
|
||||||
|
"aud": settings.jwt_audience,
|
||||||
|
"sub": "1",
|
||||||
|
"type": "access",
|
||||||
|
"ver": 1,
|
||||||
|
"iat": now,
|
||||||
|
"exp": now + timedelta(minutes=5),
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
algorithm="none",
|
||||||
|
)
|
||||||
|
with pytest.raises(pyjwt.InvalidAlgorithmError):
|
||||||
|
decode_token(forged)
|
||||||
@@ -0,0 +1,116 @@
|
|||||||
|
"""Pre-migration snapshot: atomic copy + retention pruning."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
|
from notify_bridge_server.database.snapshot import (
|
||||||
|
prune_old_snapshots,
|
||||||
|
snapshot_and_prune,
|
||||||
|
snapshot_database,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def sqlite_engine(tmp_path: Path):
|
||||||
|
"""Tiny SQLite DB with one table + one row, closed cleanly after the test."""
|
||||||
|
db_path = tmp_path / "app.db"
|
||||||
|
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.execute(text("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)"))
|
||||||
|
await conn.execute(text("INSERT INTO t (v) VALUES ('seed')"))
|
||||||
|
yield engine, db_path, tmp_path / "backups"
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
class TestSnapshot:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_creates_consistent_copy(self, sqlite_engine) -> None:
|
||||||
|
engine, _db, backups = sqlite_engine
|
||||||
|
dest = await snapshot_database(engine, backups)
|
||||||
|
assert dest is not None
|
||||||
|
assert dest.exists()
|
||||||
|
# Can open the snapshot and see the seed row — proves it's a real DB copy.
|
||||||
|
copy = create_async_engine(f"sqlite+aiosqlite:///{dest}")
|
||||||
|
async with copy.connect() as c:
|
||||||
|
result = await c.execute(text("SELECT v FROM t"))
|
||||||
|
rows = result.all()
|
||||||
|
await copy.dispose()
|
||||||
|
assert rows == [("seed",)]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_when_db_missing(self, tmp_path: Path) -> None:
|
||||||
|
# Engine pointing at a path that doesn't exist yet.
|
||||||
|
engine = create_async_engine(
|
||||||
|
f"sqlite+aiosqlite:///{tmp_path / 'does-not-exist.db'}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
dest = await snapshot_database(engine, tmp_path / "backups")
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
assert dest is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_unsafe_label(self, sqlite_engine) -> None:
|
||||||
|
engine, _db, backups = sqlite_engine
|
||||||
|
dest = await snapshot_database(engine, backups, label="bad'; DROP TABLE t;--")
|
||||||
|
assert dest is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrune:
|
||||||
|
def _make_snapshot(self, backups: Path, age_seconds: int) -> Path:
|
||||||
|
backups.mkdir(parents=True, exist_ok=True)
|
||||||
|
ts = datetime.now(timezone.utc) - timedelta(seconds=age_seconds)
|
||||||
|
name = f"pre-migrate-{ts.strftime('%Y-%m-%dT%H-%M-%S')}.db"
|
||||||
|
p = backups / name
|
||||||
|
p.write_bytes(b"x")
|
||||||
|
mtime = ts.timestamp()
|
||||||
|
import os
|
||||||
|
os.utime(p, (mtime, mtime))
|
||||||
|
return p
|
||||||
|
|
||||||
|
def test_keeps_n_newest(self, tmp_path: Path) -> None:
|
||||||
|
backups = tmp_path / "backups"
|
||||||
|
for age in (100, 80, 60, 40, 20, 0):
|
||||||
|
self._make_snapshot(backups, age)
|
||||||
|
|
||||||
|
deleted = prune_old_snapshots(backups, keep=3)
|
||||||
|
remaining = sorted(backups.glob("pre-migrate-*.db"))
|
||||||
|
assert len(deleted) == 3
|
||||||
|
assert len(remaining) == 3
|
||||||
|
|
||||||
|
def test_keep_zero_deletes_all(self, tmp_path: Path) -> None:
|
||||||
|
backups = tmp_path / "backups"
|
||||||
|
for age in (30, 20, 10):
|
||||||
|
self._make_snapshot(backups, age)
|
||||||
|
prune_old_snapshots(backups, keep=0)
|
||||||
|
assert list(backups.glob("pre-migrate-*.db")) == []
|
||||||
|
|
||||||
|
def test_missing_dir_is_noop(self, tmp_path: Path) -> None:
|
||||||
|
assert prune_old_snapshots(tmp_path / "never-created", keep=5) == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestSnapshotAndPrune:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keep_zero_disables(self, sqlite_engine) -> None:
|
||||||
|
engine, _db, backups = sqlite_engine
|
||||||
|
result = await snapshot_and_prune(engine, backups, keep=0)
|
||||||
|
assert result is None
|
||||||
|
assert not backups.exists() or list(backups.glob("*.db")) == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_to_end(self, sqlite_engine) -> None:
|
||||||
|
engine, _db, backups = sqlite_engine
|
||||||
|
# Run twice — second run should keep both snapshots (keep=5).
|
||||||
|
a = await snapshot_and_prune(engine, backups, keep=5)
|
||||||
|
# Guarantee distinct filenames (timestamp has second resolution).
|
||||||
|
await asyncio.sleep(1.05)
|
||||||
|
b = await snapshot_and_prune(engine, backups, keep=5)
|
||||||
|
assert a and b and a != b
|
||||||
|
assert a.exists() and b.exists()
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
"""SSRF guard regression tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from notify_bridge_core.notifications.ssrf import (
|
||||||
|
UnsafeURLError,
|
||||||
|
avalidate_outbound_url,
|
||||||
|
validate_outbound_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScheme:
|
||||||
|
def test_rejects_file_scheme(self) -> None:
|
||||||
|
with pytest.raises(UnsafeURLError):
|
||||||
|
validate_outbound_url("file:///etc/passwd")
|
||||||
|
|
||||||
|
def test_rejects_gopher(self) -> None:
|
||||||
|
with pytest.raises(UnsafeURLError):
|
||||||
|
validate_outbound_url("gopher://example.com/")
|
||||||
|
|
||||||
|
def test_accepts_https(self) -> None:
|
||||||
|
# A well-known public host — validated via real DNS so this test is
|
||||||
|
# skipped when offline.
|
||||||
|
try:
|
||||||
|
validate_outbound_url("https://example.com/")
|
||||||
|
except UnsafeURLError as err:
|
||||||
|
if "DNS" in str(err):
|
||||||
|
pytest.skip("No DNS in test environment")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlockedRanges:
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"url",
|
||||||
|
[
|
||||||
|
"http://127.0.0.1/",
|
||||||
|
"http://10.0.0.1/",
|
||||||
|
"http://192.168.1.1/",
|
||||||
|
"http://169.254.169.254/latest/meta-data/",
|
||||||
|
"http://[::1]/",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rejects_literal_private(self, url: str) -> None:
|
||||||
|
with pytest.raises(UnsafeURLError):
|
||||||
|
validate_outbound_url(url)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncValidator:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_rejects_loopback(self) -> None:
|
||||||
|
with pytest.raises(UnsafeURLError):
|
||||||
|
await avalidate_outbound_url("http://127.0.0.1/")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_rejects_bad_scheme(self) -> None:
|
||||||
|
with pytest.raises(UnsafeURLError):
|
||||||
|
await avalidate_outbound_url("file:///etc/passwd")
|
||||||
@@ -24,7 +24,7 @@ fi
|
|||||||
|
|
||||||
# Start backend
|
# Start backend
|
||||||
export NOTIFY_BRIDGE_DATA_DIR=./test-data
|
export NOTIFY_BRIDGE_DATA_DIR=./test-data
|
||||||
export NOTIFY_BRIDGE_SECRET_KEY=test-secret-key-minimum-32-chars
|
export NOTIFY_BRIDGE_SECRET_KEY=dev-only-pwIOUsKmfn4CYWQ9hCRs5lmI3GgrVlXSu2nqFzGW
|
||||||
# Dev targets (homelab Immich / Gitea / etc.) live on RFC1918 ranges; the SSRF
|
# Dev targets (homelab Immich / Gitea / etc.) live on RFC1918 ranges; the SSRF
|
||||||
# guard rejects private addresses by default, which would make trackers fail.
|
# guard rejects private addresses by default, which would make trackers fail.
|
||||||
export NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
|
export NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
|
||||||
|
|||||||
Reference in New Issue
Block a user