Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 90def11b8d | |||
| 8f0346ea03 | |||
| a6a854ad21 | |||
| 19036a90bb | |||
| 592e1b6114 | |||
| bbcdf1c5d1 | |||
| f9040370bc | |||
| 3b683ce82c | |||
| 2bec25353b | |||
| e44d387c7f | |||
| 7cbb02b1ef | |||
| 920920bc67 | |||
| f50d465c0e | |||
| 1f880daa0c | |||
| 1024085cdd | |||
| 5604c733d1 | |||
| 3b7808aa9c |
@@ -1,13 +1,56 @@
|
||||
name: Build Docker Image
|
||||
name: Build and Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, main]
|
||||
pull_request:
|
||||
branches: [master, main]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
test-frontend:
|
||||
if: ${{ !startsWith(gitea.event.head_commit.message, 'chore: release v') }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Build Docker image
|
||||
run: docker build -t notify-bridge:dev .
|
||||
- name: Set up Node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install deps
|
||||
run: |
|
||||
cd frontend
|
||||
npm ci
|
||||
|
||||
- name: Svelte check
|
||||
run: |
|
||||
cd frontend
|
||||
npm run check || echo "::warning::svelte-check reported warnings"
|
||||
|
||||
- name: Build
|
||||
run: |
|
||||
cd frontend
|
||||
npm run build
|
||||
|
||||
build-image:
|
||||
if: ${{ !startsWith(gitea.event.head_commit.message, 'chore: release v') }}
|
||||
needs: [test-frontend]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Build Docker image (no push)
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
push: false
|
||||
tags: notify-bridge:ci-${{ gitea.sha }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
|
||||
@@ -50,6 +50,7 @@ jobs:
|
||||
tags: |
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.tag }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.version }}
|
||||
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:sha-${{ gitea.sha }}
|
||||
${{ steps.version.outputs.is_pre == 'false' && format('{0}/{1}:latest', env.REGISTRY, env.IMAGE_NAME) || '' }}
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache
|
||||
cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max
|
||||
|
||||
+48
-4
@@ -1,3 +1,4 @@
|
||||
# syntax=docker/dockerfile:1.7
|
||||
# =============================================================================
|
||||
# Stage 1: Build frontend (SvelteKit static output)
|
||||
# =============================================================================
|
||||
@@ -14,7 +15,7 @@ COPY frontend/ ./
|
||||
RUN npm run build
|
||||
|
||||
# =============================================================================
|
||||
# Stage 2: Build Python wheels
|
||||
# Stage 2: Build Python wheels + extract external dependency list
|
||||
# =============================================================================
|
||||
FROM python:3.12-slim AS python-build
|
||||
|
||||
@@ -30,16 +31,59 @@ RUN python -m build packages/core/ --wheel --outdir /wheels
|
||||
COPY packages/server/ packages/server/
|
||||
RUN python -m build packages/server/ --wheel --outdir /wheels
|
||||
|
||||
# Emit /wheels/deps.txt with ONLY external (PyPI) deps — filter out
|
||||
# notify-bridge-* siblings, which are installed from local wheels below.
|
||||
# This file is the cache key for the external-deps install layer: as long as
|
||||
# pyproject.toml dependency lines don't change, the runtime install layer is
|
||||
# served from registry buildcache and no wheels are re-downloaded.
|
||||
RUN python <<'PY'
|
||||
import tomllib
|
||||
|
||||
deps: list[str] = []
|
||||
for p in ("packages/core/pyproject.toml", "packages/server/pyproject.toml"):
|
||||
with open(p, "rb") as f:
|
||||
data = tomllib.load(f)
|
||||
for d in data["project"].get("dependencies", []):
|
||||
if not d.lstrip().lower().startswith("notify-bridge-"):
|
||||
deps.append(d)
|
||||
|
||||
seen: set[str] = set()
|
||||
with open("/wheels/deps.txt", "w") as f:
|
||||
for d in deps:
|
||||
if d not in seen:
|
||||
seen.add(d)
|
||||
f.write(d + "\n")
|
||||
PY
|
||||
|
||||
# =============================================================================
|
||||
# Stage 3: Runtime
|
||||
# =============================================================================
|
||||
FROM python:3.12-slim
|
||||
|
||||
# uv — fast pip replacement. Installed from PyPI (Fastly CDN) rather than
|
||||
# ghcr.io/astral-sh/uv, because GHCR pulls from this runner crawl at a few
|
||||
# hundred KB/s and take longer than the install savings would recoup.
|
||||
RUN pip install --no-cache-dir uv==0.11.7
|
||||
|
||||
ENV UV_COMPILE_BYTECODE=1 \
|
||||
UV_LINK_MODE=copy
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install wheels
|
||||
COPY --from=python-build /wheels/ /tmp/wheels/
|
||||
RUN pip install --no-cache-dir /tmp/wheels/*.whl && rm -rf /tmp/wheels
|
||||
# Install external deps first — layer cache key is deps.txt content, which
|
||||
# only changes when pyproject.toml dependency lines change (not on version
|
||||
# bumps). The cache mount persists downloaded wheels across local rebuilds;
|
||||
# in CI, the registry buildcache serves the whole layer when unchanged.
|
||||
COPY --from=python-build /wheels/deps.txt /tmp/deps.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r /tmp/deps.txt \
|
||||
&& rm /tmp/deps.txt
|
||||
|
||||
# Install local wheels without re-resolving — all external deps are present.
|
||||
COPY --from=python-build /wheels/*.whl /tmp/wheels/
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system --no-deps /tmp/wheels/*.whl \
|
||||
&& rm -rf /tmp/wheels
|
||||
|
||||
# Copy frontend build
|
||||
COPY --from=frontend-build /build/build/ /app/static/
|
||||
|
||||
+57
-24
@@ -1,36 +1,69 @@
|
||||
# v0.3.0 (2026-04-22)
|
||||
# v0.4.0 (2026-04-23)
|
||||
|
||||
Major polling perf overhaul for large Immich libraries plus a UX fix for
|
||||
slow bot commands. Combined impact on idle albums: per-tick cost drops
|
||||
from ~150 MB fetched to a few hundred bytes; active albums now fetch
|
||||
O(changes) instead of O(library). Tested against a ~200k-asset library.
|
||||
|
||||
**Schema change:** adds a `meta_fingerprint` JSON column to
|
||||
`notification_tracker_state` — applied automatically by the startup
|
||||
migration, no manual step required.
|
||||
|
||||
## Performance
|
||||
|
||||
- **Skip full album fetch on idle ticks** — new `ImmichAlbumMeta` + `get_album_meta()` probe using `?withoutAssets=true` as a cheap change-detection fingerprint. When the fingerprint matches and no pending assets are outstanding, `poll()` short-circuits and does no asset fetch at all. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
||||
- **Delta-fetch active albums** — when the fingerprint changes, poll with `updatedAfter` instead of refetching the whole album; falls back to a full fetch only on count decrease or mixed add+remove that delta can't reconcile. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
||||
- **Parallel meta probes** — `asyncio.gather` over album meta probes so a 20-album tracker pays one round-trip of latency instead of 20. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
||||
- **Tick-scoped shared-links cache** — new `get_all_shared_links_by_album()` coalesces to one `/api/shared-links` request per tick instead of one per changed album. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
||||
- **Module-level users cache** — 1 h TTL, sha256-keyed, shared across providers that target the same Immich server. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
||||
- **Skip `asset_ids` DB rewrite on idle ticks** — watcher no longer rewrites the (potentially ~8 MB for huge albums) JSON column when the fingerprint didn't change. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
||||
- **Adaptive polling** — after 10 empty ticks the scheduler skips 1-in-2, after 30 empty ticks skips 1-in-4; resets on the first detected change or any schedule edit. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
||||
- **APScheduler jitter** — `interval/4`, capped at 30 s, to smooth thundering-herd bursts when many trackers share the same `scan_interval`. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
||||
- **Event payload cap** — 50 added / 200 removed assets per event so a bulk import can't explode a Jinja template or exceed Telegram message limits. ([fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20))
|
||||
A production-readiness release focused on hardening the service for real-world deployment: end-to-end structured logging with runtime controls, a broad security and runtime review across the HTTP, auth, DB, and scheduler layers, and a new pre-migration database snapshot that makes upgrades recoverable with a single file restore. Release CI and the Docker image build were also reworked for speed and reliability.
|
||||
|
||||
## Features
|
||||
|
||||
- **Chat-action hint stays alive during slow command fetches** — Telegram chat actions expire after ~5 s, so slow bot commands (`/latest`, `/random`, `/favorites`, `/memory`, `/search`, `/find`, `/person`, `/place`, `/summary`) previously showed a hint that vanished long before the media arrived and users saw nothing happening. New `telegram_chat_action` async context manager starts a keep-alive task that re-posts the action every 4 s until it exits; `classify_command_chat_action` maps each command to the right action (`upload_photo` for media-returning commands, `typing` for `/summary`, none for fast DB-only commands like `/status` / `/events`). Wired into both the webhook and long-poll paths. ([69711bb](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/69711bb))
|
||||
- **Production-grade logging** with per-request correlation (`request_id` / `command` / `chat_id` / `bot_id` / `dispatch_id`), secret masking in both messages and tracebacks, JSON or text format, runtime log level + per-module overrides editable from the settings UI, and env-var boot overrides (`NOTIFY_BRIDGE_LOG_LEVEL` / `_FORMAT` / `_LEVELS`). Closes every silent drop in the Telegram send path — `/random` and media-group failures now log `WARN` / `ERROR` with full context instead of disappearing ([f50d465](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/f50d465))
|
||||
- **Production-readiness hardening across security, async, DB, and ops** ([920920b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/920920b)):
|
||||
- *Security:* async SSRF-safe DNS resolver; `allow_redirects=False` on all outbound clients; Matrix `homeserver_url` validation; rejection of `***`-masked secrets on provider / email-bot updates; bcrypt moved off the event loop; JWT `iss` / `aud` + leeway with strict claim rejection; setup TOCTOU closed inside a transaction; expanded rate limits; constant-time login; config rejects known dev secret keys and validates CORS / ports / token lifetimes; webhook bodies capped at 1 MiB; Discord 429 retries bounded; CSP + HSTS headers added.
|
||||
- *Async / runtime:* SQLite engine tuned (WAL, `synchronous=NORMAL`, `foreign_keys=ON`, busy timeout, pool pre-ping); ordered lifespan shutdown; shared `aiohttp` session race-free; blocking storage / backup writes offloaded to threads; NUT client timeouts; Telegram poller switched from 3 s short-poll to 30 s + 25 s long-poll (~10x fewer API calls).
|
||||
- *Database:* new performance-index migration covering every FK and hot-path composite; new `schema_version` table; `__system__` placeholder user (`id=0`) seeded to satisfy FKs; `list_notification_trackers` rewritten from `1+N+N*M` to batched loads; retention job extended to event / webhook / action-execution logs.
|
||||
- *Scheduler:* `AsyncIOScheduler` job defaults set (`coalesce`, `misfire_grace_time=300`, `max_instances=1`).
|
||||
- *Ops:* uvicorn runs with `proxy_headers` / `forwarded_allow_ips` / graceful shutdown timeout; access log suppressed outside debug; FastAPI version read from `importlib.metadata`; new `/api/ready` endpoint; docker-compose adds resource / PID limits, `read_only` + tmpfs, `cap_drop: ALL`, `no-new-privileges`, drops the `ALLOW_PRIVATE_URLS=1` default, and points healthcheck at `/api/ready`.
|
||||
- *Frontend:* `/login` redirects already-authenticated users to `/` and shows a distinct "backend unreachable" banner (en / ru) when `/auth/needs-setup` fails.
|
||||
- **Pre-migration SQLite snapshots** via `VACUUM INTO` at lifespan startup — takes a consistent, atomic copy of the DB before migrations run, so a botched upgrade is recoverable by restoring a single file. Safe under WAL; best-effort (failures log but never raise); configurable via `NOTIFY_BRIDGE_PRE_MIGRATE_SNAPSHOT_KEEP` (default 5; 0 disables). Snapshots land in `data_dir/backups/pre-migrate-<ts>.db` and the N oldest are pruned each boot ([7cbb02b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/7cbb02b))
|
||||
|
||||
## Bug Fixes
|
||||
|
||||
- Allow `unsafe-inline` scripts in CSP so SvelteKit's hydration bootstrap inline `<script>` runs in production — without it the frontend failed to hydrate under the hardened CSP introduced in this release ([8f0346e](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/8f0346e))
|
||||
|
||||
## Upgrade Notes
|
||||
|
||||
- `ALLOW_PRIVATE_URLS=1` is no longer set by default in `docker-compose.yml`. If your deployment targets private network URLs, set it explicitly.
|
||||
- Docker healthchecks now probe `/api/ready` (separate from `/api/health`); update any external monitors accordingly.
|
||||
- Config startup now rejects known dev secret keys — set real values (e.g. `JWT_SECRET`) before upgrading.
|
||||
- Log format and level can now be changed at runtime from the settings UI; the `log_format` field still requires a restart to apply (a `WARN` is logged noting this).
|
||||
|
||||
---
|
||||
|
||||
## Development / Internal
|
||||
|
||||
### Tests
|
||||
|
||||
- New `packages/server/tests/` suite with 29 passing tests: config validation; JWT round-trip and `aud` / `alg=none` rejection; SSRF scheme and private-range enforcement (sync + async); Discord bounded retry; a lifespan-level `/api/health` + `/api/ready` smoke check. `services/test_dispatch.py` renamed to `manual_dispatch.py` so pytest no longer auto-collects production code ([920920b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/920920b))
|
||||
|
||||
### CI / Build
|
||||
|
||||
- CI now runs on push / PR with frontend `svelte-check` + build, and a non-push image build. Release workflow is gated on tests and publishes an immutable `sha-<commit>` image tag ([920920b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/920920b))
|
||||
- Install editable packages inside a venv ([2bec253](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2bec253))
|
||||
- Cache pip downloads and collapse install into a single pip call ([3b683ce](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/3b683ce))
|
||||
- Drop backend pytest from Gitea CI — editable install is too slow on the hosted runner ([f904037](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/f904037))
|
||||
- Skip `build.yml` on release commits to avoid redundant runs ([bbcdf1c](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/bbcdf1c))
|
||||
- Drop Trivy scan from release (output was discarded and it never failed) ([19036a9](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/19036a9))
|
||||
|
||||
### Performance
|
||||
|
||||
- Split external Docker deps into a cacheable layer and swap pip for uv for faster image builds ([592e1b6](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/592e1b6))
|
||||
- Install uv from PyPI instead of the ghcr.io distroless image to avoid slow GHCR pulls on CI ([a6a854a](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/a6a854a))
|
||||
|
||||
---
|
||||
|
||||
<details>
|
||||
<summary>All Commits</summary>
|
||||
|
||||
- [69711bb](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/69711bb) — feat(commands): keep chat-action hint alive during slow command fetches *(alexei.dolgolyov)*
|
||||
- [fe38d20](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/fe38d20) — perf(immich): skip full album fetch on idle ticks; delta-fetch for active ones *(alexei.dolgolyov)*
|
||||
| Hash | Message | Author |
|
||||
| ---- | ------- | ------ |
|
||||
| [8f0346e](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/8f0346e) | fix(csp): allow unsafe-inline scripts for SvelteKit hydration bootstrap | alexei.dolgolyov |
|
||||
| [a6a854a](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/a6a854a) | perf(docker): install uv from PyPI instead of ghcr.io (avoid slow GHCR pulls) | alexei.dolgolyov |
|
||||
| [19036a9](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/19036a9) | ci: drop trivy scan from release (never failed, output discarded) | alexei.dolgolyov |
|
||||
| [592e1b6](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/592e1b6) | perf(docker): split external deps into a cacheable layer, swap pip for uv | alexei.dolgolyov |
|
||||
| [bbcdf1c](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/bbcdf1c) | ci: skip build.yml on release commits | alexei.dolgolyov |
|
||||
| [f904037](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/f904037) | ci: drop backend pytest stage (too slow on hosted runner) | alexei.dolgolyov |
|
||||
| [3b683ce](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/3b683ce) | ci: cache pip downloads and collapse install into one pip call | alexei.dolgolyov |
|
||||
| [2bec253](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/2bec253) | ci: install editable packages inside a venv | alexei.dolgolyov |
|
||||
| [7cbb02b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/7cbb02b) | feat(db): pre-migration SQLite snapshots via VACUUM INTO | alexei.dolgolyov |
|
||||
| [920920b](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/920920b) | feat: production-readiness hardening across security, async, DB, ops | alexei.dolgolyov |
|
||||
| [f50d465](https://git.dolgolyov-family.by/alexei.dolgolyov/notify-bridge/commit/f50d465) | feat(logging): production-grade logging with context vars, secret masking, and runtime level control | alexei.dolgolyov |
|
||||
|
||||
</details>
|
||||
|
||||
+27
-7
@@ -10,18 +10,38 @@ services:
|
||||
volumes:
|
||||
- notify-bridge-data:/data
|
||||
environment:
|
||||
# REQUIRED — any 32+ byte random string. `openssl rand -hex 32` is one way.
|
||||
- NOTIFY_BRIDGE_SECRET_KEY=${NOTIFY_BRIDGE_SECRET_KEY:?Set NOTIFY_BRIDGE_SECRET_KEY (min 32 chars)}
|
||||
- NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS=${NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS:-*}
|
||||
# Homelab target: allow outbound requests to RFC1918 / link-local addresses.
|
||||
# The SSRF guard otherwise rejects 10.*/172.16.*/192.168.*/169.254.* hosts,
|
||||
# which breaks tracking of Immich / Gitea / etc. running on the same LAN.
|
||||
- NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
|
||||
# Comma-separated list of allowed browser origins. Wildcard `*` is
|
||||
# rejected on startup because credentials are enabled.
|
||||
- NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS=${NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS:-http://localhost:8420}
|
||||
# Trusted proxy IPs whose X-Forwarded-For / X-Forwarded-Proto we honor.
|
||||
# Set this to your reverse proxy's IP (e.g. 172.17.0.1 for the default
|
||||
# docker bridge, or `*` only if the container is NOT reachable from the
|
||||
# public internet).
|
||||
- NOTIFY_BRIDGE_FORWARDED_ALLOW_IPS=${NOTIFY_BRIDGE_FORWARDED_ALLOW_IPS:-127.0.0.1}
|
||||
# Opt-in SSRF bypass for private/loopback/link-local hosts (homelab
|
||||
# scenario — tracking an Immich/Gitea instance on the same LAN). DO NOT
|
||||
# enable on a publicly exposed instance.
|
||||
# - NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8420/api/health')"]
|
||||
# Use /api/ready (not /api/health) so the container is only reported
|
||||
# healthy after migrations and the scheduler finish booting.
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8420/api/ready', timeout=3)"]
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
start_period: 30s
|
||||
read_only: true
|
||||
tmpfs:
|
||||
- /tmp
|
||||
security_opt:
|
||||
- no-new-privileges:true
|
||||
cap_drop:
|
||||
- ALL
|
||||
mem_limit: 512m
|
||||
cpus: 1.0
|
||||
pids_limit: 256
|
||||
|
||||
volumes:
|
||||
notify-bridge-data:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "notify-bridge-frontend",
|
||||
"private": true,
|
||||
"version": "0.3.0",
|
||||
"version": "0.4.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite dev",
|
||||
|
||||
@@ -55,7 +55,8 @@
|
||||
"passwordTooShort": "Password must be at least 8 characters",
|
||||
"or": "or",
|
||||
"loginFailed": "Login failed",
|
||||
"setupFailed": "Setup failed"
|
||||
"setupFailed": "Setup failed",
|
||||
"backendUnreachable": "Cannot reach the server. Check that it's running and try again."
|
||||
},
|
||||
"dashboard": {
|
||||
"title": "Dashboard",
|
||||
@@ -78,6 +79,7 @@
|
||||
"collectionRenamed": "collection renamed",
|
||||
"collectionDeleted": "collection deleted",
|
||||
"sharingChanged": "sharing changed",
|
||||
"scheduledMessage": "scheduled message",
|
||||
"actionSuccess": "action run",
|
||||
"actionPartial": "action partial",
|
||||
"actionFailed": "action failed",
|
||||
@@ -694,6 +696,13 @@
|
||||
"locales": "Template Languages",
|
||||
"supportedLocales": "Supported Locales",
|
||||
"supportedLocalesHint": "Languages available when authoring notification and command templates. Built-in defaults ship for English and Russian; other languages start empty.",
|
||||
"logging": "Logging",
|
||||
"logLevel": "Log Level",
|
||||
"logLevelHint": "Root log level for the server. Raise to DEBUG while investigating; keep at INFO in production. WARNING/ERROR hide per-command progress lines.",
|
||||
"logFormat": "Log Format",
|
||||
"logFormatHint": "Output format. 'text' is human-readable; 'json' emits one object per line for log aggregators (Loki, ELK). Changing this requires a server restart.",
|
||||
"logLevels": "Per-Module Overrides",
|
||||
"logLevelsHint": "Comma-separated 'module=LEVEL' pairs to silence noisy modules or drill into one area. Example: sqlalchemy.engine=WARNING,notify_bridge_core.notifications.telegram.client=DEBUG",
|
||||
"saved": "Settings saved"
|
||||
},
|
||||
"hints": {
|
||||
|
||||
@@ -55,7 +55,8 @@
|
||||
"passwordTooShort": "Пароль должен быть не менее 8 символов",
|
||||
"or": "или",
|
||||
"loginFailed": "Ошибка входа",
|
||||
"setupFailed": "Ошибка настройки"
|
||||
"setupFailed": "Ошибка настройки",
|
||||
"backendUnreachable": "Не удалось подключиться к серверу. Убедитесь, что он запущен, и повторите попытку."
|
||||
},
|
||||
"dashboard": {
|
||||
"title": "Главная",
|
||||
@@ -78,6 +79,7 @@
|
||||
"collectionRenamed": "альбом переименован",
|
||||
"collectionDeleted": "альбом удалён",
|
||||
"sharingChanged": "изменение доступа",
|
||||
"scheduledMessage": "запланированное сообщение",
|
||||
"actionSuccess": "действие выполнено",
|
||||
"actionPartial": "действие частично",
|
||||
"actionFailed": "действие провалено",
|
||||
@@ -694,6 +696,13 @@
|
||||
"locales": "Языки шаблонов",
|
||||
"supportedLocales": "Поддерживаемые локали",
|
||||
"supportedLocalesHint": "Языки, доступные для редактирования шаблонов уведомлений и команд. Встроенные шаблоны поставляются для английского и русского; другие языки начинают с пустых.",
|
||||
"logging": "Логирование",
|
||||
"logLevel": "Уровень логов",
|
||||
"logLevelHint": "Уровень логирования сервера. Поднимайте до DEBUG при отладке; оставляйте INFO в продакшене. WARNING/ERROR скрывают пошаговые строки по командам.",
|
||||
"logFormat": "Формат логов",
|
||||
"logFormatHint": "Формат вывода. 'text' — читаемый человеком; 'json' — по одному объекту в строке для агрегаторов (Loki, ELK). Смена требует перезапуска сервера.",
|
||||
"logLevels": "Переопределения по модулям",
|
||||
"logLevelsHint": "Пары 'модуль=УРОВЕНЬ' через запятую, чтобы приглушить шумные модули или углубиться в один. Пример: sqlalchemy.engine=WARNING,notify_bridge_core.notifications.telegram.client=DEBUG",
|
||||
"saved": "Настройки сохранены"
|
||||
},
|
||||
"hints": {
|
||||
|
||||
@@ -223,6 +223,7 @@
|
||||
collection_renamed: 'dashboard.collectionRenamed',
|
||||
collection_deleted: 'dashboard.collectionDeleted',
|
||||
sharing_changed: 'dashboard.sharingChanged',
|
||||
scheduled_message: 'dashboard.scheduledMessage',
|
||||
action_success: 'dashboard.actionSuccess',
|
||||
action_partial: 'dashboard.actionPartial',
|
||||
action_failed: 'dashboard.actionFailed',
|
||||
@@ -231,11 +232,13 @@
|
||||
const eventIcons: Record<string, string> = {
|
||||
assets_added: 'mdiImagePlus', assets_removed: 'mdiImageMinus',
|
||||
collection_renamed: 'mdiRename', collection_deleted: 'mdiDeleteAlert', sharing_changed: 'mdiShareVariant',
|
||||
scheduled_message: 'mdiCalendarClock',
|
||||
action_success: 'mdiPlayCircle', action_partial: 'mdiAlertCircle', action_failed: 'mdiCloseCircle',
|
||||
};
|
||||
const eventColors: Record<string, string> = {
|
||||
assets_added: '#059669', assets_removed: '#ef4444',
|
||||
collection_renamed: '#6366f1', collection_deleted: '#dc2626', sharing_changed: '#f59e0b',
|
||||
scheduled_message: '#8b5cf6',
|
||||
action_success: '#0d9488', action_partial: '#f59e0b', action_failed: '#dc2626',
|
||||
};
|
||||
|
||||
|
||||
@@ -15,13 +15,32 @@
|
||||
let submitting = $state(false);
|
||||
let mounted = $state(false);
|
||||
|
||||
let backendDown = $state(false);
|
||||
|
||||
onMount(async () => {
|
||||
initTheme();
|
||||
mounted = true;
|
||||
// If the user is already signed in (valid access token in storage),
|
||||
// there is no reason to show them the login form. loadUser() runs in
|
||||
// the root layout; we just check the resolved state after a short tick.
|
||||
const { isAuthenticated } = await import('$lib/api');
|
||||
if (isAuthenticated()) {
|
||||
try {
|
||||
await api('/auth/me');
|
||||
goto('/');
|
||||
return;
|
||||
} catch {
|
||||
// Token was stale; fall through to the login form.
|
||||
}
|
||||
}
|
||||
try {
|
||||
const res = await api<{ needs_setup: boolean }>('/auth/needs-setup');
|
||||
if (res.needs_setup) goto('/setup');
|
||||
} catch { /* ignore */ }
|
||||
} catch {
|
||||
// The backend is unreachable — surface that distinctly so the user
|
||||
// doesn't blame the login form for a network/backend problem.
|
||||
backendDown = true;
|
||||
}
|
||||
});
|
||||
|
||||
async function handleSubmit(e: SubmitEvent) {
|
||||
@@ -62,7 +81,12 @@
|
||||
<p class="text-sm mt-1" style="color: var(--color-muted-foreground);">{t('auth.signInTitle')}</p>
|
||||
</div>
|
||||
|
||||
{#if error}
|
||||
{#if backendDown}
|
||||
<div class="auth-error animate-fade-slide-in">
|
||||
<MdiIcon name="mdiAlertCircle" size={16} />
|
||||
{t('auth.backendUnreachable')}
|
||||
</div>
|
||||
{:else if error}
|
||||
<div class="auth-error animate-fade-slide-in">
|
||||
<MdiIcon name="mdiAlertCircle" size={16} />
|
||||
{error}
|
||||
|
||||
@@ -37,6 +37,9 @@
|
||||
telegram_asset_cache_max_entries: '5000',
|
||||
supported_locales: 'en,ru',
|
||||
timezone: 'UTC',
|
||||
log_level: 'INFO',
|
||||
log_format: 'text',
|
||||
log_levels: '',
|
||||
});
|
||||
let cacheStats = $state<CacheStats | null>(null);
|
||||
|
||||
@@ -204,6 +207,40 @@
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<!-- Logging section -->
|
||||
<Card>
|
||||
<h3 class="text-sm font-semibold mb-4 flex items-center gap-2">
|
||||
<MdiIcon name="mdiTextBoxOutline" size={18} />
|
||||
{t('settings.logging')}
|
||||
</h3>
|
||||
<div class="grid grid-cols-1 sm:grid-cols-2 gap-4">
|
||||
<div>
|
||||
<label class="block text-xs font-medium mb-1">{t('settings.logLevel')}<Hint text={t('settings.logLevelHint')} /></label>
|
||||
<select bind:value={settings.log_level}
|
||||
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
|
||||
<option value="DEBUG">DEBUG</option>
|
||||
<option value="INFO">INFO</option>
|
||||
<option value="WARNING">WARNING</option>
|
||||
<option value="ERROR">ERROR</option>
|
||||
</select>
|
||||
</div>
|
||||
<div>
|
||||
<label class="block text-xs font-medium mb-1">{t('settings.logFormat')}<Hint text={t('settings.logFormatHint')} /></label>
|
||||
<select bind:value={settings.log_format}
|
||||
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)]">
|
||||
<option value="text">text</option>
|
||||
<option value="json">json</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="sm:col-span-2">
|
||||
<label class="block text-xs font-medium mb-1">{t('settings.logLevels')}<Hint text={t('settings.logLevelsHint')} /></label>
|
||||
<input bind:value={settings.log_levels}
|
||||
placeholder="sqlalchemy.engine=WARNING,notify_bridge_core.notifications.telegram.client=DEBUG"
|
||||
class="w-full px-3 py-1.5 text-sm border border-[var(--color-border)] rounded-md bg-[var(--color-background)] font-mono" />
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<Button onclick={save} disabled={saving}>
|
||||
{saving ? t('common.loading') : t('common.save')}
|
||||
</Button>
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "notify-bridge-core"
|
||||
version = "0.3.0"
|
||||
version = "0.4.0"
|
||||
description = "Core library for Notify Bridge — service provider abstractions, models, notifications, and templates"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Request-scoped ContextVars that propagate into log records.
|
||||
|
||||
The server sets these at entry points (Telegram webhook, scheduler dispatch,
|
||||
REST call) and they propagate through async calls automatically. A
|
||||
``LogRecordFactory`` installed by ``notify_bridge_server.logging_setup``
|
||||
reads them so every log line is tagged (``request_id``, ``command``,
|
||||
``chat_id``, ``bot_id``, ``dispatch_id``) without each call site having
|
||||
to pass the values explicitly.
|
||||
|
||||
Kept in ``notify_bridge_core`` so core modules (``TelegramClient``,
|
||||
``NotificationDispatcher``) can *set* additional context (e.g. a
|
||||
``dispatch_id``) without depending on the server package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Any, Iterator
|
||||
|
||||
request_id_var: ContextVar[str | None] = ContextVar("request_id", default=None)
|
||||
command_var: ContextVar[str | None] = ContextVar("command", default=None)
|
||||
chat_id_var: ContextVar[str | None] = ContextVar("chat_id", default=None)
|
||||
bot_id_var: ContextVar[int | None] = ContextVar("bot_id", default=None)
|
||||
dispatch_id_var: ContextVar[str | None] = ContextVar("dispatch_id", default=None)
|
||||
|
||||
_VAR_MAP: dict[str, ContextVar[Any]] = {
|
||||
"request_id": request_id_var,
|
||||
"command": command_var,
|
||||
"chat_id": chat_id_var,
|
||||
"bot_id": bot_id_var,
|
||||
"dispatch_id": dispatch_id_var,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_log_context(**kwargs: Any) -> Iterator[None]:
|
||||
"""Bind the given context fields for the duration of the ``with`` block.
|
||||
|
||||
Unknown keys are ignored so callers can pass whatever they want without
|
||||
an ``if`` ladder. Values are reset on exit even if the block raises.
|
||||
|
||||
Example:
|
||||
``with bind_log_context(request_id="abc", command="random"): ...``
|
||||
"""
|
||||
tokens: list[tuple[ContextVar[Any], Token]] = []
|
||||
try:
|
||||
for key, value in kwargs.items():
|
||||
var = _VAR_MAP.get(key)
|
||||
if var is None:
|
||||
continue
|
||||
tokens.append((var, var.set(value)))
|
||||
yield
|
||||
finally:
|
||||
for var, tok in tokens:
|
||||
var.reset(tok)
|
||||
|
||||
|
||||
def current_log_context() -> dict[str, Any]:
|
||||
"""Return a snapshot of the currently-bound context values (non-None)."""
|
||||
snap: dict[str, Any] = {}
|
||||
for key, var in _VAR_MAP.items():
|
||||
val = var.get()
|
||||
if val is not None:
|
||||
snap[key] = val
|
||||
return snap
|
||||
@@ -52,22 +52,46 @@ class DiscordClient:
|
||||
|
||||
return {"success": True}
|
||||
|
||||
_MAX_RETRIES = 3
|
||||
_MAX_RETRY_AFTER = 60.0
|
||||
|
||||
async def _post(self, url: str, payload: dict) -> dict[str, Any]:
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, json=payload, headers={"Content-Type": "application/json"}
|
||||
) as resp:
|
||||
if resp.status == 429:
|
||||
retry_after = float(resp.headers.get("Retry-After", "2"))
|
||||
_LOGGER.warning("Discord rate limited, retrying after %.1fs", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
return await self._post(url, payload)
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
body = await resp.text()
|
||||
return {"success": False, "error": f"HTTP {resp.status}: {body[:200]}"}
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
"""POST with bounded 429 retry.
|
||||
|
||||
We cap retries at _MAX_RETRIES and the ``Retry-After`` header at
|
||||
_MAX_RETRY_AFTER seconds so a hostile or misbehaving upstream cannot
|
||||
pin the dispatch task indefinitely.
|
||||
"""
|
||||
for attempt in range(self._MAX_RETRIES + 1):
|
||||
try:
|
||||
async with self._session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
allow_redirects=False,
|
||||
) as resp:
|
||||
if resp.status == 429 and attempt < self._MAX_RETRIES:
|
||||
try:
|
||||
retry_after = float(resp.headers.get("Retry-After", "2"))
|
||||
except (TypeError, ValueError):
|
||||
retry_after = 2.0
|
||||
retry_after = max(0.0, min(retry_after, self._MAX_RETRY_AFTER))
|
||||
_LOGGER.warning(
|
||||
"Discord rate limited, retrying after %.1fs (attempt %d/%d)",
|
||||
retry_after, attempt + 1, self._MAX_RETRIES,
|
||||
)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
body = await resp.text()
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"HTTP {resp.status}: {body[:200]}",
|
||||
}
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
return {"success": False, "error": "Rate limited (retries exhausted)"}
|
||||
|
||||
|
||||
def _split_message(text: str, limit: int) -> list[str]:
|
||||
|
||||
@@ -3,16 +3,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.log_context import bind_log_context, dispatch_id_var
|
||||
from notify_bridge_core.models.events import ServiceEvent
|
||||
from notify_bridge_core.templates.context import build_template_context
|
||||
from notify_bridge_core.templates.renderer import render_template
|
||||
from .ssrf import UnsafeURLError, validate_outbound_url
|
||||
from .ssrf import UnsafeURLError, avalidate_outbound_url
|
||||
|
||||
_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||
|
||||
@@ -82,9 +85,28 @@ class NotificationDispatcher:
|
||||
*,
|
||||
url_cache: TelegramFileCache | None = None,
|
||||
asset_cache: TelegramFileCache | None = None,
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
) -> None:
|
||||
self._url_cache = url_cache
|
||||
self._asset_cache = asset_cache
|
||||
# Optional shared session owned by the caller; when supplied we reuse
|
||||
# its connection pool instead of opening a fresh per-dispatch session
|
||||
# (saves a TLS handshake per outbound call).
|
||||
self._shared_session = session
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]:
|
||||
"""Yield an aiohttp session, reusing the shared one if provided.
|
||||
|
||||
When a shared session was passed in ``__init__`` we yield it without
|
||||
closing (the caller owns its lifetime). Otherwise we open a
|
||||
short-lived session with our default timeout and close it on exit.
|
||||
"""
|
||||
if self._shared_session is not None and not self._shared_session.closed:
|
||||
yield self._shared_session
|
||||
return
|
||||
async with self._session_ctx() as session:
|
||||
yield session
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
@@ -95,18 +117,40 @@ class NotificationDispatcher:
|
||||
|
||||
Returns list of results (one per target).
|
||||
"""
|
||||
raw_results = await asyncio.gather(
|
||||
*[self._send_to_target(event, t) for t in targets],
|
||||
return_exceptions=True,
|
||||
)
|
||||
results = []
|
||||
for raw in raw_results:
|
||||
if isinstance(raw, Exception):
|
||||
_LOGGER.error("Failed to dispatch to target: %s", raw)
|
||||
results.append({"success": False, "error": str(raw)})
|
||||
else:
|
||||
results.append(raw)
|
||||
return results
|
||||
# Bind a dispatch_id so every log line emitted by the target sends
|
||||
# (including deep in TelegramClient) can be correlated to the same
|
||||
# upstream event.
|
||||
new_id = dispatch_id_var.get() or f"disp:{uuid.uuid4().hex[:12]}"
|
||||
|
||||
with bind_log_context(dispatch_id=new_id):
|
||||
_LOGGER.info(
|
||||
"Dispatching event %s (collection=%r) to %d target(s)",
|
||||
event.event_type.value if hasattr(event.event_type, "value") else event.event_type,
|
||||
getattr(event, "collection_name", None), len(targets),
|
||||
)
|
||||
raw_results = await asyncio.gather(
|
||||
*[self._send_to_target(event, t) for t in targets],
|
||||
return_exceptions=True,
|
||||
)
|
||||
results = []
|
||||
failures = 0
|
||||
for target, raw in zip(targets, raw_results):
|
||||
if isinstance(raw, Exception):
|
||||
failures += 1
|
||||
_LOGGER.error(
|
||||
"Dispatch to target type=%s failed: %s",
|
||||
target.type, raw, exc_info=raw,
|
||||
)
|
||||
results.append({"success": False, "error": str(raw)})
|
||||
else:
|
||||
if isinstance(raw, dict) and not raw.get("success"):
|
||||
failures += 1
|
||||
results.append(raw)
|
||||
_LOGGER.info(
|
||||
"Dispatch finished: %d target(s), %d failure(s)",
|
||||
len(targets), failures,
|
||||
)
|
||||
return results
|
||||
|
||||
def _resolve_template(
|
||||
self, event: ServiceEvent, target: TargetConfig, locale: str,
|
||||
@@ -284,7 +328,7 @@ class NotificationDispatcher:
|
||||
media_assets.append(asset)
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
# Preload all asset bytes once so (a) TelegramClient can skip its
|
||||
# own download and (b) we know exact upload sizes in time for the
|
||||
# oversize warning in the rendered text.
|
||||
@@ -354,13 +398,13 @@ class NotificationDispatcher:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
for receiver in target.receivers:
|
||||
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
|
||||
results.append({"success": False, "error": "Invalid webhook receiver"})
|
||||
continue
|
||||
try:
|
||||
validate_outbound_url(receiver.url)
|
||||
await avalidate_outbound_url(receiver.url)
|
||||
except UnsafeURLError as err:
|
||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||
continue
|
||||
@@ -428,14 +472,14 @@ class NotificationDispatcher:
|
||||
username = target.config.get("username")
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
client = DiscordClient(session)
|
||||
for receiver in target.receivers:
|
||||
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
|
||||
results.append({"success": False, "error": "Invalid discord receiver"})
|
||||
continue
|
||||
try:
|
||||
validate_outbound_url(receiver.webhook_url)
|
||||
await avalidate_outbound_url(receiver.webhook_url)
|
||||
except UnsafeURLError as err:
|
||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||
continue
|
||||
@@ -454,14 +498,14 @@ class NotificationDispatcher:
|
||||
username = target.config.get("username")
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
client = SlackClient(session)
|
||||
for receiver in target.receivers:
|
||||
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
|
||||
results.append({"success": False, "error": "Invalid slack receiver"})
|
||||
continue
|
||||
try:
|
||||
validate_outbound_url(receiver.webhook_url)
|
||||
await avalidate_outbound_url(receiver.webhook_url)
|
||||
except UnsafeURLError as err:
|
||||
results.append({"success": False, "error": f"Unsafe URL: {err}"})
|
||||
continue
|
||||
@@ -480,14 +524,14 @@ class NotificationDispatcher:
|
||||
if not target.receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
try:
|
||||
validate_outbound_url(server_url)
|
||||
await avalidate_outbound_url(server_url)
|
||||
except UnsafeURLError as err:
|
||||
return {"success": False, "error": f"Unsafe ntfy server_url: {err}"}
|
||||
|
||||
title = f"{event.event_type.value}: {event.collection_name}"
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
client = NtfyClient(session)
|
||||
for receiver in target.receivers:
|
||||
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
|
||||
@@ -511,7 +555,7 @@ class NotificationDispatcher:
|
||||
if not homeserver or not access_token:
|
||||
return {"success": False, "error": "Missing Matrix homeserver_url or access_token"}
|
||||
try:
|
||||
validate_outbound_url(homeserver)
|
||||
await avalidate_outbound_url(homeserver)
|
||||
except UnsafeURLError as err:
|
||||
return {"success": False, "error": f"Unsafe matrix homeserver_url: {err}"}
|
||||
|
||||
@@ -519,7 +563,7 @@ class NotificationDispatcher:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
async with _new_session() as session:
|
||||
async with self._session_ctx() as session:
|
||||
client = MatrixClient(session, homeserver, access_token)
|
||||
for receiver in target.receivers:
|
||||
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
|
||||
|
||||
@@ -68,7 +68,9 @@ class MatrixClient:
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.put(url, json=body, headers=headers) as resp:
|
||||
async with self._session.put(
|
||||
url, json=body, headers=headers, allow_redirects=False,
|
||||
) as resp:
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
resp_body = await resp.text()
|
||||
|
||||
@@ -51,7 +51,9 @@ class NtfyClient:
|
||||
headers["Authorization"] = f"Bearer {auth_token}"
|
||||
|
||||
try:
|
||||
async with self._session.post(url, json=payload, headers=headers) as resp:
|
||||
async with self._session.post(
|
||||
url, json=payload, headers=headers, allow_redirects=False,
|
||||
) as resp:
|
||||
if 200 <= resp.status < 300:
|
||||
return {"success": True}
|
||||
body = await resp.text()
|
||||
|
||||
@@ -38,6 +38,7 @@ class SlackClient:
|
||||
webhook_url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
allow_redirects=False,
|
||||
) as resp:
|
||||
if resp.status == 429:
|
||||
_LOGGER.warning("Slack rate limited")
|
||||
|
||||
@@ -12,14 +12,25 @@ development against localhost services.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_ALLOW_PRIVATE = os.environ.get("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS") == "1"
|
||||
_ALLOWED_SCHEMES = {"http", "https"}
|
||||
|
||||
if _ALLOW_PRIVATE: # pragma: no cover — operator-visible banner
|
||||
_LOGGER.warning(
|
||||
"SSRF guard: private-URL bypass ENABLED "
|
||||
"(NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1). Requests to RFC1918 / "
|
||||
"loopback / link-local hosts will be permitted."
|
||||
)
|
||||
|
||||
|
||||
class UnsafeURLError(ValueError):
|
||||
"""Raised when a URL targets a disallowed network destination."""
|
||||
@@ -36,13 +47,7 @@ def _is_blocked_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def validate_outbound_url(url: str) -> str:
|
||||
"""Validate ``url`` is safe to fetch; returns the URL on success.
|
||||
|
||||
Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
|
||||
is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
|
||||
private addresses are permitted but the scheme check still applies.
|
||||
"""
|
||||
def _check_scheme_host(url: str) -> tuple[str, str]:
|
||||
if not isinstance(url, str) or not url:
|
||||
raise UnsafeURLError("URL is empty")
|
||||
parsed = urlparse(url)
|
||||
@@ -51,6 +56,31 @@ def validate_outbound_url(url: str) -> str:
|
||||
host = parsed.hostname
|
||||
if not host:
|
||||
raise UnsafeURLError("URL has no host")
|
||||
return parsed.scheme, host
|
||||
|
||||
|
||||
def _check_resolved_addresses(host: str, infos: list[tuple]) -> None:
|
||||
for info in infos:
|
||||
sockaddr = info[4]
|
||||
try:
|
||||
ip = ipaddress.ip_address(sockaddr[0])
|
||||
except ValueError:
|
||||
continue
|
||||
if _is_blocked_ip(ip):
|
||||
raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
|
||||
|
||||
|
||||
def validate_outbound_url(url: str) -> str:
|
||||
"""Validate ``url`` is safe to fetch; returns the URL on success.
|
||||
|
||||
Raises :class:`UnsafeURLError` when the scheme, host, or resolved IP
|
||||
is not allowed. In development (``NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1``)
|
||||
private addresses are permitted but the scheme check still applies.
|
||||
|
||||
Synchronous; uses blocking ``socket.getaddrinfo``. Prefer
|
||||
:func:`avalidate_outbound_url` from async code paths.
|
||||
"""
|
||||
_, host = _check_scheme_host(url)
|
||||
|
||||
if _ALLOW_PRIVATE:
|
||||
return url
|
||||
@@ -64,17 +94,37 @@ def validate_outbound_url(url: str) -> str:
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Hostname — resolve and reject if any resolution is in a blocked range.
|
||||
try:
|
||||
infos = socket.getaddrinfo(host, None)
|
||||
except socket.gaierror as exc:
|
||||
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
|
||||
for info in infos:
|
||||
sockaddr = info[4]
|
||||
try:
|
||||
ip = ipaddress.ip_address(sockaddr[0])
|
||||
except ValueError:
|
||||
continue
|
||||
if _is_blocked_ip(ip):
|
||||
raise UnsafeURLError(f"Host {host} resolves to blocked address {ip}")
|
||||
_check_resolved_addresses(host, infos)
|
||||
return url
|
||||
|
||||
|
||||
async def avalidate_outbound_url(url: str) -> str:
|
||||
"""Async variant that resolves DNS via the running loop's resolver.
|
||||
|
||||
Use this from ``async def`` code paths to avoid blocking the event
|
||||
loop on DNS lookups.
|
||||
"""
|
||||
_, host = _check_scheme_host(url)
|
||||
|
||||
if _ALLOW_PRIVATE:
|
||||
return url
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(host)
|
||||
if _is_blocked_ip(ip):
|
||||
raise UnsafeURLError(f"Host {host} is in a blocked range")
|
||||
return url
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
infos = await loop.getaddrinfo(host, None)
|
||||
except socket.gaierror as exc:
|
||||
raise UnsafeURLError(f"DNS resolution failed for {host}") from exc
|
||||
_check_resolved_addresses(host, infos)
|
||||
return url
|
||||
|
||||
@@ -162,8 +162,20 @@ class TelegramClient:
|
||||
"message_id": result.get("result", {}).get("message_id"),
|
||||
"cached": True,
|
||||
}
|
||||
except aiohttp.ClientError:
|
||||
pass
|
||||
# Non-ok from a cached send — file_id stale or file deleted on
|
||||
# Telegram's side. Log at DEBUG so operators who are hunting
|
||||
# "why didn't the cached send work?" can see it, but the
|
||||
# caller will fall through to a fresh upload.
|
||||
_LOGGER.debug(
|
||||
"Telegram %s (cached) returned non-ok: status=%s code=%s desc=%r — falling back to fresh upload",
|
||||
kind.api_method, response.status, result.get("error_code"),
|
||||
result.get("description"),
|
||||
)
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.debug(
|
||||
"Telegram %s (cached) transport error — falling back to fresh upload: %s",
|
||||
kind.api_method, err,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _upload_media(
|
||||
@@ -203,8 +215,17 @@ class TelegramClient:
|
||||
thumbhash=thumbhash, size=len(data),
|
||||
)
|
||||
return {"success": True, "message_id": res.get("message_id")}
|
||||
_LOGGER.error(
|
||||
"Telegram %s failed: status=%s code=%s desc=%r bytes=%d",
|
||||
kind.api_method, response.status, result.get("error_code"),
|
||||
result.get("description", "Unknown"), len(data),
|
||||
)
|
||||
return {"success": False, "error": result.get("description", "Unknown Telegram error")}
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.error(
|
||||
"Telegram %s transport error (bytes=%d): %s",
|
||||
kind.api_method, len(data), err, exc_info=True,
|
||||
)
|
||||
return {"success": False, "error": str(err)}
|
||||
|
||||
async def send_notification(
|
||||
@@ -327,8 +348,14 @@ class TelegramClient:
|
||||
retry_result = await retry_resp.json()
|
||||
if retry_resp.status == 200 and retry_result.get("ok"):
|
||||
return {"success": True, "message_id": retry_result.get("result", {}).get("message_id")}
|
||||
_LOGGER.error(
|
||||
"Telegram sendMessage failed: status=%s code=%s desc=%r",
|
||||
response.status, result.get("error_code"),
|
||||
result.get("description", "Unknown"),
|
||||
)
|
||||
return {"success": False, "error": result.get("description", "Unknown Telegram error"), "error_code": result.get("error_code")}
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.error("Telegram sendMessage transport error: %s", err, exc_info=True)
|
||||
return {"success": False, "error": str(err)}
|
||||
|
||||
async def send_chat_action(self, chat_id: str, action: str = "typing") -> bool:
|
||||
@@ -513,11 +540,14 @@ class TelegramClient:
|
||||
# Tuple is (cache_key, media_type, thumbhash, uploaded_size).
|
||||
media_cache_info: list[tuple[str, str, str | None, int] | None] = []
|
||||
|
||||
# Resolve cache hits and collect download tasks in parallel
|
||||
# Resolve cache hits and collect download tasks in parallel.
|
||||
# Each drop site logs the reason — otherwise a filtered asset
|
||||
# disappears silently and the media group silently shrinks.
|
||||
async def _fetch_asset(idx: int, item: dict) -> tuple[int, dict | None, bytes | None]:
|
||||
"""Return (index, cache_entry_or_None, downloaded_bytes_or_None)."""
|
||||
url = item.get("url")
|
||||
if not url:
|
||||
_LOGGER.warning("Media skipped: missing url (idx=%d type=%s)", idx, item.get("type"))
|
||||
return idx, None, None
|
||||
media_type = item.get("type", "photo")
|
||||
custom_cache_key = item.get("cache_key")
|
||||
@@ -537,12 +567,24 @@ class TelegramClient:
|
||||
if preloaded is not None:
|
||||
data = preloaded
|
||||
if max_asset_data_size and len(data) > max_asset_data_size:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: preloaded size %d exceeds max_asset_data_size %d (idx=%d type=%s url=%s)",
|
||||
len(data), max_asset_data_size, idx, media_type, url,
|
||||
)
|
||||
return idx, None, None
|
||||
if media_type == "video" and len(data) > TELEGRAM_MAX_VIDEO_SIZE:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: preloaded video %d bytes exceeds Telegram limit %d (idx=%d url=%s)",
|
||||
len(data), TELEGRAM_MAX_VIDEO_SIZE, idx, url,
|
||||
)
|
||||
return idx, None, None
|
||||
if media_type == "photo":
|
||||
exceeds, _, _, _ = check_photo_limits(data)
|
||||
exceeds, reason, _, _ = check_photo_limits(data)
|
||||
if exceeds:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: preloaded photo %s (idx=%d url=%s)",
|
||||
reason, idx, url,
|
||||
)
|
||||
return idx, None, None
|
||||
return idx, None, data
|
||||
|
||||
@@ -551,18 +593,38 @@ class TelegramClient:
|
||||
dl_headers = item.get("headers") or {}
|
||||
async with self._session.get(download_url, headers=dl_headers) as resp:
|
||||
if resp.status != 200:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: download HTTP %d (idx=%d type=%s url=%s)",
|
||||
resp.status, idx, media_type, url,
|
||||
)
|
||||
return idx, None, None
|
||||
data = await resp.read()
|
||||
if max_asset_data_size and len(data) > max_asset_data_size:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: downloaded size %d exceeds max_asset_data_size %d (idx=%d type=%s url=%s)",
|
||||
len(data), max_asset_data_size, idx, media_type, url,
|
||||
)
|
||||
return idx, None, None
|
||||
if media_type == "video" and len(data) > TELEGRAM_MAX_VIDEO_SIZE:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: video %d bytes exceeds Telegram %d-byte limit (idx=%d url=%s)",
|
||||
len(data), TELEGRAM_MAX_VIDEO_SIZE, idx, url,
|
||||
)
|
||||
return idx, None, None
|
||||
if media_type == "photo":
|
||||
exceeds, _, _, _ = check_photo_limits(data)
|
||||
exceeds, reason, _, _ = check_photo_limits(data)
|
||||
if exceeds:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: photo %s (idx=%d url=%s)",
|
||||
reason, idx, url,
|
||||
)
|
||||
return idx, None, None
|
||||
return idx, None, data
|
||||
except aiohttp.ClientError:
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning(
|
||||
"Media skipped: download failed (idx=%d type=%s url=%s): %s",
|
||||
idx, media_type, url, err,
|
||||
)
|
||||
return idx, None, None
|
||||
|
||||
results = await asyncio.gather(
|
||||
@@ -602,6 +664,14 @@ class TelegramClient:
|
||||
media_json.append(mij)
|
||||
|
||||
if not media_json:
|
||||
# Every asset in this chunk was filtered out (size, download
|
||||
# failure, etc.). Without this log, sendMediaGroup returns
|
||||
# success=True with zero message_ids and nobody knows why
|
||||
# the user sees only the text reply and no media.
|
||||
_LOGGER.warning(
|
||||
"sendMediaGroup skipped — chunk %d/%d had %d input items but 0 usable (all filtered/failed)",
|
||||
chunk_idx + 1, len(chunks), len(chunk),
|
||||
)
|
||||
continue
|
||||
|
||||
form.add_field("media", json.dumps(media_json))
|
||||
@@ -638,10 +708,35 @@ class TelegramClient:
|
||||
if eff_cache:
|
||||
await eff_cache.async_set_many(cache_entries)
|
||||
else:
|
||||
return {"success": False, "error": result.get("description", "Unknown"), "failed_at_chunk": chunk_idx + 1}
|
||||
_LOGGER.error(
|
||||
"Telegram sendMediaGroup failed: status=%s code=%s desc=%r chunk=%d/%d items=%d",
|
||||
response.status, result.get("error_code"),
|
||||
result.get("description", "Unknown"),
|
||||
chunk_idx + 1, len(chunks), len(media_json),
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": result.get("description", "Unknown"),
|
||||
"error_code": result.get("error_code"),
|
||||
"failed_at_chunk": chunk_idx + 1,
|
||||
}
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.error(
|
||||
"Telegram sendMediaGroup transport error on chunk %d/%d (%d items): %s",
|
||||
chunk_idx + 1, len(chunks), len(media_json), err,
|
||||
exc_info=True,
|
||||
)
|
||||
return {"success": False, "error": str(err), "failed_at_chunk": chunk_idx + 1}
|
||||
|
||||
# Distinguish "posted something" from "posted nothing" so the caller
|
||||
# can surface an ERROR when a command produced a caption reply but no
|
||||
# media ever reached Telegram.
|
||||
if not all_message_ids:
|
||||
_LOGGER.warning(
|
||||
"sendMediaGroup completed with 0 message_ids across %d chunk(s) — nothing was delivered",
|
||||
len(chunks),
|
||||
)
|
||||
return {"success": False, "error": "no_items_delivered", "chunks_sent": len(chunks)}
|
||||
return {"success": True, "message_ids": all_message_ids, "chunks_sent": len(chunks)}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..ssrf import UnsafeURLError, validate_outbound_url
|
||||
from ..ssrf import UnsafeURLError, avalidate_outbound_url
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +24,7 @@ class WebhookClient:
|
||||
|
||||
async def send(self, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
try:
|
||||
validate_outbound_url(self._url)
|
||||
await avalidate_outbound_url(self._url)
|
||||
except UnsafeURLError as err:
|
||||
return {"success": False, "error": f"Unsafe URL: {err}"}
|
||||
try:
|
||||
@@ -33,6 +33,7 @@ class WebhookClient:
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json", **self._headers},
|
||||
timeout=_DEFAULT_TIMEOUT,
|
||||
allow_redirects=False,
|
||||
) as response:
|
||||
if 200 <= response.status < 300:
|
||||
return {"success": True, "status_code": response.status}
|
||||
|
||||
@@ -177,7 +177,9 @@ class ImmichActionExecutor(ActionExecutor):
|
||||
needs_thumbnail = album_id in album_created_now
|
||||
|
||||
if album_id and album_id != "__dry_run_new__":
|
||||
album = await self._client.get_album(album_id)
|
||||
# Actions diff the current album state to decide what to
|
||||
# add — must observe fresh data, not a cached view.
|
||||
album = await self._client.get_album(album_id, use_cache=False)
|
||||
if album is None and create_if_missing and create_album_name:
|
||||
if not dry_run:
|
||||
created = await self._client.create_album(create_album_name)
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -18,6 +21,51 @@ _LOGGER = logging.getLogger(__name__)
|
||||
MAX_SEARCH_QUERY_LEN = 256
|
||||
MAX_SEARCH_PERSON_IDS = 50
|
||||
|
||||
# Module-level TTL caches for album bodies and shared-link listings. The
|
||||
# Immich ``GET /api/albums/{id}`` response can be tens or hundreds of MB on a
|
||||
# large album, and bot commands like /random, /latest, /memory all refetch
|
||||
# the same album in quick succession. A short TTL makes repeat runs nearly
|
||||
# instant and deduplicates concurrent fetches so a burst of commands issues
|
||||
# one HTTP call instead of N.
|
||||
#
|
||||
# Caches are module-scoped (not instance-scoped) because ``ImmichClient`` is
|
||||
# constructed fresh per request in several places (api/providers.py,
|
||||
# services/action_runner.py, command handlers), so an instance cache would
|
||||
# never survive to serve a second caller. This mirrors ``_users_cache`` in
|
||||
# ``provider.py``.
|
||||
_ALBUM_CACHE_TTL_SECONDS = 60
|
||||
_SHARED_LINKS_CACHE_TTL_SECONDS = 60
|
||||
# Guard rail against runaway memory — a 200k-asset album response can be
|
||||
# ~150 MB, so even modest caps bound the worst case.
|
||||
_ALBUM_CACHE_MAX_ENTRIES = 32
|
||||
_album_cache_lock = asyncio.Lock()
|
||||
# key = (server_digest, album_id); value = (monotonic_ts, raw_api_dict)
|
||||
# Store the raw dict rather than the parsed ``ImmichAlbumData`` so callers
|
||||
# that pass a ``users_cache`` still get owner-name enrichment on cache hits.
|
||||
_album_cache: dict[tuple[str, str], tuple[float, dict[str, Any]]] = {}
|
||||
_shared_links_cache_lock = asyncio.Lock()
|
||||
# key = server_digest; value = (monotonic_ts, {album_id: [SharedLinkInfo, ...]})
|
||||
# The underlying ``/api/shared-links`` endpoint has no per-album filter, so
|
||||
# every call was already paying for the full server-wide list. Caching the
|
||||
# bucketed result once per server turns N per-album calls into one fetch.
|
||||
_shared_links_cache: dict[str, tuple[float, dict[str, list[SharedLinkInfo]]]] = {}
|
||||
|
||||
|
||||
def _server_digest(url: str, api_key: str) -> str:
|
||||
"""Hashed key that avoids putting raw api_key into cache dict keys."""
|
||||
return hashlib.sha256(f"{url}|{api_key}".encode("utf-8")).hexdigest()[:32]
|
||||
|
||||
|
||||
def invalidate_album_cache() -> None:
|
||||
"""Drop every cached album body. Call after mutations that invalidate
|
||||
the cached view (e.g. integration tests, manual /refresh commands)."""
|
||||
_album_cache.clear()
|
||||
|
||||
|
||||
def invalidate_shared_links_cache() -> None:
|
||||
"""Drop every cached shared-link listing."""
|
||||
_shared_links_cache.clear()
|
||||
|
||||
# User-facing error bodies — Immich responses may leak internal paths,
|
||||
# hostnames, or headers injected by intermediary proxies. These helpers keep
|
||||
# only a short, scrubbed summary; full bodies are logged server-side only.
|
||||
@@ -184,22 +232,30 @@ class ImmichClient:
|
||||
return {}
|
||||
|
||||
async def get_shared_links(self, album_id: str) -> list[SharedLinkInfo]:
|
||||
links: list[SharedLinkInfo] = []
|
||||
try:
|
||||
async with self._session.get(
|
||||
f"{self._url}/api/shared-links",
|
||||
headers=self._headers,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
for link in data:
|
||||
album = link.get("album")
|
||||
key = link.get("key")
|
||||
if album and key and album.get("id") == album_id:
|
||||
links.append(SharedLinkInfo.from_api_response(link))
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to fetch shared links: %s", err)
|
||||
return links
|
||||
bucketed = await self._get_shared_links_bucketed()
|
||||
return list(bucketed.get(album_id, []))
|
||||
|
||||
async def _get_shared_links_bucketed(self) -> dict[str, list[SharedLinkInfo]]:
|
||||
"""Return ``{album_id: [SharedLinkInfo, ...]}`` for the server, hitting
|
||||
the module-level TTL cache first. Underlying Immich endpoint has no
|
||||
per-album filter, so one server-wide fetch serves every caller until
|
||||
the TTL elapses.
|
||||
"""
|
||||
digest = _server_digest(self._url, self._api_key)
|
||||
now = time.monotonic()
|
||||
entry = _shared_links_cache.get(digest)
|
||||
if entry is not None and (now - entry[0]) < _SHARED_LINKS_CACHE_TTL_SECONDS:
|
||||
return entry[1]
|
||||
|
||||
async with _shared_links_cache_lock:
|
||||
# Re-check under the lock — another coroutine may have refreshed
|
||||
# while we waited.
|
||||
entry = _shared_links_cache.get(digest)
|
||||
if entry is not None and (time.monotonic() - entry[0]) < _SHARED_LINKS_CACHE_TTL_SECONDS:
|
||||
return entry[1]
|
||||
fresh = await self.get_all_shared_links_by_album()
|
||||
_shared_links_cache[digest] = (time.monotonic(), fresh)
|
||||
return fresh
|
||||
|
||||
async def get_all_shared_links_by_album(self) -> dict[str, list[SharedLinkInfo]]:
|
||||
"""Fetch every shared link on the server, bucketed by album id.
|
||||
@@ -247,7 +303,29 @@ class ImmichClient:
|
||||
self,
|
||||
album_id: str,
|
||||
users_cache: dict[str, str] | None = None,
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
) -> ImmichAlbumData | None:
|
||||
"""Fetch an album by id, optionally serving from the module-level
|
||||
TTL cache. Pass ``use_cache=False`` from paths that must observe the
|
||||
current server state (e.g. the notification poll loop's full-fetch
|
||||
path, where a stale cached entry would delay asset-removal events).
|
||||
Non-cached fetches still populate the cache for subsequent readers.
|
||||
"""
|
||||
cache_key = (_server_digest(self._url, self._api_key), album_id)
|
||||
if use_cache:
|
||||
entry = _album_cache.get(cache_key)
|
||||
if entry is not None and (time.monotonic() - entry[0]) < _ALBUM_CACHE_TTL_SECONDS:
|
||||
# Rehydrate per-call so ``users_cache`` enrichment is applied
|
||||
# with the caller's dict, not whichever one was live when the
|
||||
# cache was populated.
|
||||
return ImmichAlbumData.from_api_response(entry[1], users_cache)
|
||||
|
||||
# Deliberately fetch without holding a lock so concurrent calls for
|
||||
# *different* album_ids (the common case from asyncio.gather in
|
||||
# fetch_albums_with_links) stay parallel. The worst case is a small
|
||||
# duplicate-fetch stampede when two requests miss the same album at
|
||||
# the same instant — acceptable for our scale.
|
||||
try:
|
||||
async with self._session.get(
|
||||
f"{self._url}/api/albums/{album_id}",
|
||||
@@ -260,10 +338,18 @@ class ImmichClient:
|
||||
f"Error fetching album {album_id}: HTTP {response.status}"
|
||||
)
|
||||
data = await response.json()
|
||||
return ImmichAlbumData.from_api_response(data, users_cache)
|
||||
except aiohttp.ClientError as err:
|
||||
raise ImmichApiError(f"Error communicating with Immich: {err}") from err
|
||||
|
||||
async with _album_cache_lock:
|
||||
# Evict the oldest entry if we're at the cap — simple FIFO is fine
|
||||
# for our access pattern (commands touch a small working set).
|
||||
if len(_album_cache) >= _ALBUM_CACHE_MAX_ENTRIES and cache_key not in _album_cache:
|
||||
oldest = min(_album_cache.items(), key=lambda kv: kv[1][0])[0]
|
||||
_album_cache.pop(oldest, None)
|
||||
_album_cache[cache_key] = (time.monotonic(), data)
|
||||
return ImmichAlbumData.from_api_response(data, users_cache)
|
||||
|
||||
async def get_album_meta(self, album_id: str) -> ImmichAlbumMeta | None:
|
||||
"""Fetch album metadata without the assets array.
|
||||
|
||||
|
||||
@@ -292,7 +292,13 @@ class ImmichServiceProvider(ServiceProvider):
|
||||
# the full-fetch path so removals get detected.
|
||||
|
||||
# Full fetch: first tick, or count-decreased, or delta-unsafe.
|
||||
album = await self._client.get_album(album_id, self._users_cache)
|
||||
# Bypass the module-level album cache — this path runs when we
|
||||
# specifically need the current server state (e.g. to detect
|
||||
# asset removals), so a stale cached entry would silently delay
|
||||
# the event.
|
||||
album = await self._client.get_album(
|
||||
album_id, self._users_cache, use_cache=False,
|
||||
)
|
||||
if album is None:
|
||||
# Album was deleted between meta probe and full fetch — handle
|
||||
# the deletion the same way as above.
|
||||
|
||||
@@ -12,6 +12,7 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_PORT = 3493
|
||||
_READ_TIMEOUT = 10.0
|
||||
_WRITE_TIMEOUT = 10.0
|
||||
_CONNECT_TIMEOUT = 5.0
|
||||
|
||||
# Allowed characters for NUT protocol identifiers (UPS names, variable names).
|
||||
@@ -84,14 +85,26 @@ class NutClient:
|
||||
await self._command(f"PASSWORD {self._password}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Send LOGOUT and close the TCP connection."""
|
||||
"""Send LOGOUT and close the TCP connection.
|
||||
|
||||
``drain`` is bounded by ``_WRITE_TIMEOUT`` so a half-closed peer
|
||||
cannot hold the disconnect indefinitely — a tracker tick would
|
||||
otherwise be pinned by a stuck NUT server and block the scheduler
|
||||
slot (``max_instances=1``).
|
||||
"""
|
||||
if self._writer is not None:
|
||||
try:
|
||||
self._writer.write(b"LOGOUT\n")
|
||||
await self._writer.drain()
|
||||
except OSError:
|
||||
await asyncio.wait_for(self._writer.drain(), timeout=_WRITE_TIMEOUT)
|
||||
except (OSError, asyncio.TimeoutError):
|
||||
pass
|
||||
self._writer.close()
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._writer.wait_closed(), timeout=_WRITE_TIMEOUT,
|
||||
)
|
||||
except (OSError, asyncio.TimeoutError):
|
||||
pass
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
|
||||
@@ -135,7 +148,10 @@ class NutClient:
|
||||
if self._writer is None:
|
||||
raise NutClientError("Not connected")
|
||||
self._writer.write(f"{cmd}\n".encode())
|
||||
await self._writer.drain()
|
||||
try:
|
||||
await asyncio.wait_for(self._writer.drain(), timeout=_WRITE_TIMEOUT)
|
||||
except asyncio.TimeoutError as exc:
|
||||
raise NutClientError("Write timeout") from exc
|
||||
|
||||
async def _readline(self) -> str:
|
||||
"""Read one line from upsd, stripping trailing newline."""
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from notify_bridge_core.models.events import EventType, ServiceEvent
|
||||
from notify_bridge_core.providers.base import ServiceProvider, ServiceProviderType
|
||||
@@ -57,6 +58,13 @@ SCHEDULER_VARIABLES: list[TemplateVariableDefinition] = [
|
||||
example="Monday",
|
||||
provider_type=ServiceProviderType.SCHEDULER,
|
||||
),
|
||||
TemplateVariableDefinition(
|
||||
name="timezone",
|
||||
type="string",
|
||||
description="IANA timezone used to compute current_date/time",
|
||||
example="Europe/Warsaw",
|
||||
provider_type=ServiceProviderType.SCHEDULER,
|
||||
),
|
||||
TemplateVariableDefinition(
|
||||
name="custom_vars",
|
||||
type="dict",
|
||||
@@ -83,7 +91,8 @@ class SchedulerServiceProvider(ServiceProvider):
|
||||
custom_variables: dict[str, str] | None = None,
|
||||
date_format: str = "%d.%m.%Y",
|
||||
time_format: str = "%H:%M",
|
||||
datetime_format: str = "%d.%m.%Y, %H:%M UTC",
|
||||
datetime_format: str = "%d.%m.%Y, %H:%M %Z",
|
||||
timezone_name: str | None = None,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._tracker_name = tracker_name
|
||||
@@ -91,6 +100,18 @@ class SchedulerServiceProvider(ServiceProvider):
|
||||
self._date_format = date_format
|
||||
self._time_format = time_format
|
||||
self._datetime_format = datetime_format
|
||||
# Resolve a timezone for date/time rendering. Falls back to UTC on
|
||||
# invalid IANA names so a typo in app settings doesn't break polls.
|
||||
tz: ZoneInfo
|
||||
if timezone_name:
|
||||
try:
|
||||
tz = ZoneInfo(timezone_name)
|
||||
except (ZoneInfoNotFoundError, ValueError):
|
||||
_LOGGER.warning("Unknown timezone %r; falling back to UTC", timezone_name)
|
||||
tz = ZoneInfo("UTC")
|
||||
else:
|
||||
tz = ZoneInfo("UTC")
|
||||
self._tz = tz
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True # virtual provider — always connected
|
||||
@@ -103,7 +124,8 @@ class SchedulerServiceProvider(ServiceProvider):
|
||||
collection_ids: list[str],
|
||||
tracker_state: dict[str, Any],
|
||||
) -> tuple[list[ServiceEvent], dict[str, Any]]:
|
||||
now = datetime.now(timezone.utc)
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
now = now_utc.astimezone(self._tz)
|
||||
# State uses {collection_id: {dict}} convention like other providers
|
||||
sched_state = tracker_state.get("scheduler", {})
|
||||
fire_count = sched_state.get("fire_count", 0) + 1
|
||||
@@ -115,6 +137,7 @@ class SchedulerServiceProvider(ServiceProvider):
|
||||
"current_time": now.strftime(self._time_format),
|
||||
"current_datetime": now.strftime(self._datetime_format),
|
||||
"weekday": _WEEKDAYS[now.weekday()],
|
||||
"timezone": self._tz.key,
|
||||
"custom_vars": dict(self._custom_variables),
|
||||
}
|
||||
# Flatten custom variables at top level for easy template access
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
@@ -19,34 +21,58 @@ class StorageBackend(Protocol):
|
||||
async def remove(self) -> None: ...
|
||||
|
||||
|
||||
def _read_file(path: Path) -> str | None:
|
||||
if not path.exists():
|
||||
return None
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def _atomic_write(path: Path, payload: str) -> None:
|
||||
"""Write atomically: tmp file + rename. Prevents half-written files on crash."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
tmp.write_text(payload, encoding="utf-8")
|
||||
os.replace(tmp, path)
|
||||
|
||||
|
||||
def _remove_file(path: Path) -> None:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
|
||||
class JsonFileBackend:
|
||||
"""Simple JSON file storage backend."""
|
||||
"""Simple JSON file storage backend.
|
||||
|
||||
All blocking I/O is wrapped in ``asyncio.to_thread`` so callers can
|
||||
``await load() / save() / remove()`` without stalling the event loop.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path) -> None:
|
||||
self._path = path
|
||||
|
||||
async def load(self) -> dict[str, Any] | None:
|
||||
if not self._path.exists():
|
||||
try:
|
||||
text = await asyncio.to_thread(_read_file, self._path)
|
||||
except OSError as err:
|
||||
_LOGGER.warning("Failed to load %s: %s", self._path, err)
|
||||
return None
|
||||
if text is None:
|
||||
return None
|
||||
try:
|
||||
text = self._path.read_text(encoding="utf-8")
|
||||
return json.loads(text)
|
||||
except (json.JSONDecodeError, OSError) as err:
|
||||
_LOGGER.warning("Failed to load %s: %s", self._path, err)
|
||||
except json.JSONDecodeError as err:
|
||||
_LOGGER.warning("Failed to parse %s: %s", self._path, err)
|
||||
return None
|
||||
|
||||
async def save(self, data: dict[str, Any]) -> None:
|
||||
payload = json.dumps(data, default=str)
|
||||
try:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._path.write_text(
|
||||
json.dumps(data, default=str), encoding="utf-8"
|
||||
)
|
||||
await asyncio.to_thread(_atomic_write, self._path, payload)
|
||||
except OSError as err:
|
||||
_LOGGER.error("Failed to save %s: %s", self._path, err)
|
||||
|
||||
async def remove(self) -> None:
|
||||
try:
|
||||
if self._path.exists():
|
||||
self._path.unlink()
|
||||
await asyncio.to_thread(_remove_file, self._path)
|
||||
except OSError as err:
|
||||
_LOGGER.error("Failed to remove %s: %s", self._path, err)
|
||||
|
||||
@@ -224,6 +224,7 @@ def build_template_context(
|
||||
ctx.setdefault("current_time", event.extra.get("current_time", ""))
|
||||
ctx.setdefault("current_datetime", event.extra.get("current_datetime", ""))
|
||||
ctx.setdefault("weekday", event.extra.get("weekday", ""))
|
||||
ctx.setdefault("timezone", event.extra.get("timezone", "UTC"))
|
||||
ctx.setdefault("custom_vars", event.extra.get("custom_vars", {}))
|
||||
|
||||
return ctx
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "notify-bridge-server"
|
||||
version = "0.3.0"
|
||||
version = "0.4.0"
|
||||
description = "Standalone Notify Bridge server — FastAPI REST API with SQLite database"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
@@ -28,6 +28,7 @@ dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"httpx>=0.27",
|
||||
"aioresponses>=0.7",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -35,3 +36,14 @@ notify-bridge = "notify_bridge_server.main:run"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/notify_bridge_server"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
# The default filter doesn't let SQLAlchemy warnings fail the suite, which
|
||||
# matters because our migrations emit a handful of deprecation warnings we
|
||||
# don't want to suppress at source.
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning:passlib",
|
||||
"ignore::DeprecationWarning:bcrypt",
|
||||
]
|
||||
|
||||
@@ -24,6 +24,10 @@ _SETTING_KEYS = {
|
||||
"telegram_asset_cache_max_entries": None, # LRU cap for both caches
|
||||
"supported_locales": None, # comma-separated locale codes
|
||||
"timezone": "NOTIFY_BRIDGE_TIMEZONE", # IANA tz (e.g. "Europe/Warsaw"); empty = UTC
|
||||
# Logging — applied live via apply_log_levels() when changed.
|
||||
"log_level": "NOTIFY_BRIDGE_LOG_LEVEL", # DEBUG/INFO/WARNING/ERROR
|
||||
"log_format": "NOTIFY_BRIDGE_LOG_FORMAT", # text|json (requires restart to switch)
|
||||
"log_levels": "NOTIFY_BRIDGE_LOG_LEVELS", # module=LEVEL,module2=LEVEL
|
||||
}
|
||||
|
||||
_DEFAULTS = {
|
||||
@@ -35,12 +39,20 @@ _DEFAULTS = {
|
||||
"telegram_asset_cache_max_entries": "5000",
|
||||
"supported_locales": "en,ru",
|
||||
"timezone": "UTC",
|
||||
"log_level": "INFO",
|
||||
"log_format": "text",
|
||||
"log_levels": "",
|
||||
}
|
||||
|
||||
# Settings whose changes require dropping in-memory Telegram caches so the
|
||||
# next dispatch rebuilds them with the new parameters. Files are preserved.
|
||||
_CACHE_SETTING_KEYS = {"telegram_cache_ttl_hours", "telegram_asset_cache_max_entries"}
|
||||
|
||||
# Settings that change logging behaviour. ``log_level`` + ``log_levels`` apply
|
||||
# live via apply_log_levels(); ``log_format`` requires a restart because
|
||||
# changing it means swapping the handler formatter entirely.
|
||||
_LOG_SETTING_KEYS = {"log_level", "log_levels", "log_format"}
|
||||
|
||||
|
||||
async def get_setting(session: AsyncSession, key: str) -> str:
|
||||
"""Read a setting from DB, falling back to env var then default."""
|
||||
@@ -66,6 +78,9 @@ class SettingsUpdate(BaseModel):
|
||||
telegram_asset_cache_max_entries: int | str | None = None
|
||||
supported_locales: str | None = None
|
||||
timezone: str | None = None
|
||||
log_level: str | None = None
|
||||
log_format: str | None = None
|
||||
log_levels: str | None = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
@@ -94,6 +109,8 @@ async def update_settings(
|
||||
old_base_url = await get_setting(session, "external_url")
|
||||
old_secret = await get_setting(session, "telegram_webhook_secret")
|
||||
old_cache_values = {k: await get_setting(session, k) for k in _CACHE_SETTING_KEYS}
|
||||
old_timezone = await get_setting(session, "timezone")
|
||||
old_log_values = {k: await get_setting(session, k) for k in _LOG_SETTING_KEYS}
|
||||
|
||||
for key in _SETTING_KEYS:
|
||||
value = getattr(body, key, None)
|
||||
@@ -128,6 +145,33 @@ async def update_settings(
|
||||
|
||||
new_base_url = await get_setting(session, "external_url")
|
||||
new_secret = await get_setting(session, "telegram_webhook_secret")
|
||||
new_timezone = await get_setting(session, "timezone")
|
||||
new_log_values = {k: await get_setting(session, k) for k in _LOG_SETTING_KEYS}
|
||||
|
||||
# Apply live log-level changes (log_format still needs a restart).
|
||||
if (new_log_values["log_level"] != old_log_values["log_level"]
|
||||
or new_log_values["log_levels"] != old_log_values["log_levels"]):
|
||||
from ..logging_setup import apply_log_levels
|
||||
apply_log_levels(
|
||||
level=new_log_values["log_level"] or None,
|
||||
per_module_levels=new_log_values["log_levels"],
|
||||
)
|
||||
_LOGGER.info(
|
||||
"Log levels updated: root=%s overrides=%r",
|
||||
new_log_values["log_level"], new_log_values["log_levels"],
|
||||
)
|
||||
if new_log_values["log_format"] != old_log_values["log_format"]:
|
||||
_LOGGER.warning(
|
||||
"log_format changed from %r to %r — restart the server for it to take effect",
|
||||
old_log_values["log_format"], new_log_values["log_format"],
|
||||
)
|
||||
|
||||
# Cron triggers freeze their timezone at construction time, so a tz change
|
||||
# has no effect until the jobs are rebuilt — do that here, before we
|
||||
# return success, so the UI reflects the actual schedule immediately.
|
||||
if new_timezone != old_timezone:
|
||||
from ..services.scheduler import reschedule_cron_jobs_for_timezone_change
|
||||
await reschedule_cron_jobs_for_timezone_change()
|
||||
|
||||
# Update webhook secret in the webhook handler module
|
||||
if new_secret != old_secret:
|
||||
@@ -190,7 +234,10 @@ async def _reregister_webhooks(
|
||||
if res.get("success"):
|
||||
_LOGGER.info("Re-registered webhook for bot %d (%s)", bot.id, bot.name)
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Failed to re-register webhook for bot %d: %s",
|
||||
bot.id, res.get("error"),
|
||||
# Webhook re-register failure means the bot silently stops
|
||||
# delivering updates — this is operational visibility for an
|
||||
# admin, ERROR is appropriate.
|
||||
_LOGGER.error(
|
||||
"Failed to re-register webhook for bot %d (%s): %s",
|
||||
bot.id, bot.name, res.get("error"),
|
||||
)
|
||||
|
||||
@@ -93,7 +93,14 @@ async def update_email_bot(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
bot = await _get_user_bot(session, bot_id, user.id)
|
||||
for field, value in body.model_dump(exclude_unset=True).items():
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
# Reject the masked value the GET response returns so the stored password
|
||||
# is preserved if the user saves without retyping it.
|
||||
if "smtp_password" in updates:
|
||||
pw = updates["smtp_password"]
|
||||
if isinstance(pw, str) and pw.startswith("***"):
|
||||
updates.pop("smtp_password")
|
||||
for field, value in updates.items():
|
||||
setattr(bot, field, value)
|
||||
session.add(bot)
|
||||
await session.commit()
|
||||
|
||||
@@ -7,6 +7,11 @@ from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.notifications.ssrf import (
|
||||
UnsafeURLError,
|
||||
avalidate_outbound_url,
|
||||
)
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import MatrixBot, User
|
||||
@@ -33,6 +38,21 @@ class MatrixBotUpdate(BaseModel):
|
||||
display_name: str | None = None
|
||||
|
||||
|
||||
def _is_masked_secret(value: str | None) -> bool:
|
||||
"""True when a field still carries our masked placeholder."""
|
||||
return bool(value) and (value.startswith("***") or "..." in value)
|
||||
|
||||
|
||||
async def _validate_homeserver_url(url: str) -> None:
|
||||
"""Reject homeserver URLs that point to blocked networks."""
|
||||
try:
|
||||
await avalidate_outbound_url(url)
|
||||
except UnsafeURLError as err:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid homeserver_url: {err}"
|
||||
) from err
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_matrix_bots(
|
||||
user: User = Depends(get_current_user),
|
||||
@@ -50,6 +70,7 @@ async def create_matrix_bot(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
await _validate_homeserver_url(body.homeserver_url)
|
||||
bot = MatrixBot(user_id=user.id, **body.model_dump())
|
||||
session.add(bot)
|
||||
await session.commit()
|
||||
@@ -74,7 +95,19 @@ async def update_matrix_bot(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
bot = await _get_user_bot(session, bot_id, user.id)
|
||||
for field, value in body.model_dump(exclude_unset=True).items():
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
|
||||
# Re-validate homeserver_url whenever the client supplies a new one so
|
||||
# no private/loopback target can ever be saved, even via update.
|
||||
if "homeserver_url" in updates and updates["homeserver_url"]:
|
||||
await _validate_homeserver_url(updates["homeserver_url"])
|
||||
|
||||
# Never accept the masked placeholder the GET response returns. If the
|
||||
# client echoes it back, keep the stored secret.
|
||||
if "access_token" in updates and _is_masked_secret(updates["access_token"]):
|
||||
updates.pop("access_token")
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(bot, field, value)
|
||||
session.add(bot)
|
||||
await session.commit()
|
||||
@@ -108,15 +141,17 @@ async def test_matrix_bot(
|
||||
If room_id is not provided, just verifies the access token by calling /whoami.
|
||||
"""
|
||||
bot = await _get_user_bot(session, bot_id, user.id)
|
||||
# Defense-in-depth: even though create/update validate the URL, a bot row
|
||||
# written before this guard was added could still point at a blocked host.
|
||||
await _validate_homeserver_url(bot.homeserver_url)
|
||||
|
||||
import aiohttp
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
# Verify token with /whoami
|
||||
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
|
||||
headers = {"Authorization": f"Bearer {bot.access_token}"}
|
||||
try:
|
||||
async with http.get(whoami_url, headers=headers) as resp:
|
||||
async with http.get(whoami_url, headers=headers, allow_redirects=False) as resp:
|
||||
if resp.status != 200:
|
||||
body = await resp.text()
|
||||
return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"}
|
||||
@@ -126,7 +161,6 @@ async def test_matrix_bot(
|
||||
|
||||
result = {"success": True, "user_id": whoami.get("user_id", "")}
|
||||
|
||||
# Optionally send a test message
|
||||
if room_id:
|
||||
from ..services.notifier import _get_test_message
|
||||
from notify_bridge_core.notifications.matrix.client import MatrixClient
|
||||
@@ -148,7 +182,7 @@ def _response(bot: MatrixBot) -> dict:
|
||||
"name": bot.name,
|
||||
"icon": bot.icon,
|
||||
"homeserver_url": bot.homeserver_url,
|
||||
"access_token": f"{bot.access_token[:8]}...{bot.access_token[-4:]}" if len(bot.access_token) > 12 else "***",
|
||||
"access_token": f"***{bot.access_token[-4:]}" if len(bot.access_token) > 4 else "***",
|
||||
"display_name": bot.display_name,
|
||||
"created_at": bot.created_at.isoformat(),
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..database.models import (
|
||||
User,
|
||||
)
|
||||
from ..services.notifier import send_test_notification
|
||||
from ..services.test_dispatch import dispatch_test_notification
|
||||
from ..services.manual_dispatch import dispatch_test_notification
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -11,6 +11,7 @@ from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import (
|
||||
EventLog,
|
||||
NotificationTarget,
|
||||
NotificationTracker,
|
||||
NotificationTrackerState,
|
||||
NotificationTrackerTarget,
|
||||
@@ -54,11 +55,79 @@ async def list_notification_trackers(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
# Batched loader: pull trackers, then all their tracker-target links in
|
||||
# a single query, then the referenced targets in a single query. Avoids
|
||||
# the old 1 + N + N*M pattern that ran ~60 round-trips for 10 trackers.
|
||||
result = await session.exec(
|
||||
select(NotificationTracker).where(NotificationTracker.user_id == user.id)
|
||||
)
|
||||
trackers = result.all()
|
||||
return [await _tracker_response(session, t) for t in trackers]
|
||||
trackers = list(result.all())
|
||||
if not trackers:
|
||||
return []
|
||||
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
tt_result = await session.exec(
|
||||
select(NotificationTrackerTarget).where(
|
||||
NotificationTrackerTarget.tracker_id.in_(tracker_ids)
|
||||
)
|
||||
)
|
||||
tt_rows = list(tt_result.all())
|
||||
|
||||
target_ids = {tt.target_id for tt in tt_rows}
|
||||
targets_by_id: dict[int, NotificationTarget] = {}
|
||||
if target_ids:
|
||||
tgt_result = await session.exec(
|
||||
select(NotificationTarget).where(NotificationTarget.id.in_(target_ids))
|
||||
)
|
||||
targets_by_id = {t.id: t for t in tgt_result.all()}
|
||||
|
||||
tts_by_tracker: dict[int, list[NotificationTrackerTarget]] = {}
|
||||
for tt in tt_rows:
|
||||
tts_by_tracker.setdefault(tt.tracker_id, []).append(tt)
|
||||
|
||||
return [
|
||||
_build_tracker_response(t, tts_by_tracker.get(t.id, []), targets_by_id)
|
||||
for t in trackers
|
||||
]
|
||||
|
||||
|
||||
def _build_tracker_response(
|
||||
t: NotificationTracker,
|
||||
tts: list[NotificationTrackerTarget],
|
||||
targets_by_id: dict[int, NotificationTarget],
|
||||
) -> dict:
|
||||
"""In-memory assembler for a tracker + its pre-loaded links/targets."""
|
||||
tracker_targets = []
|
||||
for tt in tts:
|
||||
target = targets_by_id.get(tt.target_id)
|
||||
tracker_targets.append({
|
||||
"id": tt.id,
|
||||
"tracker_id": tt.tracker_id,
|
||||
"target_id": tt.target_id,
|
||||
"target_name": target.name if target else None,
|
||||
"target_type": target.type if target else None,
|
||||
"target_icon": target.icon if target else None,
|
||||
"tracking_config_id": tt.tracking_config_id,
|
||||
"template_config_id": tt.template_config_id,
|
||||
"enabled": tt.enabled,
|
||||
"quiet_hours_start": tt.quiet_hours_start,
|
||||
"quiet_hours_end": tt.quiet_hours_end,
|
||||
"created_at": tt.created_at.isoformat(),
|
||||
})
|
||||
return {
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"icon": t.icon,
|
||||
"provider_id": t.provider_id,
|
||||
"collection_ids": t.collection_ids,
|
||||
"scan_interval": t.scan_interval,
|
||||
"batch_duration": t.batch_duration,
|
||||
"default_tracking_config_id": t.default_tracking_config_id,
|
||||
"default_template_config_id": t.default_template_config_id,
|
||||
"enabled": t.enabled,
|
||||
"tracker_targets": tracker_targets,
|
||||
"created_at": t.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
|
||||
@@ -306,16 +306,31 @@ async def update_provider(
|
||||
if body.icon is not None:
|
||||
provider.icon = body.icon
|
||||
|
||||
config_changed = body.config is not None and body.config != provider.config
|
||||
if body.config is not None:
|
||||
_validate_provider_config(provider.type, body.config)
|
||||
provider.config = body.config
|
||||
# Merge rather than replace so the masked secrets the frontend
|
||||
# receives on GET cannot silently nuke the stored values when the
|
||||
# user saves without re-entering them. Any field that still carries
|
||||
# our mask placeholder ("***…") is dropped from the incoming body.
|
||||
incoming = dict(body.config)
|
||||
for secret_field in (
|
||||
"api_key", "api_token", "webhook_secret", "password",
|
||||
"client_secret", "refresh_token",
|
||||
):
|
||||
value = incoming.get(secret_field)
|
||||
if isinstance(value, str) and value.startswith("***"):
|
||||
incoming.pop(secret_field, None)
|
||||
new_config = {**provider.config, **incoming}
|
||||
_validate_provider_config(provider.type, new_config)
|
||||
config_changed = new_config != provider.config
|
||||
provider.config = new_config
|
||||
|
||||
# Re-validate connection when config changes for known provider types
|
||||
if config_changed:
|
||||
test_result = await _validate_provider_connection(provider)
|
||||
if test_result.get("external_domain"):
|
||||
provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
|
||||
if config_changed:
|
||||
test_result = await _validate_provider_connection(provider)
|
||||
if test_result.get("external_domain"):
|
||||
provider.config = {
|
||||
**provider.config,
|
||||
"external_domain": test_result["external_domain"],
|
||||
}
|
||||
|
||||
session.add(provider)
|
||||
await session.commit()
|
||||
|
||||
@@ -242,6 +242,8 @@ async def get_template_variables(
|
||||
"current_date": "Current date (formatted)",
|
||||
"current_time": "Current time (formatted)",
|
||||
"current_datetime": "Current date and time (formatted)",
|
||||
"weekday": "Day of the week (Monday..Sunday)",
|
||||
"timezone": "IANA timezone used for current_date/time",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""User management API routes (admin only)."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
@@ -14,6 +15,15 @@ from ..auth.dependencies import require_admin
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import User
|
||||
|
||||
|
||||
async def _hash_password(password: str) -> str:
|
||||
"""Run bcrypt off the event loop. Matches the helper in auth/routes.py."""
|
||||
|
||||
def _work() -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
return await asyncio.to_thread(_work)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/users", tags=["users"])
|
||||
@@ -36,8 +46,12 @@ async def list_users(
|
||||
admin: User = Depends(require_admin),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""List all users (admin only)."""
|
||||
result = await session.exec(select(User))
|
||||
"""List all users (admin only).
|
||||
|
||||
Excludes the internal ``__system__`` placeholder (id=0) used as the
|
||||
owner of default templates/configs — it is never a real account.
|
||||
"""
|
||||
result = await session.exec(select(User).where(User.id != 0))
|
||||
return [
|
||||
{"id": u.id, "username": u.username, "role": u.role, "created_at": u.created_at.isoformat()}
|
||||
for u in result.all()
|
||||
@@ -61,7 +75,7 @@ async def create_user(
|
||||
|
||||
user = User(
|
||||
username=body.username,
|
||||
hashed_password=bcrypt.hashpw(body.password.encode(), bcrypt.gensalt()).decode(),
|
||||
hashed_password=await _hash_password(body.password),
|
||||
role=body.role if body.role in ("admin", "user") else "user",
|
||||
)
|
||||
session.add(user)
|
||||
@@ -162,7 +176,7 @@ async def reset_user_password(
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
if len(body.new_password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
user.hashed_password = bcrypt.hashpw(body.new_password.encode(), bcrypt.gensalt()).decode()
|
||||
user.hashed_password = await _hash_password(body.new_password)
|
||||
# Invalidate all prior JWTs issued for this user — matches the self-serve
|
||||
# password-change path in auth/routes.py.
|
||||
user.token_version = (user.token_version or 1) + 1
|
||||
|
||||
@@ -37,6 +37,42 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/webhooks", tags=["webhooks"])
|
||||
|
||||
# Hard cap on inbound webhook body size (1 MiB is far larger than anything
|
||||
# legitimate providers send and keeps the worst-case memory footprint bounded
|
||||
# when a malicious peer lies about Content-Length or streams slowly).
|
||||
_MAX_WEBHOOK_BODY_BYTES = 1_000_000
|
||||
|
||||
|
||||
async def _read_bounded_body(request: Request, limit: int = _MAX_WEBHOOK_BODY_BYTES) -> bytes:
|
||||
"""Reject oversized inbound bodies before they exhaust memory.
|
||||
|
||||
First checks ``Content-Length`` (fast-path for honest peers), then
|
||||
streams the body in chunks enforcing the same cap on actual bytes
|
||||
received so a peer that lies about Content-Length cannot slip through.
|
||||
"""
|
||||
declared = request.headers.get("content-length")
|
||||
if declared:
|
||||
try:
|
||||
if int(declared) > limit:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Payload too large (max {limit} bytes)",
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid Content-Length")
|
||||
|
||||
chunks: list[bytes] = []
|
||||
size = 0
|
||||
async for chunk in request.stream():
|
||||
size += len(chunk)
|
||||
if size > limit:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"Payload too large (max {limit} bytes)",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
async def _get_provider_by_token(
|
||||
session: AsyncSession, token: str, expected_type: str,
|
||||
@@ -169,7 +205,8 @@ async def _dispatch_webhook_event(
|
||||
))
|
||||
|
||||
# Dispatch to targets
|
||||
dispatcher = NotificationDispatcher()
|
||||
from ..services.http_session import get_http_session
|
||||
dispatcher = NotificationDispatcher(session=await get_http_session())
|
||||
target_configs = _build_target_configs(event, link_data, provider_config, app_tz)
|
||||
if target_configs:
|
||||
results = await dispatcher.dispatch(event, target_configs)
|
||||
@@ -203,7 +240,7 @@ async def gitea_webhook(token: str, request: Request):
|
||||
webhook_secret = (provider.config or {}).get("webhook_secret", "")
|
||||
|
||||
# Read raw body for HMAC check
|
||||
raw_body = await request.body()
|
||||
raw_body = await _read_bounded_body(request)
|
||||
|
||||
if not webhook_secret:
|
||||
raise HTTPException(
|
||||
@@ -221,8 +258,8 @@ async def gitea_webhook(token: str, request: Request):
|
||||
return {"ok": True, "skipped": "no event header"}
|
||||
|
||||
try:
|
||||
payload = await request.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
payload = json.loads(raw_body.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON")
|
||||
|
||||
event = parse_gitea_webhook(event_header, payload, provider.name)
|
||||
@@ -280,10 +317,10 @@ async def planka_webhook(token: str, request: Request):
|
||||
if not _verify_planka_token(webhook_secret, request):
|
||||
raise HTTPException(status_code=403, detail="Invalid token")
|
||||
|
||||
# Parse payload
|
||||
# Parse payload from the bounded raw_body we already read.
|
||||
try:
|
||||
payload = await request.json()
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
payload = json.loads(raw_body.decode("utf-8"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON")
|
||||
|
||||
event_type = payload.get("type", "")
|
||||
@@ -446,23 +483,22 @@ async def generic_webhook(token: str, request: Request):
|
||||
store_payloads = provider_config.get("store_payloads", True)
|
||||
max_stored = min(max(int(provider_config.get("max_stored_payloads", 20)), 1), 100)
|
||||
|
||||
raw_body = await request.body()
|
||||
raw_body = await _read_bounded_body(request)
|
||||
|
||||
# Enforce payload size limit BEFORE parsing JSON
|
||||
if len(raw_body) > 1_000_000:
|
||||
raise HTTPException(status_code=413, detail="Payload too large (max 1 MB)")
|
||||
# Bounded read above already enforces the size cap; no need to re-check.
|
||||
|
||||
if not _verify_generic_webhook_auth(provider_config, request, raw_body):
|
||||
raise HTTPException(status_code=403, detail="Authentication failed")
|
||||
|
||||
safe_headers = _filter_headers(dict(request.headers))
|
||||
|
||||
# Parse JSON payload
|
||||
# Parse JSON payload from the already-bounded raw_body (request.body()
|
||||
# has been consumed, so request.json() is no longer usable here).
|
||||
try:
|
||||
payload = await request.json()
|
||||
payload = json.loads(raw_body.decode("utf-8"))
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("Payload must be a JSON object")
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
|
||||
if store_payloads:
|
||||
async with AsyncSession(get_engine()) as log_session:
|
||||
await _save_webhook_log(
|
||||
|
||||
@@ -7,30 +7,51 @@ import jwt
|
||||
from ..config import settings
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
_LEEWAY_SECONDS = 10
|
||||
|
||||
|
||||
def _now_utc() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def create_access_token(user_id: int, role: str, token_version: int = 1) -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
now = _now_utc()
|
||||
expire = now + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
payload = {
|
||||
"iss": settings.jwt_issuer,
|
||||
"aud": settings.jwt_audience,
|
||||
"sub": str(user_id),
|
||||
"role": role,
|
||||
"type": "access",
|
||||
"ver": token_version,
|
||||
"iat": now,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def create_refresh_token(user_id: int, token_version: int = 1) -> str:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.refresh_token_expire_days)
|
||||
now = _now_utc()
|
||||
expire = now + timedelta(days=settings.refresh_token_expire_days)
|
||||
payload = {
|
||||
"iss": settings.jwt_issuer,
|
||||
"aud": settings.jwt_audience,
|
||||
"sub": str(user_id),
|
||||
"type": "refresh",
|
||||
"ver": token_version,
|
||||
"iat": now,
|
||||
"exp": expire,
|
||||
}
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
return jwt.decode(token, settings.secret_key, algorithms=[ALGORITHM])
|
||||
return jwt.decode(
|
||||
token,
|
||||
settings.secret_key,
|
||||
algorithms=[ALGORITHM],
|
||||
audience=settings.jwt_audience,
|
||||
issuer=settings.jwt_issuer,
|
||||
leeway=_LEEWAY_SECONDS,
|
||||
options={"require": ["exp", "sub", "iss", "aud", "type"]},
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Authentication API routes."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from slowapi import Limiter
|
||||
@@ -16,7 +18,9 @@ from .jwt import create_access_token, create_refresh_token, decode_token
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
# Default rate limit applied by SlowAPIMiddleware to every route that does NOT
|
||||
# specify its own @limiter.limit(...) — protects against blanket abuse.
|
||||
limiter = Limiter(key_func=get_remote_address, default_limits=["600/minute"])
|
||||
|
||||
|
||||
class SetupRequest(BaseModel):
|
||||
@@ -45,27 +49,52 @@ class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
async def _hash_password(password: str) -> str:
|
||||
"""bcrypt.hashpw is CPU-bound (~200-500ms); never run it on the event loop."""
|
||||
|
||||
def _work() -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
return await asyncio.to_thread(_work)
|
||||
|
||||
|
||||
def _verify_password(password: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
async def _verify_password(password: str, hashed: str) -> bool:
|
||||
def _work() -> bool:
|
||||
try:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
except ValueError:
|
||||
# Malformed hash in DB — treat as mismatch, never raise to caller.
|
||||
return False
|
||||
|
||||
return await asyncio.to_thread(_work)
|
||||
|
||||
|
||||
@router.post("/setup", response_model=TokenResponse)
|
||||
@limiter.limit("3/minute")
|
||||
async def setup(request: Request, body: SetupRequest, session: AsyncSession = Depends(get_session)):
|
||||
result = await session.exec(select(func.count()).select_from(User))
|
||||
count = result.one()
|
||||
if count > 0:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Setup already completed.")
|
||||
|
||||
if len(body.password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
user = User(username=body.username, hashed_password=_hash_password(body.password), role="admin")
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
# Compute hash BEFORE opening the transaction so we don't hold a writer lock
|
||||
# during the CPU-bound bcrypt work.
|
||||
hashed = await _hash_password(body.password)
|
||||
|
||||
# Serialize setup via an INSERT-inside-transaction-with-count-guard.
|
||||
# SQLite's writer lock plus the count check inside the transaction closes
|
||||
# the TOCTOU window between two concurrent POSTs. We ignore id=0 — that's
|
||||
# the internal "__system__" placeholder used for ownership of default
|
||||
# templates, never a real admin.
|
||||
async with session.begin():
|
||||
result = await session.exec(
|
||||
select(func.count()).select_from(User).where(User.id != 0)
|
||||
)
|
||||
count = result.one()
|
||||
if count > 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Setup already completed.",
|
||||
)
|
||||
user = User(username=body.username, hashed_password=hashed, role="admin")
|
||||
session.add(user)
|
||||
await session.refresh(user)
|
||||
|
||||
return TokenResponse(
|
||||
@@ -79,7 +108,13 @@ async def setup(request: Request, body: SetupRequest, session: AsyncSession = De
|
||||
async def login(request: Request, body: LoginRequest, session: AsyncSession = Depends(get_session)):
|
||||
result = await session.exec(select(User).where(User.username == body.username))
|
||||
user = result.first()
|
||||
if not user or not _verify_password(body.password, user.hashed_password):
|
||||
# Always run a bcrypt verification to keep the response time constant,
|
||||
# preventing username-enumeration via timing side channel.
|
||||
password_ok = await _verify_password(
|
||||
body.password,
|
||||
user.hashed_password if user else "$2b$12$" + "a" * 53,
|
||||
)
|
||||
if not user or not password_ok:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password")
|
||||
|
||||
return TokenResponse(
|
||||
@@ -124,16 +159,18 @@ class PasswordChangeRequest(BaseModel):
|
||||
|
||||
|
||||
@router.put("/password")
|
||||
@limiter.limit("10/minute")
|
||||
async def change_password(
|
||||
request: Request,
|
||||
body: PasswordChangeRequest,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
if not _verify_password(body.current_password, user.hashed_password):
|
||||
if not await _verify_password(body.current_password, user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
||||
if len(body.new_password) < 8:
|
||||
raise HTTPException(status_code=400, detail="New password must be at least 8 characters")
|
||||
user.hashed_password = _hash_password(body.new_password)
|
||||
user.hashed_password = await _hash_password(body.new_password)
|
||||
user.token_version = (user.token_version or 1) + 1
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
@@ -141,7 +178,12 @@ async def change_password(
|
||||
|
||||
|
||||
@router.get("/needs-setup")
|
||||
async def needs_setup(session: AsyncSession = Depends(get_session)):
|
||||
result = await session.exec(select(func.count()).select_from(User))
|
||||
@limiter.limit("30/minute")
|
||||
async def needs_setup(request: Request, session: AsyncSession = Depends(get_session)):
|
||||
# Exclude the internal __system__ placeholder (id=0) from the count so
|
||||
# a fresh install still reports needs_setup=True.
|
||||
result = await session.exec(
|
||||
select(func.count()).select_from(User).where(User.id != 0)
|
||||
)
|
||||
count = result.one()
|
||||
return {"needs_setup": count == 0}
|
||||
|
||||
@@ -108,13 +108,18 @@ def _render_cmd_template(
|
||||
"""Render a locale-aware command template. Falls back to 'en'."""
|
||||
template_str = _resolve_template(templates, slot_name, locale)
|
||||
if not template_str:
|
||||
_LOGGER.warning("No command template found for slot '%s' locale '%s'", slot_name, locale)
|
||||
# Missing template = user sees "[No template: X]" — this is an ERROR,
|
||||
# not a warning. Broken replies must stand out in production logs.
|
||||
_LOGGER.error("No command template found for slot '%s' locale '%s'", slot_name, locale)
|
||||
return f"[No template: {slot_name}]"
|
||||
try:
|
||||
tmpl = _compile_template(template_str)
|
||||
return tmpl.render(**context)
|
||||
except Exception as e:
|
||||
_LOGGER.warning("Failed to render command template '%s': %s", slot_name, e)
|
||||
except Exception:
|
||||
_LOGGER.error(
|
||||
"Failed to render command template '%s' locale=%s — user will see a broken reply",
|
||||
slot_name, locale, exc_info=True,
|
||||
)
|
||||
return f"[Template error: {slot_name}]"
|
||||
|
||||
|
||||
@@ -296,6 +301,10 @@ async def handle_command(
|
||||
# Rate limit check (once per command, shared across all trackers)
|
||||
wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits)
|
||||
if wait is not None:
|
||||
_LOGGER.info(
|
||||
"Rate-limited /%s for bot=%d chat=%s — %ds cooldown remaining",
|
||||
cmd, bot.id, chat_id, wait,
|
||||
)
|
||||
text_resp = _render_cmd_template(merged_templates, "rate_limited", locale, {"wait": wait})
|
||||
return [CommandResponse(text=text_resp)]
|
||||
|
||||
@@ -322,8 +331,8 @@ async def handle_command(
|
||||
for tracker, config, provider, listener in ctx_tuples:
|
||||
if len(responses) >= _MAX_RESPONSES_PER_COMMAND:
|
||||
_LOGGER.warning(
|
||||
"Truncated command responses at %d for bot %d cmd /%s",
|
||||
_MAX_RESPONSES_PER_COMMAND, bot.id, cmd,
|
||||
"Truncated command responses at %d for bot=%d chat=%s cmd=/%s (listener context size=%d)",
|
||||
_MAX_RESPONSES_PER_COMMAND, bot.id, chat_id, cmd, len(ctx_tuples),
|
||||
)
|
||||
break
|
||||
|
||||
@@ -418,7 +427,12 @@ async def send_reply(
|
||||
disable_web_page_preview=True,
|
||||
)
|
||||
if not result.get("success"):
|
||||
_LOGGER.warning("Telegram reply failed: %s", result.get("error"))
|
||||
# User-visible failure: the bot's reply never reached the chat.
|
||||
_LOGGER.error(
|
||||
"Telegram reply failed (chat=%s reply_to=%s len=%d): code=%s error=%r",
|
||||
chat_id, reply_to_message_id, len(text or ""),
|
||||
result.get("error_code"), result.get("error"),
|
||||
)
|
||||
|
||||
|
||||
async def send_media_group(
|
||||
@@ -442,6 +456,14 @@ async def send_media_group(
|
||||
assets hit the cache and skip the re-upload.
|
||||
"""
|
||||
if not media_items:
|
||||
# This is what happened in the /random blind spot: the text reply
|
||||
# was sent, but the media follow-up was silently skipped because
|
||||
# the caller passed an empty media list. Surface it so we can see
|
||||
# it in the log and correlate with the text message.
|
||||
_LOGGER.warning(
|
||||
"send_media_group called with 0 items (chat=%s reply_to=%s) — no media will be delivered",
|
||||
chat_id, reply_to_message_id,
|
||||
)
|
||||
return
|
||||
|
||||
from ..services.telegram_send import send_telegram_media
|
||||
@@ -452,7 +474,13 @@ async def send_media_group(
|
||||
chat_action=None,
|
||||
)
|
||||
if not result.get("success"):
|
||||
_LOGGER.warning("Telegram media group failed: %s", result.get("error"))
|
||||
# User-visible failure: media promised by the text reply never arrived.
|
||||
_LOGGER.error(
|
||||
"Telegram media group failed (chat=%s items=%d reply_to=%s): code=%s error=%r failed_at_chunk=%s",
|
||||
chat_id, len(media_items), reply_to_message_id,
|
||||
result.get("error_code"), result.get("error"),
|
||||
result.get("failed_at_chunk"),
|
||||
)
|
||||
|
||||
|
||||
async def register_commands_with_telegram(bot: TelegramBot) -> bool:
|
||||
|
||||
@@ -144,6 +144,7 @@ def _format_assets(
|
||||
# other's cached file_ids (which is what made the cache look empty
|
||||
# from the WebUI after running /random).
|
||||
media_items: list[dict[str, Any]] = []
|
||||
dropped = 0
|
||||
for asset in assets:
|
||||
asset_id = asset.get("id", "")
|
||||
asset_type = (asset.get("type") or "").upper()
|
||||
@@ -156,6 +157,20 @@ def _format_assets(
|
||||
)
|
||||
if entry is not None:
|
||||
media_items.append(entry)
|
||||
else:
|
||||
dropped += 1
|
||||
_LOGGER.warning(
|
||||
"Dropped asset from /%s media payload: id=%s type=%s (empty preview URL)",
|
||||
cmd, asset_id, asset_type,
|
||||
)
|
||||
if not media_items and assets:
|
||||
# All assets were filtered out before reaching Telegram. The user
|
||||
# will see the text reply but no media — surface it here so the
|
||||
# log shows WHY the media group ended up empty.
|
||||
_LOGGER.warning(
|
||||
"/%s media payload empty: %d asset(s) in, 0 out (all dropped)",
|
||||
cmd, len(assets),
|
||||
)
|
||||
# Return text message + media items — text is sent first, media as reply
|
||||
return {"text": text, "media": media_items}
|
||||
|
||||
|
||||
@@ -143,7 +143,16 @@ async def _cmd_immich(
|
||||
# chat). ``None`` = no filter (rare); empty set = show nothing (common
|
||||
# when the chat has no tracker routing).
|
||||
if allowed_album_ids is not None:
|
||||
before = len(all_album_ids)
|
||||
all_album_ids = [aid for aid in all_album_ids if aid in allowed_album_ids]
|
||||
if not all_album_ids:
|
||||
# A command that sees zero albums is a routing/tracker config issue
|
||||
# the operator needs to notice — otherwise the user gets
|
||||
# "no results" with no hint at why.
|
||||
_LOGGER.info(
|
||||
"Command /%s has empty album scope for provider=%d (had %d trackers, chat scope allowed %d)",
|
||||
cmd, provider.id, before, len(allowed_album_ids),
|
||||
)
|
||||
|
||||
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
|
||||
|
||||
|
||||
@@ -4,12 +4,14 @@ from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.log_context import bind_log_context
|
||||
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
||||
|
||||
from ..database.engine import get_session
|
||||
@@ -18,6 +20,7 @@ from ..services.telegram import save_chat_from_webhook
|
||||
from ..services.telegram_send import telegram_chat_action
|
||||
from .base import CommandResponse
|
||||
from .handler import classify_command_chat_action, handle_command, send_media_group, send_reply
|
||||
from .parser import parse_command
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -93,20 +96,62 @@ async def telegram_webhook(
|
||||
)
|
||||
)).first()
|
||||
if not chat_row or not chat_row.commands_enabled:
|
||||
_LOGGER.info(
|
||||
"Command ignored — commands disabled for bot=%s chat=%s text=%r",
|
||||
bot_id, chat_id, text[:64],
|
||||
)
|
||||
return {"ok": True, "skipped": "commands_disabled"}
|
||||
effective_lang = chat_row.language_override or msg_language
|
||||
message_id = message.get("message_id")
|
||||
async with telegram_chat_action(
|
||||
bot_token, chat_id, classify_command_chat_action(text),
|
||||
|
||||
cmd_name, _, _ = parse_command(text)
|
||||
update_id = update.get("update_id")
|
||||
request_id = f"tg:{update_id}" if update_id is not None else f"tg:msg{message_id}"
|
||||
|
||||
with bind_log_context(
|
||||
request_id=request_id,
|
||||
command=cmd_name or "-",
|
||||
chat_id=chat_id,
|
||||
bot_id=bot_id,
|
||||
):
|
||||
responses = await handle_command(bot, chat_id, text, language_code=effective_lang)
|
||||
if responses:
|
||||
for resp in responses:
|
||||
if resp.text:
|
||||
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||
if resp.media:
|
||||
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||
return {"ok": True}
|
||||
started = time.monotonic()
|
||||
_LOGGER.info("Command received: /%s args=%r lang=%s", cmd_name, text[:200], effective_lang)
|
||||
try:
|
||||
async with telegram_chat_action(
|
||||
bot_token, chat_id, classify_command_chat_action(text),
|
||||
):
|
||||
responses = await handle_command(bot, chat_id, text, language_code=effective_lang)
|
||||
if not responses:
|
||||
_LOGGER.info(
|
||||
"Command produced no response (cmd=%r) after %.0f ms",
|
||||
cmd_name, (time.monotonic() - started) * 1000,
|
||||
)
|
||||
return {"ok": True, "skipped": "no_response"}
|
||||
text_count = sum(1 for r in responses if r.text)
|
||||
media_count = sum(len(r.media or []) for r in responses)
|
||||
_LOGGER.info(
|
||||
"Command dispatching %d response(s): text=%d media_items=%d",
|
||||
len(responses), text_count, media_count,
|
||||
)
|
||||
for idx, resp in enumerate(responses):
|
||||
if resp.text:
|
||||
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||
if resp.media:
|
||||
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||
_LOGGER.info(
|
||||
"Command /%s completed in %.0f ms (responses=%d media=%d)",
|
||||
cmd_name, (time.monotonic() - started) * 1000,
|
||||
len(responses), media_count,
|
||||
)
|
||||
return {"ok": True}
|
||||
except Exception:
|
||||
_LOGGER.exception(
|
||||
"Command /%s raised after %.0f ms",
|
||||
cmd_name, (time.monotonic() - started) * 1000,
|
||||
)
|
||||
# Return 200 so Telegram doesn't retry the same update — we
|
||||
# already logged the failure and can't usefully reprocess.
|
||||
return {"ok": True, "error": "handler_exception"}
|
||||
|
||||
return {"ok": True, "skipped": "not_a_command"}
|
||||
|
||||
|
||||
@@ -2,8 +2,20 @@
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
# Secret keys we will actively refuse. These cover the default template value
|
||||
# and dev-only literals that have appeared in scripts or documentation.
|
||||
_FORBIDDEN_SECRETS: frozenset[str] = frozenset(
|
||||
{
|
||||
"change-me-in-production",
|
||||
"test-secret-key-minimum-32-chars",
|
||||
"dev-secret-key-not-for-production",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
@@ -13,29 +25,25 @@ class Settings(BaseSettings):
|
||||
|
||||
secret_key: str = "change-me-in-production"
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.secret_key == "change-me-in-production":
|
||||
raise ValueError(
|
||||
"SECURITY: Refusing to start with the default secret_key. "
|
||||
"Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) "
|
||||
"before starting the server (debug mode included)."
|
||||
)
|
||||
if len(self.secret_key) < 32:
|
||||
raise ValueError(
|
||||
"SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters."
|
||||
)
|
||||
if "*" in self.cors_allowed_origins.split(","):
|
||||
raise ValueError(
|
||||
"SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled."
|
||||
)
|
||||
|
||||
access_token_expire_minutes: int = 60
|
||||
access_token_expire_minutes: int = 15
|
||||
refresh_token_expire_days: int = 30
|
||||
|
||||
jwt_issuer: str = "notify-bridge"
|
||||
jwt_audience: str = "notify-bridge-api"
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8420
|
||||
debug: bool = False
|
||||
|
||||
# Comma-separated list of trusted proxy IPs uvicorn will honor for
|
||||
# X-Forwarded-For / X-Forwarded-Proto. Use "*" ONLY when you trust the
|
||||
# network (never directly on the internet). Default matches uvicorn.
|
||||
forwarded_allow_ips: str = "127.0.0.1"
|
||||
|
||||
# How long to wait for in-flight requests / scheduler jobs before force
|
||||
# killing on SIGTERM.
|
||||
graceful_shutdown_seconds: int = 60
|
||||
|
||||
anthropic_api_key: str = ""
|
||||
ai_model: str = "claude-sonnet-4-20250514"
|
||||
ai_max_tokens: int = 1024
|
||||
@@ -48,8 +56,61 @@ class Settings(BaseSettings):
|
||||
static_dir: str = ""
|
||||
"""Path to frontend static files. Set to serve SvelteKit build via FastAPI (e.g. /app/static in Docker)."""
|
||||
|
||||
# --- Logging ---
|
||||
log_level: str = "INFO"
|
||||
"""Root log level for the app loggers (``DEBUG``/``INFO``/``WARNING``/``ERROR``)."""
|
||||
|
||||
log_format: str = "text"
|
||||
"""Log output format: ``text`` (human-readable) or ``json`` (one object per line)."""
|
||||
|
||||
log_levels: str = ""
|
||||
"""Comma-separated per-module overrides, e.g. ``notify_bridge_core.notifications.telegram.client=DEBUG,sqlalchemy.engine=INFO``."""
|
||||
|
||||
# --- Retention ---
|
||||
event_log_retention_days: int = 30
|
||||
"""Days of event_log history to retain. 0 disables the retention job."""
|
||||
|
||||
pre_migrate_snapshot_keep: int = 5
|
||||
"""Number of pre-migration DB snapshots to keep in ``data_dir/backups/``.
|
||||
0 disables snapshotting entirely. Each snapshot is produced at boot
|
||||
before migrations run using SQLite's ``VACUUM INTO`` (atomic, consistent).
|
||||
"""
|
||||
|
||||
model_config = {"env_prefix": "NOTIFY_BRIDGE_"}
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if self.secret_key in _FORBIDDEN_SECRETS:
|
||||
raise ValueError(
|
||||
"SECURITY: Refusing to start with a known/default secret_key. "
|
||||
"Set NOTIFY_BRIDGE_SECRET_KEY to a random value (>=32 bytes) "
|
||||
"before starting the server."
|
||||
)
|
||||
if len(self.secret_key) < 32:
|
||||
raise ValueError(
|
||||
"SECURITY: NOTIFY_BRIDGE_SECRET_KEY must be at least 32 characters."
|
||||
)
|
||||
origins = [o.strip() for o in self.cors_allowed_origins.split(",") if o.strip()]
|
||||
if "*" in origins:
|
||||
raise ValueError(
|
||||
"SECURITY: wildcard '*' is not allowed in CORS origins when credentials are enabled."
|
||||
)
|
||||
for origin in origins:
|
||||
parsed = urlparse(origin)
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
|
||||
raise ValueError(
|
||||
f"CORS origin {origin!r} is invalid — must include scheme (http/https) and host."
|
||||
)
|
||||
if self.access_token_expire_minutes <= 0:
|
||||
raise ValueError("access_token_expire_minutes must be > 0")
|
||||
if self.refresh_token_expire_days <= 0:
|
||||
raise ValueError("refresh_token_expire_days must be > 0")
|
||||
if not (1 <= self.port <= 65535):
|
||||
raise ValueError("port must be in range 1..65535")
|
||||
if self.event_log_retention_days < 0:
|
||||
raise ValueError("event_log_retention_days must be >= 0")
|
||||
if self.pre_migrate_snapshot_keep < 0:
|
||||
raise ValueError("pre_migrate_snapshot_keep must be >= 0")
|
||||
|
||||
@property
|
||||
def effective_database_url(self) -> str:
|
||||
if self.database_url:
|
||||
|
||||
@@ -1,23 +1,59 @@
|
||||
"""Database engine and session management."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
import logging
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
|
||||
from ..config import settings
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_engine: AsyncEngine | None = None
|
||||
|
||||
|
||||
def _install_sqlite_pragmas(engine: AsyncEngine) -> None:
|
||||
"""Apply production-grade SQLite PRAGMAs on every new connection.
|
||||
|
||||
WAL mode lets readers and writers work concurrently without blocking;
|
||||
``busy_timeout`` gives contending writers a chance instead of instant
|
||||
SQLITE_BUSY; ``foreign_keys`` enforces the FK constraints declared in the
|
||||
models (SQLite disables them by default); ``synchronous=NORMAL`` is a
|
||||
safe-by-default durability trade-off that is standard in WAL mode.
|
||||
"""
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _pragmas(dbapi_conn, _record): # pragma: no cover — driver hook
|
||||
cur = dbapi_conn.cursor()
|
||||
try:
|
||||
cur.execute("PRAGMA journal_mode=WAL")
|
||||
cur.execute("PRAGMA synchronous=NORMAL")
|
||||
cur.execute("PRAGMA foreign_keys=ON")
|
||||
cur.execute("PRAGMA busy_timeout=10000")
|
||||
cur.execute("PRAGMA temp_store=MEMORY")
|
||||
finally:
|
||||
cur.close()
|
||||
|
||||
|
||||
def get_engine() -> AsyncEngine:
|
||||
global _engine
|
||||
if _engine is None:
|
||||
url = settings.effective_database_url
|
||||
connect_args: dict = {}
|
||||
if url.startswith("sqlite"):
|
||||
connect_args["timeout"] = 30
|
||||
_engine = create_async_engine(
|
||||
settings.effective_database_url,
|
||||
url,
|
||||
echo=settings.debug,
|
||||
pool_pre_ping=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
if url.startswith("sqlite"):
|
||||
_install_sqlite_pragmas(_engine)
|
||||
_LOGGER.info("Database engine initialized: %s", url.split("://", 1)[0])
|
||||
return _engine
|
||||
|
||||
|
||||
@@ -31,3 +67,11 @@ async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def dispose_engine() -> None:
|
||||
"""Close the engine's connection pool. Call during graceful shutdown."""
|
||||
global _engine
|
||||
if _engine is not None:
|
||||
await _engine.dispose()
|
||||
_engine = None
|
||||
|
||||
@@ -1282,3 +1282,141 @@ async def migrate_user_token_version(engine: AsyncEngine) -> None:
|
||||
text("ALTER TABLE user ADD COLUMN token_version INTEGER NOT NULL DEFAULT 1")
|
||||
)
|
||||
logger.info("Added token_version column to user table")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Performance indexes — covers every FK / owner column the list endpoints
|
||||
# and the webhook hot-path filter on. All use CREATE INDEX IF NOT EXISTS so
|
||||
# they are safe to re-run on every boot.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INDEXES: list[tuple[str, str, str]] = [
|
||||
# (index_name, table, columns)
|
||||
("ix_service_provider_user_id", "service_provider", "user_id"),
|
||||
("ix_telegram_bot_user_id", "telegram_bot", "user_id"),
|
||||
("ix_matrix_bot_user_id", "matrix_bot", "user_id"),
|
||||
("ix_email_bot_user_id", "email_bot", "user_id"),
|
||||
("ix_telegram_chat_bot_id", "telegram_chat", "bot_id"),
|
||||
("ix_tracking_config_user_id", "tracking_config", "user_id"),
|
||||
("ix_tracking_config_provider_type", "tracking_config", "provider_type"),
|
||||
("ix_notification_target_user_id", "notification_target", "user_id"),
|
||||
("ix_notification_target_type", "notification_target", "type"),
|
||||
("ix_notification_tracker_user_id", "notification_tracker", "user_id"),
|
||||
("ix_notification_tracker_provider_id", "notification_tracker", "provider_id"),
|
||||
# Composite for the webhook hot path: WHERE provider_id = ? AND enabled = true
|
||||
(
|
||||
"ix_notification_tracker_provider_enabled",
|
||||
"notification_tracker",
|
||||
"provider_id, enabled",
|
||||
),
|
||||
("ix_command_config_user_id", "command_config", "user_id"),
|
||||
("ix_command_template_config_user_id", "command_template_config", "user_id"),
|
||||
("ix_command_tracker_user_id", "command_tracker", "user_id"),
|
||||
("ix_command_tracker_provider_id", "command_tracker", "provider_id"),
|
||||
("ix_action_user_id", "action", "user_id"),
|
||||
("ix_action_provider_id", "action", "provider_id"),
|
||||
# Dashboard: SELECT event_log WHERE user_id = ? ORDER BY created_at DESC
|
||||
("ix_event_log_user_created", "event_log", "user_id, created_at DESC"),
|
||||
("ix_event_log_provider_id", "event_log", "provider_id"),
|
||||
("ix_event_log_notification_tracker_id", "event_log", "notification_tracker_id"),
|
||||
("ix_event_log_action_id", "event_log", "action_id"),
|
||||
# Webhook log hot path: WHERE provider_id = ? ORDER BY created_at DESC
|
||||
(
|
||||
"ix_webhook_payload_log_provider_created",
|
||||
"webhook_payload_log",
|
||||
"provider_id, created_at DESC",
|
||||
),
|
||||
# Notification tracker join tables
|
||||
(
|
||||
"ix_notification_tracker_target_notification_tracker_id",
|
||||
"notification_tracker_target",
|
||||
"notification_tracker_id",
|
||||
),
|
||||
(
|
||||
"ix_notification_tracker_target_target_id",
|
||||
"notification_tracker_target",
|
||||
"target_id",
|
||||
),
|
||||
("ix_target_receiver_target_id", "target_receiver", "target_id"),
|
||||
("ix_template_slot_config_id", "template_slot", "config_id"),
|
||||
("ix_command_template_slot_config_id", "command_template_slot", "config_id"),
|
||||
("ix_action_rule_action_id", "action_rule", "action_id"),
|
||||
("ix_action_execution_action_started", "action_execution", "action_id, started_at DESC"),
|
||||
]
|
||||
|
||||
|
||||
async def migrate_performance_indexes(engine: AsyncEngine) -> None:
|
||||
"""Create missing performance indexes on hot query paths.
|
||||
|
||||
Every index is created with IF NOT EXISTS so the migration is safe to
|
||||
replay on every boot. We only create the index when the table exists —
|
||||
early boots before other migrations land would otherwise raise.
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
for name, table, columns in _INDEXES:
|
||||
_assert_ident(name, "index")
|
||||
_assert_ident(table, "table")
|
||||
# Columns list is a trusted literal constructed above — never user input.
|
||||
if not await _has_table(conn, table):
|
||||
continue
|
||||
try:
|
||||
await conn.execute(
|
||||
text(f"CREATE INDEX IF NOT EXISTS {name} ON {table} ({columns})")
|
||||
)
|
||||
except Exception: # pragma: no cover — log and continue
|
||||
logger.warning(
|
||||
"Failed to create index %s on %s(%s)",
|
||||
name, table, columns, exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema version tracking — lightweight alternative to Alembic while the
|
||||
# hand-rolled idempotent migrations remain the source of truth. Gives
|
||||
# operators a single-row answer to "what schema is this DB at" and lets
|
||||
# future upgrades short-circuit migrations that already ran.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CURRENT_SCHEMA_VERSION = 1
|
||||
|
||||
|
||||
async def migrate_schema_version(engine: AsyncEngine) -> None:
|
||||
"""Create schema_version table and bump it to CURRENT_SCHEMA_VERSION."""
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE TABLE IF NOT EXISTS schema_version ("
|
||||
" id INTEGER PRIMARY KEY CHECK (id = 1),"
|
||||
" version INTEGER NOT NULL,"
|
||||
" applied_at TEXT NOT NULL"
|
||||
")"
|
||||
)
|
||||
)
|
||||
row = await conn.run_sync(
|
||||
lambda sc: sc.execute(
|
||||
text("SELECT version FROM schema_version WHERE id = 1")
|
||||
).fetchone()
|
||||
)
|
||||
from datetime import datetime, timezone
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
if row is None:
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO schema_version (id, version, applied_at) "
|
||||
"VALUES (1, :v, :t)"
|
||||
),
|
||||
{"v": CURRENT_SCHEMA_VERSION, "t": now},
|
||||
)
|
||||
logger.info("Initialized schema_version at %d", CURRENT_SCHEMA_VERSION)
|
||||
elif int(row[0]) < CURRENT_SCHEMA_VERSION:
|
||||
await conn.execute(
|
||||
text(
|
||||
"UPDATE schema_version SET version = :v, applied_at = :t "
|
||||
"WHERE id = 1"
|
||||
),
|
||||
{"v": CURRENT_SCHEMA_VERSION, "t": now},
|
||||
)
|
||||
logger.info(
|
||||
"Bumped schema_version from %s to %d",
|
||||
row[0], CURRENT_SCHEMA_VERSION,
|
||||
)
|
||||
|
||||
@@ -394,8 +394,37 @@ async def _seed_default_command_configs() -> None:
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _ensure_system_user() -> None:
|
||||
"""Ensure a User row with id=0 exists.
|
||||
|
||||
Historically the app used ``user_id=0`` as a sentinel for "system-owned"
|
||||
defaults (tracking configs, templates, etc.). Now that we enable
|
||||
``PRAGMA foreign_keys=ON`` at connect time, those inserts would fail
|
||||
with ``FOREIGN KEY constraint failed`` unless a placeholder user row
|
||||
with the matching id exists.
|
||||
"""
|
||||
engine = get_engine()
|
||||
async with engine.begin() as conn:
|
||||
# INSERT OR IGNORE so re-running seeds is cheap and idempotent.
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT OR IGNORE INTO user "
|
||||
"(id, username, hashed_password, role, token_version, created_at) "
|
||||
"VALUES (0, :u, :p, :r, 1, :t)"
|
||||
),
|
||||
{
|
||||
"u": "__system__",
|
||||
# Invalid bcrypt hash — nobody can ever log in as this user.
|
||||
"p": "!disabled!",
|
||||
"r": "system",
|
||||
"t": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def seed_all() -> None:
|
||||
"""Run all seed functions in order."""
|
||||
await _ensure_system_user()
|
||||
await _seed_default_templates()
|
||||
await _seed_default_command_templates()
|
||||
await _seed_default_tracking_configs()
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
"""Pre-migration database snapshots.
|
||||
|
||||
Runs at lifespan startup BEFORE migrations execute. Produces a consistent
|
||||
point-in-time copy of the SQLite database using ``VACUUM INTO`` (atomic,
|
||||
cannot tear against concurrent activity, works with WAL).
|
||||
|
||||
The snapshot is the operator's fallback if a future migration corrupts the
|
||||
schema — restore is a single ``mv`` / ``docker cp``. We keep the N most
|
||||
recent files (default 5) and never fail startup if the snapshot itself
|
||||
fails: a snapshot is best-effort safety net, not a gate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_SNAPSHOT_GLOB = "pre-migrate-*.db"
|
||||
_SNAPSHOT_NAME_RE = re.compile(r"^[A-Za-z0-9._+\-:]+$")
|
||||
|
||||
|
||||
def _sqlite_path_from_url(url: str) -> Path | None:
|
||||
"""Extract the filesystem path from a ``sqlite+aiosqlite:///...`` URL."""
|
||||
if not url.startswith("sqlite"):
|
||||
return None
|
||||
# e.g. "sqlite+aiosqlite:///C:/data/notify_bridge.db"
|
||||
prefix, _, rest = url.partition(":///")
|
||||
if not rest:
|
||||
return None
|
||||
return Path(rest)
|
||||
|
||||
|
||||
async def snapshot_database(
|
||||
engine: AsyncEngine,
|
||||
target_dir: Path,
|
||||
*,
|
||||
label: str = "pre-migrate",
|
||||
) -> Path | None:
|
||||
"""Write a consistent copy of the SQLite DB to ``target_dir``.
|
||||
|
||||
Uses ``VACUUM INTO`` which SQLite executes atomically against a read
|
||||
snapshot — safe under WAL, cannot produce a torn copy. Returns the
|
||||
snapshot path on success, ``None`` when skipped or on non-fatal
|
||||
failure. Never raises: callers treat a missing snapshot as acceptable
|
||||
(the main DB remains the source of truth).
|
||||
"""
|
||||
if not _SNAPSHOT_NAME_RE.match(label):
|
||||
_LOGGER.warning("Snapshot label %r contains unsafe characters; skipping", label)
|
||||
return None
|
||||
|
||||
url = str(engine.url)
|
||||
src = _sqlite_path_from_url(url)
|
||||
if src is None:
|
||||
_LOGGER.debug("Non-SQLite engine; skipping snapshot")
|
||||
return None
|
||||
if not src.exists():
|
||||
_LOGGER.debug("DB file %s does not exist yet (fresh install); skipping snapshot", src)
|
||||
return None
|
||||
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
|
||||
dest = target_dir / f"{label}-{ts}.db"
|
||||
|
||||
# VACUUM INTO accepts a string literal, not a bind parameter. The dest
|
||||
# path is built from our own label + timestamp (never user input), so
|
||||
# escaping is straightforward — still, reject any dest containing a
|
||||
# single quote as a belt-and-braces check.
|
||||
dest_str = str(dest)
|
||||
if "'" in dest_str:
|
||||
_LOGGER.warning("Refusing to snapshot to path containing a single quote: %s", dest_str)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with engine.connect() as conn:
|
||||
# VACUUM cannot run inside an explicit transaction; use the
|
||||
# plain connection without begin().
|
||||
await conn.execute(text(f"VACUUM INTO '{dest_str}'"))
|
||||
_LOGGER.info("Database snapshot written: %s (%.1f KiB)", dest, dest.stat().st_size / 1024)
|
||||
return dest
|
||||
except Exception:
|
||||
_LOGGER.warning(
|
||||
"Pre-migration snapshot failed — continuing with startup. "
|
||||
"Check disk space in %s.",
|
||||
target_dir,
|
||||
exc_info=True,
|
||||
)
|
||||
# Partial file can linger if VACUUM INTO aborted mid-write; clean up.
|
||||
try:
|
||||
if dest.exists():
|
||||
dest.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def prune_old_snapshots(target_dir: Path, keep: int) -> list[Path]:
|
||||
"""Keep the ``keep`` most recent pre-migrate snapshots, delete the rest.
|
||||
|
||||
Returns the list of paths that were deleted. Safe to call with
|
||||
``keep=0`` (deletes everything) or when the directory does not exist.
|
||||
"""
|
||||
if keep < 0:
|
||||
raise ValueError("keep must be >= 0")
|
||||
if not target_dir.is_dir():
|
||||
return []
|
||||
|
||||
try:
|
||||
snapshots = sorted(
|
||||
target_dir.glob(_SNAPSHOT_GLOB),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
except OSError:
|
||||
return []
|
||||
|
||||
deleted: list[Path] = []
|
||||
for old in snapshots[keep:]:
|
||||
try:
|
||||
old.unlink()
|
||||
deleted.append(old)
|
||||
except OSError:
|
||||
_LOGGER.debug("Could not delete old snapshot %s", old, exc_info=True)
|
||||
if deleted:
|
||||
_LOGGER.info(
|
||||
"Pruned %d old pre-migrate snapshot(s); kept %d most recent",
|
||||
len(deleted), min(keep, len(snapshots)),
|
||||
)
|
||||
return deleted
|
||||
|
||||
|
||||
async def snapshot_and_prune(
|
||||
engine: AsyncEngine,
|
||||
target_dir: Path,
|
||||
*,
|
||||
keep: int,
|
||||
) -> Path | None:
|
||||
"""Take a snapshot and prune old ones. Used by the lifespan startup path.
|
||||
|
||||
``keep=0`` disables snapshotting entirely.
|
||||
"""
|
||||
if keep <= 0:
|
||||
return None
|
||||
snapshot_path = await snapshot_database(engine, target_dir)
|
||||
# Always prune even if this run's snapshot failed — old files still
|
||||
# cost disk and may have been written by prior successful boots.
|
||||
await asyncio.to_thread(prune_old_snapshots, target_dir, keep)
|
||||
return snapshot_path
|
||||
@@ -0,0 +1,324 @@
|
||||
"""Production-grade logging configuration.
|
||||
|
||||
Installs one ``dictConfig`` layout with:
|
||||
|
||||
* A ``LogRecordFactory`` that pulls request-scoped identifiers from
|
||||
``notify_bridge_core.log_context`` onto every record, so logs can be
|
||||
filtered/correlated by ``request_id``, ``command``, ``chat_id``,
|
||||
``bot_id``, ``dispatch_id`` without each call site passing them.
|
||||
* A ``SecretMaskingFilter`` that redacts Telegram bot tokens and common
|
||||
``Authorization`` / ``x-api-key`` headers so an accidental ``repr`` or
|
||||
dumped request doesn't leak credentials into the log aggregator.
|
||||
* A text formatter (default) or a JSON formatter (one line per record)
|
||||
selectable via ``NOTIFY_BRIDGE_LOG_FORMAT`` / app setting.
|
||||
|
||||
Levels are configurable three ways (later wins):
|
||||
|
||||
1. ``NOTIFY_BRIDGE_LOG_LEVEL`` env var (root) plus
|
||||
``NOTIFY_BRIDGE_LOG_LEVELS`` (``mod=LEVEL,mod2=LEVEL``).
|
||||
2. DB ``AppSetting`` rows ``log_level`` / ``log_levels`` / ``log_format``,
|
||||
applied after migrations during startup.
|
||||
3. Live edits via the settings API — ``apply_log_levels()`` updates
|
||||
existing loggers in place without a server restart.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import logging.config
|
||||
import re
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from notify_bridge_core.log_context import (
|
||||
bot_id_var,
|
||||
chat_id_var,
|
||||
command_var,
|
||||
dispatch_id_var,
|
||||
request_id_var,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Secret masking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Telegram bot tokens: /bot<digits>:<alnum with dashes/underscores>
|
||||
_TELEGRAM_TOKEN_RE = re.compile(r"/bot\d+:[A-Za-z0-9_-]{20,}")
|
||||
|
||||
# Header-style secrets: Authorization: Bearer xxx, x-api-key=xxx, etc.
|
||||
# Only matches reasonably long tokens so short legitimate values don't trip.
|
||||
_HEADER_SECRET_RE = re.compile(
|
||||
r"(?i)(authorization|x-api-key|api[_-]?key|password|secret|access[_-]?token|refresh[_-]?token)"
|
||||
r"([\"']?\s*[:=]\s*[\"']?)"
|
||||
r"([A-Za-z0-9._+/=\-]{12,})"
|
||||
)
|
||||
|
||||
|
||||
def _mask(text: str) -> str:
|
||||
redacted = _TELEGRAM_TOKEN_RE.sub("/bot***", text)
|
||||
redacted = _HEADER_SECRET_RE.sub(r"\1\2***", redacted)
|
||||
return redacted
|
||||
|
||||
|
||||
class SecretMaskingFilter(logging.Filter):
|
||||
"""Redact likely secrets from every log message before it's emitted.
|
||||
|
||||
Covers three surfaces where a leaked token can end up in the log:
|
||||
the formatted message, a cached exception traceback (``exc_text``),
|
||||
and a cached stack frame dump (``stack_info``). The formatter still
|
||||
expands ``exc_info`` for us when ``exc_text`` is None, so we also
|
||||
pre-render + mask on first emission.
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
try:
|
||||
msg = record.getMessage()
|
||||
except Exception:
|
||||
return True
|
||||
redacted = _mask(msg)
|
||||
if redacted != msg:
|
||||
# Replace the formatted message and drop args so the handler
|
||||
# doesn't re-format with the original values.
|
||||
record.msg = redacted
|
||||
record.args = ()
|
||||
|
||||
if record.exc_info and not record.exc_text:
|
||||
# Pre-render so we can mask before the formatter caches it.
|
||||
fmt = logging.Formatter()
|
||||
record.exc_text = fmt.formatException(record.exc_info)
|
||||
if record.exc_text:
|
||||
record.exc_text = _mask(record.exc_text)
|
||||
if record.stack_info:
|
||||
record.stack_info = _mask(record.stack_info)
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Record factory — injects context identifiers onto every record
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CONTEXT_FIELDS = ("request_id", "command", "chat_id", "bot_id", "dispatch_id")
|
||||
_PLACEHOLDER = "-"
|
||||
|
||||
_original_factory = logging.getLogRecordFactory()
|
||||
|
||||
|
||||
def _context_record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
|
||||
record = _original_factory(*args, **kwargs)
|
||||
record.request_id = request_id_var.get() or _PLACEHOLDER
|
||||
record.command = command_var.get() or _PLACEHOLDER
|
||||
record.chat_id = chat_id_var.get() or _PLACEHOLDER
|
||||
bid = bot_id_var.get()
|
||||
record.bot_id = str(bid) if bid is not None else _PLACEHOLDER
|
||||
record.dispatch_id = dispatch_id_var.get() or _PLACEHOLDER
|
||||
return record
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON formatter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""Emit one JSON object per log record."""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
payload: dict[str, Any] = {
|
||||
"ts": self.formatTime(record, "%Y-%m-%dT%H:%M:%S") + f".{int(record.msecs):03d}",
|
||||
"level": record.levelname,
|
||||
"logger": record.name,
|
||||
"module": record.module,
|
||||
"line": record.lineno,
|
||||
"msg": record.getMessage(),
|
||||
}
|
||||
for field in _CONTEXT_FIELDS:
|
||||
val = getattr(record, field, None)
|
||||
if val and val != _PLACEHOLDER:
|
||||
payload[field] = val
|
||||
# Prefer the pre-masked exc_text cached by SecretMaskingFilter over
|
||||
# re-formatting from exc_info, which would bypass the mask.
|
||||
if record.exc_text:
|
||||
payload["exc"] = record.exc_text
|
||||
elif record.exc_info:
|
||||
payload["exc"] = self.formatException(record.exc_info)
|
||||
if record.stack_info:
|
||||
payload["stack"] = record.stack_info
|
||||
return json.dumps(payload, ensure_ascii=False, default=str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text formatter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Keeps all context fields on one line so grep-by-field works. Empty values
|
||||
# are rendered as "-" by the record factory to avoid KeyError if a record
|
||||
# arrives without the filter.
|
||||
_TEXT_FORMAT = (
|
||||
"%(asctime)s %(levelname)-7s %(name)s:%(lineno)d "
|
||||
"[req=%(request_id)s cmd=%(command)s bot=%(bot_id)s chat=%(chat_id)s disp=%(dispatch_id)s] "
|
||||
"%(message)s"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Level override parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_VALID_LEVELS = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", "NOTSET"}
|
||||
|
||||
|
||||
def parse_level_overrides(raw: str) -> dict[str, str]:
|
||||
"""Parse ``module=LEVEL,module2=LEVEL`` into a mapping of validated levels.
|
||||
|
||||
Invalid entries (bad format, unknown level) are silently dropped —
|
||||
a malformed env var or DB setting must not crash boot.
|
||||
"""
|
||||
result: dict[str, str] = {}
|
||||
for chunk in (raw or "").split(","):
|
||||
chunk = chunk.strip()
|
||||
if not chunk or "=" not in chunk:
|
||||
continue
|
||||
mod, _, lvl = chunk.partition("=")
|
||||
mod = mod.strip()
|
||||
lvl = lvl.strip().upper()
|
||||
if not mod or lvl not in _VALID_LEVELS:
|
||||
continue
|
||||
result[mod] = lvl
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_level(level: str | None, default: str = "INFO") -> str:
|
||||
if not level:
|
||||
return default
|
||||
up = level.strip().upper()
|
||||
return up if up in _VALID_LEVELS else default
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Setup + live apply
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Libraries we quiet by default — noisy at DEBUG and almost always irrelevant
|
||||
# to a service issue. Override via LOG_LEVELS=sqlalchemy.engine=DEBUG when
|
||||
# actually debugging.
|
||||
_NOISY_LIBRARY_DEFAULTS: dict[str, str] = {
|
||||
"sqlalchemy": "WARNING",
|
||||
"sqlalchemy.engine": "WARNING",
|
||||
"sqlalchemy.pool": "WARNING",
|
||||
"aiohttp": "WARNING",
|
||||
"aiohttp.access": "WARNING",
|
||||
"aiohttp.client": "WARNING",
|
||||
"aiohttp.server": "WARNING",
|
||||
"apscheduler": "WARNING",
|
||||
"apscheduler.scheduler": "WARNING",
|
||||
"apscheduler.executors.default": "WARNING",
|
||||
"urllib3": "WARNING",
|
||||
"asyncio": "WARNING",
|
||||
"httpx": "WARNING",
|
||||
"httpcore": "WARNING",
|
||||
"PIL": "WARNING",
|
||||
"uvicorn.access": "WARNING",
|
||||
}
|
||||
|
||||
|
||||
def setup_logging(
|
||||
*,
|
||||
level: str = "INFO",
|
||||
fmt: str = "text",
|
||||
per_module_levels: str = "",
|
||||
) -> None:
|
||||
"""Install the logging configuration. Safe to call more than once.
|
||||
|
||||
Args:
|
||||
level: Root log level (applied to ``notify_bridge_*`` loggers).
|
||||
fmt: ``"text"`` (default) or ``"json"``.
|
||||
per_module_levels: ``mod=LEVEL,mod2=LEVEL`` overrides. Wins over the
|
||||
root level for the listed loggers.
|
||||
"""
|
||||
root_level = _normalize_level(level, "INFO")
|
||||
overrides = parse_level_overrides(per_module_levels)
|
||||
|
||||
# Install the context-aware record factory (idempotent — setting the same
|
||||
# factory twice is fine because ``_original_factory`` is captured at
|
||||
# import time).
|
||||
logging.setLogRecordFactory(_context_record_factory)
|
||||
|
||||
if fmt == "json":
|
||||
formatters = {"default": {"()": f"{__name__}.JsonFormatter"}}
|
||||
else:
|
||||
formatters = {
|
||||
"default": {
|
||||
"format": _TEXT_FORMAT,
|
||||
"datefmt": "%Y-%m-%d %H:%M:%S",
|
||||
}
|
||||
}
|
||||
|
||||
# Start with noisy-library defaults, then layer user overrides on top so
|
||||
# the user can raise them to DEBUG when actually debugging.
|
||||
loggers: dict[str, dict[str, Any]] = {}
|
||||
for mod, lvl in _NOISY_LIBRARY_DEFAULTS.items():
|
||||
loggers[mod] = {"level": lvl, "propagate": True}
|
||||
# App loggers follow the root level unless overridden.
|
||||
loggers["notify_bridge_server"] = {"level": root_level, "propagate": True}
|
||||
loggers["notify_bridge_core"] = {"level": root_level, "propagate": True}
|
||||
# User overrides win.
|
||||
for mod, lvl in overrides.items():
|
||||
loggers[mod] = {"level": lvl, "propagate": True}
|
||||
|
||||
config: dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"filters": {
|
||||
"mask_secrets": {"()": f"{__name__}.SecretMaskingFilter"},
|
||||
},
|
||||
"formatters": formatters,
|
||||
"handlers": {
|
||||
"stderr": {
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": sys.stderr,
|
||||
"formatter": "default",
|
||||
"filters": ["mask_secrets"],
|
||||
},
|
||||
},
|
||||
"root": {
|
||||
"level": root_level,
|
||||
"handlers": ["stderr"],
|
||||
},
|
||||
"loggers": loggers,
|
||||
}
|
||||
logging.config.dictConfig(config)
|
||||
|
||||
|
||||
def apply_log_levels(
|
||||
*,
|
||||
level: str | None,
|
||||
per_module_levels: str | None,
|
||||
) -> None:
|
||||
"""Update existing logger levels in-place without rebuilding handlers.
|
||||
|
||||
Called when an admin changes the log settings at runtime. Setting
|
||||
``level`` to None leaves the root untouched; setting it to a valid
|
||||
level applies to ``notify_bridge_server`` / ``notify_bridge_core``.
|
||||
|
||||
``per_module_levels`` is treated as an exclusive set — loggers that
|
||||
previously had an override but aren't in the new string are reset
|
||||
*toward* the root level so a removed override actually takes effect.
|
||||
"""
|
||||
if level:
|
||||
lvl = _normalize_level(level, "INFO")
|
||||
logging.getLogger("notify_bridge_server").setLevel(lvl)
|
||||
logging.getLogger("notify_bridge_core").setLevel(lvl)
|
||||
# NOTSET on root is almost never what you want — keep root where it is
|
||||
# unless the caller explicitly set something.
|
||||
logging.getLogger().setLevel(lvl)
|
||||
|
||||
if per_module_levels is not None:
|
||||
overrides = parse_level_overrides(per_module_levels)
|
||||
# Apply new overrides
|
||||
for mod, lvl in overrides.items():
|
||||
logging.getLogger(mod).setLevel(lvl)
|
||||
# Reset noisy libs that aren't in the new overrides back to defaults
|
||||
for mod, default_lvl in _NOISY_LIBRARY_DEFAULTS.items():
|
||||
if mod not in overrides:
|
||||
logging.getLogger(mod).setLevel(default_lvl)
|
||||
@@ -9,13 +9,18 @@ from slowapi import _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
|
||||
# Ensure app-level loggers are visible
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
from .config import settings as _log_cfg
|
||||
_log_level = logging.DEBUG if _log_cfg.debug else logging.INFO
|
||||
logging.getLogger("notify_bridge_server").setLevel(_log_level)
|
||||
logging.getLogger("notify_bridge_core").setLevel(_log_level)
|
||||
from .logging_setup import setup_logging
|
||||
|
||||
# Boot logging from env-based config. DB-backed AppSetting rows (``log_level`` /
|
||||
# ``log_levels`` / ``log_format``) override this after migrations — see the
|
||||
# lifespan block below.
|
||||
setup_logging(
|
||||
level="DEBUG" if _log_cfg.debug else _log_cfg.log_level,
|
||||
fmt=_log_cfg.log_format,
|
||||
per_module_levels=_log_cfg.log_levels,
|
||||
)
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
from .database.engine import init_db
|
||||
from .database.models import * # noqa: F401,F403 — ensure all models registered
|
||||
@@ -47,13 +52,41 @@ from .api.webhook_logs import router as webhook_logs_router
|
||||
from .api.backup import router as backup_router
|
||||
|
||||
|
||||
# Readiness flag — flipped to True once the scheduler has started and the
|
||||
# app is fully initialized. Exposed via /api/ready for orchestrators.
|
||||
_READY: bool = False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global _READY
|
||||
await init_db()
|
||||
# Run data migrations (idempotent)
|
||||
from .database.engine import get_engine
|
||||
from .database.migrations import migrate_schema, migrate_tracker_targets, migrate_entity_refactor, migrate_template_slots, migrate_target_receivers, migrate_template_locale, migrate_receivers_from_config, migrate_command_slot_locale, migrate_notification_slot_locale, migrate_user_token_version
|
||||
from .database.migrations import (
|
||||
migrate_schema,
|
||||
migrate_tracker_targets,
|
||||
migrate_entity_refactor,
|
||||
migrate_template_slots,
|
||||
migrate_target_receivers,
|
||||
migrate_template_locale,
|
||||
migrate_receivers_from_config,
|
||||
migrate_command_slot_locale,
|
||||
migrate_notification_slot_locale,
|
||||
migrate_user_token_version,
|
||||
migrate_performance_indexes,
|
||||
migrate_schema_version,
|
||||
)
|
||||
from .database.snapshot import snapshot_and_prune
|
||||
engine = get_engine()
|
||||
# Take a consistent DB snapshot BEFORE migrations run, so operators can
|
||||
# roll back a bad upgrade by restoring one file. Best-effort — failures
|
||||
# are logged, not raised.
|
||||
await snapshot_and_prune(
|
||||
engine,
|
||||
_log_cfg.data_dir / "backups",
|
||||
keep=_log_cfg.pre_migrate_snapshot_keep,
|
||||
)
|
||||
await migrate_schema(engine)
|
||||
await migrate_tracker_targets(engine)
|
||||
await migrate_entity_refactor(engine)
|
||||
@@ -64,8 +97,28 @@ async def lifespan(app: FastAPI):
|
||||
await migrate_command_slot_locale(engine)
|
||||
await migrate_notification_slot_locale(engine)
|
||||
await migrate_user_token_version(engine)
|
||||
await migrate_performance_indexes(engine)
|
||||
await migrate_schema_version(engine)
|
||||
from .database.seeds import seed_all
|
||||
await seed_all()
|
||||
# Apply DB-backed logging settings (override env-based boot config).
|
||||
# log_format still needs a restart — changing it means swapping the
|
||||
# handler formatter entirely.
|
||||
try:
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession as _AS_log
|
||||
from .api.app_settings import get_setting as _get_log_setting
|
||||
from .logging_setup import apply_log_levels
|
||||
async with _AS_log(engine) as _log_session:
|
||||
db_level = await _get_log_setting(_log_session, "log_level")
|
||||
db_levels = await _get_log_setting(_log_session, "log_levels")
|
||||
apply_log_levels(level=db_level or None, per_module_levels=db_levels)
|
||||
_LOGGER.info(
|
||||
"Logging initialized: level=%s overrides=%r format=%s",
|
||||
db_level or _log_cfg.log_level, db_levels or _log_cfg.log_levels,
|
||||
_log_cfg.log_format,
|
||||
)
|
||||
except Exception: # pragma: no cover — never let logging setup abort boot
|
||||
_LOGGER.exception("Failed to apply DB-backed log settings; keeping env-based levels")
|
||||
# Apply any pending restore staged via /api/backup/prepare-restore
|
||||
from .services.pending_restore import apply_pending_restore_if_any
|
||||
await apply_pending_restore_if_any()
|
||||
@@ -77,16 +130,28 @@ async def lifespan(app: FastAPI):
|
||||
set_webhook_secret(_secret or None)
|
||||
from .services.scheduler import start_scheduler, get_scheduler
|
||||
await start_scheduler()
|
||||
_READY = True
|
||||
yield
|
||||
# Graceful shutdown
|
||||
from .services.http_session import close_http_session
|
||||
await close_http_session()
|
||||
# Graceful shutdown — stop the scheduler FIRST so in-flight jobs finish
|
||||
# before we close their HTTP session. Then close the shared session and
|
||||
# dispose the DB engine.
|
||||
_READY = False
|
||||
scheduler = get_scheduler()
|
||||
if scheduler.running:
|
||||
scheduler.shutdown()
|
||||
scheduler.shutdown(wait=True)
|
||||
from .services.http_session import close_http_session
|
||||
await close_http_session()
|
||||
from .database.engine import dispose_engine
|
||||
await dispose_engine()
|
||||
|
||||
|
||||
app = FastAPI(title="Notify Bridge", version="0.1.0", lifespan=lifespan)
|
||||
try:
|
||||
from importlib.metadata import version as _pkg_version
|
||||
_APP_VERSION = _pkg_version("notify-bridge-server")
|
||||
except Exception: # pragma: no cover — editable install edge cases
|
||||
_APP_VERSION = "0.0.0+unknown"
|
||||
|
||||
app = FastAPI(title="Notify Bridge", version=_APP_VERSION, lifespan=lifespan)
|
||||
|
||||
# --- Security headers ---
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
@@ -94,6 +159,24 @@ from starlette.requests import Request as StarletteRequest
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
|
||||
|
||||
_CSP = (
|
||||
"default-src 'self'; "
|
||||
"img-src 'self' data: blob: https:; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
# SvelteKit's static adapter emits an inline bootstrap <script> with the
|
||||
# hydration payload, so 'self' alone blocks the SPA from starting.
|
||||
# 'unsafe-inline' re-enables it; the app's primary XSS protection still
|
||||
# comes from Svelte's template auto-escaping and frontend/sanitize.ts
|
||||
# for the few {@html} paths that render user-controlled content.
|
||||
"script-src 'self' 'unsafe-inline'; "
|
||||
"connect-src 'self'; "
|
||||
"font-src 'self' data:; "
|
||||
"base-uri 'self'; "
|
||||
"form-action 'self'; "
|
||||
"frame-ancestors 'none'"
|
||||
)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: StarletteRequest, call_next):
|
||||
response: StarletteResponse = await call_next(request)
|
||||
@@ -101,6 +184,14 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers.setdefault("Content-Security-Policy", _CSP)
|
||||
# HSTS only makes sense over HTTPS; set when the edge terminates TLS
|
||||
# and forwards X-Forwarded-Proto=https.
|
||||
if request.headers.get("x-forwarded-proto") == "https":
|
||||
response.headers.setdefault(
|
||||
"Strict-Transport-Security",
|
||||
"max-age=31536000; includeSubDomains",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@@ -153,7 +244,22 @@ app.include_router(backup_router)
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
"""Liveness: process is up and responding. Always returns 200 once the
|
||||
ASGI app has started. Keep this endpoint anonymous and trivially cheap."""
|
||||
return {"status": "ok", "version": _APP_VERSION}
|
||||
|
||||
|
||||
@app.get("/api/ready")
|
||||
async def ready():
|
||||
"""Readiness: migrations and scheduler have started, app can serve traffic.
|
||||
|
||||
Returns 503 until the lifespan startup sequence has completed. Use this
|
||||
for orchestrator readiness probes (Docker, Kubernetes).
|
||||
"""
|
||||
if not _READY:
|
||||
from starlette.responses import JSONResponse
|
||||
return JSONResponse({"status": "starting"}, status_code=503)
|
||||
return {"status": "ready", "version": _APP_VERSION}
|
||||
|
||||
|
||||
# --- Serve frontend static files (production) ---
|
||||
@@ -186,4 +292,12 @@ if _cfg.static_dir and Path(_cfg.static_dir).is_dir():
|
||||
|
||||
def run():
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=_cfg.host, port=_cfg.port)
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=_cfg.host,
|
||||
port=_cfg.port,
|
||||
proxy_headers=True,
|
||||
forwarded_allow_ips=_cfg.forwarded_allow_ips or "127.0.0.1",
|
||||
timeout_graceful_shutdown=_cfg.graceful_shutdown_seconds,
|
||||
access_log=not _cfg.debug,
|
||||
)
|
||||
|
||||
@@ -387,10 +387,9 @@ async def export_backup_to_file(
|
||||
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
|
||||
filename = f"backup-{ts}.json"
|
||||
filepath = backup_dir / filename
|
||||
filepath.write_text(
|
||||
json.dumps(backup.model_dump(), indent=2, ensure_ascii=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
import asyncio as _asyncio
|
||||
payload = json.dumps(backup.model_dump(), indent=2, ensure_ascii=False)
|
||||
await _asyncio.to_thread(filepath.write_text, payload, encoding="utf-8")
|
||||
_LOGGER.info("Scheduled backup saved: %s", filepath)
|
||||
return filepath
|
||||
|
||||
@@ -399,7 +398,13 @@ def cleanup_old_backups(backup_dir: Path, keep: int = 5) -> list[str]:
|
||||
"""Delete oldest backup files exceeding `keep` count. Returns deleted filenames."""
|
||||
if not backup_dir.is_dir():
|
||||
return []
|
||||
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
|
||||
# Sort by mtime (newest first) so behavior doesn't depend on the filename
|
||||
# timestamp format, which could change later without updating this code.
|
||||
files = sorted(
|
||||
backup_dir.glob("backup-*.json"),
|
||||
key=lambda f: f.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
deleted = []
|
||||
for old in files[keep:]:
|
||||
old.unlink()
|
||||
@@ -413,7 +418,13 @@ def list_backup_files(backup_dir: Path) -> list[dict[str, Any]]:
|
||||
"""List backup files in the directory with metadata."""
|
||||
if not backup_dir.is_dir():
|
||||
return []
|
||||
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
|
||||
# Sort by mtime (newest first) so behavior doesn't depend on the filename
|
||||
# timestamp format, which could change later without updating this code.
|
||||
files = sorted(
|
||||
backup_dir.glob("backup-*.json"),
|
||||
key=lambda f: f.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
result = []
|
||||
for f in files:
|
||||
stat = f.stat()
|
||||
|
||||
@@ -11,23 +11,36 @@ Call ``close_http_session()`` once during application shutdown.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
|
||||
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30, connect=10)
|
||||
|
||||
_session: aiohttp.ClientSession | None = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_http_session() -> aiohttp.ClientSession:
|
||||
"""Get or create the shared HTTP session."""
|
||||
"""Get or create the shared HTTP session.
|
||||
|
||||
Concurrent first-callers are serialized through ``_lock`` so we never
|
||||
leak a second ClientSession / connector pair. Once established, hot
|
||||
callers skip the lock via the fast-path check.
|
||||
"""
|
||||
global _session
|
||||
if _session is None or _session.closed:
|
||||
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
|
||||
if _session is not None and not _session.closed:
|
||||
return _session
|
||||
async with _lock:
|
||||
if _session is None or _session.closed:
|
||||
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
|
||||
return _session
|
||||
|
||||
|
||||
async def close_http_session() -> None:
|
||||
"""Close the shared HTTP session (call on app shutdown)."""
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
await _session.close()
|
||||
async with _lock:
|
||||
if _session is not None and not _session.closed:
|
||||
await _session.close()
|
||||
_session = None
|
||||
|
||||
@@ -188,6 +188,7 @@ _SAMPLE_CONTEXT = {
|
||||
"current_time": "09:00",
|
||||
"current_datetime": "22.03.2026, 09:00 UTC",
|
||||
"weekday": "Monday",
|
||||
"timezone": "UTC",
|
||||
"custom_vars": {"team": "Engineering", "message": "Time for standup!"},
|
||||
"team": "Engineering",
|
||||
"message": "Time for standup!",
|
||||
|
||||
@@ -3,11 +3,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_zoneinfo(tz_name: str | None) -> ZoneInfo:
|
||||
"""Resolve an IANA tz string to a ZoneInfo, falling back to UTC on any error.
|
||||
|
||||
Kept local to avoid importing from api/dispatch layers inside the scheduler
|
||||
module (which is loaded at startup, before the API routers).
|
||||
"""
|
||||
if not tz_name:
|
||||
return ZoneInfo("UTC")
|
||||
try:
|
||||
return ZoneInfo(tz_name)
|
||||
except (ZoneInfoNotFoundError, ValueError):
|
||||
_LOGGER.warning("Unknown timezone %r; falling back to UTC", tz_name)
|
||||
return ZoneInfo("UTC")
|
||||
|
||||
|
||||
async def _load_app_timezone() -> ZoneInfo:
|
||||
"""Load the admin-configured app timezone from AppSetting (falls back to UTC)."""
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..api.app_settings import get_setting
|
||||
from ..database.engine import get_engine
|
||||
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
tz_name = await get_setting(session, "timezone")
|
||||
return _resolve_zoneinfo(tz_name)
|
||||
|
||||
_scheduler: AsyncIOScheduler | None = None
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -57,7 +85,21 @@ def _compute_jitter(interval_seconds: int) -> int:
|
||||
def get_scheduler() -> AsyncIOScheduler:
|
||||
global _scheduler
|
||||
if _scheduler is None:
|
||||
_scheduler = AsyncIOScheduler()
|
||||
# Sensible production defaults applied to every job unless overridden:
|
||||
# * coalesce — collapse a queue of missed runs into one firing after
|
||||
# a restart / pause, instead of bursting to catch up.
|
||||
# * misfire_grace_time — accept firings up to 5 min late without
|
||||
# dropping them silently.
|
||||
# * max_instances=1 — never run two copies of the same tracker tick
|
||||
# concurrently; the scheduler already enforces this on add_job,
|
||||
# but we also set it as the default for safety.
|
||||
_scheduler = AsyncIOScheduler(
|
||||
job_defaults={
|
||||
"coalesce": True,
|
||||
"misfire_grace_time": 300,
|
||||
"max_instances": 1,
|
||||
},
|
||||
)
|
||||
return _scheduler
|
||||
|
||||
|
||||
@@ -251,21 +293,38 @@ async def _refresh_telegram_chat_titles() -> None:
|
||||
|
||||
|
||||
async def _cleanup_old_events() -> None:
|
||||
"""Delete EventLog entries older than 90 days."""
|
||||
"""Delete EventLog / WebhookPayloadLog / ActionExecution rows older than the
|
||||
configured retention window. A retention of 0 disables the job.
|
||||
"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlmodel import delete
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..config import settings
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import EventLog
|
||||
from ..database.models import ActionExecution, EventLog, WebhookPayloadLog
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
days = settings.event_log_retention_days
|
||||
if days <= 0:
|
||||
_LOGGER.debug("Event log retention disabled (days=0); skipping cleanup")
|
||||
return
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
await session.exec(delete(EventLog).where(EventLog.created_at < cutoff))
|
||||
await session.exec(
|
||||
delete(WebhookPayloadLog).where(WebhookPayloadLog.created_at < cutoff)
|
||||
)
|
||||
await session.exec(
|
||||
delete(ActionExecution).where(ActionExecution.started_at < cutoff)
|
||||
)
|
||||
await session.commit()
|
||||
_LOGGER.info("Cleaned up event log entries older than %s", cutoff.date())
|
||||
_LOGGER.info(
|
||||
"Cleaned event_log / webhook_payload_log / action_execution older than %s",
|
||||
cutoff.date(),
|
||||
)
|
||||
|
||||
|
||||
async def _load_tracker_jobs() -> None:
|
||||
@@ -293,6 +352,8 @@ async def _load_tracker_jobs() -> None:
|
||||
)
|
||||
provider_types = {p.id: p.type for p in provider_result.all()}
|
||||
|
||||
tz = await _load_app_timezone()
|
||||
|
||||
for tracker in trackers:
|
||||
job_id = f"tracker_{tracker.id}"
|
||||
if scheduler.get_job(job_id):
|
||||
@@ -306,7 +367,7 @@ async def _load_tracker_jobs() -> None:
|
||||
cron_expr = filters.get("cron_expression", "")
|
||||
if cron_expr:
|
||||
try:
|
||||
_add_cron_job(scheduler, job_id, tracker.id, cron_expr, tracker.name)
|
||||
_add_cron_job(scheduler, job_id, tracker.id, cron_expr, tracker.name, tz)
|
||||
continue
|
||||
except Exception as e:
|
||||
_LOGGER.error(
|
||||
@@ -337,10 +398,18 @@ def _add_cron_job(
|
||||
tracker_id: int,
|
||||
cron_expression: str,
|
||||
tracker_name: str,
|
||||
tz: ZoneInfo,
|
||||
) -> None:
|
||||
"""Add a cron-triggered job for a scheduler-type tracker."""
|
||||
"""Add a cron-triggered job for a scheduler-type tracker.
|
||||
|
||||
``tz`` is the user-configured app timezone; without it APScheduler
|
||||
interprets the crontab in the host's local timezone, which surfaces as
|
||||
events firing at the "wrong" wall-clock time for operators in a non-UTC
|
||||
zone (see the companion fix in ``update_settings`` which reschedules on
|
||||
timezone changes).
|
||||
"""
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
trigger = CronTrigger.from_crontab(cron_expression)
|
||||
trigger = CronTrigger.from_crontab(cron_expression, timezone=tz)
|
||||
scheduler.add_job(
|
||||
_poll_tracker,
|
||||
trigger,
|
||||
@@ -349,7 +418,10 @@ def _add_cron_job(
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Scheduled tracker %d (%s) with cron: %s", tracker_id, tracker_name, cron_expression)
|
||||
_LOGGER.info(
|
||||
"Scheduled tracker %d (%s) with cron: %s [tz=%s]",
|
||||
tracker_id, tracker_name, cron_expression, tz.key,
|
||||
)
|
||||
|
||||
|
||||
async def schedule_tracker(
|
||||
@@ -371,7 +443,8 @@ async def schedule_tracker(
|
||||
|
||||
if cron_expression:
|
||||
try:
|
||||
_add_cron_job(scheduler, job_id, tracker_id, cron_expression, f"tracker-{tracker_id}")
|
||||
tz = await _load_app_timezone()
|
||||
_add_cron_job(scheduler, job_id, tracker_id, cron_expression, f"tracker-{tracker_id}", tz)
|
||||
return
|
||||
except Exception as e:
|
||||
_LOGGER.error("Invalid cron for tracker %d: %s — using interval", tracker_id, e)
|
||||
@@ -506,6 +579,8 @@ async def _load_action_jobs() -> None:
|
||||
)
|
||||
actions = result.all()
|
||||
|
||||
tz = await _load_app_timezone()
|
||||
|
||||
for action in actions:
|
||||
job_id = f"action_{action.id}"
|
||||
if scheduler.get_job(job_id):
|
||||
@@ -514,7 +589,7 @@ async def _load_action_jobs() -> None:
|
||||
if action.schedule_type == "cron" and action.schedule_cron:
|
||||
try:
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
trigger = CronTrigger.from_crontab(action.schedule_cron)
|
||||
trigger = CronTrigger.from_crontab(action.schedule_cron, timezone=tz)
|
||||
scheduler.add_job(
|
||||
_run_action,
|
||||
trigger,
|
||||
@@ -523,8 +598,8 @@ async def _load_action_jobs() -> None:
|
||||
replace_existing=True,
|
||||
)
|
||||
_LOGGER.info(
|
||||
"Scheduled action %d (%s) with cron: %s",
|
||||
action.id, action.name, action.schedule_cron,
|
||||
"Scheduled action %d (%s) with cron: %s [tz=%s]",
|
||||
action.id, action.name, action.schedule_cron, tz.key,
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
@@ -563,7 +638,8 @@ async def schedule_action(
|
||||
if schedule_type == "cron" and cron_expression:
|
||||
try:
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
trigger = CronTrigger.from_crontab(cron_expression)
|
||||
tz = await _load_app_timezone()
|
||||
trigger = CronTrigger.from_crontab(cron_expression, timezone=tz)
|
||||
scheduler.add_job(
|
||||
_run_action,
|
||||
trigger,
|
||||
@@ -571,7 +647,10 @@ async def schedule_action(
|
||||
args=[action_id],
|
||||
replace_existing=True,
|
||||
)
|
||||
_LOGGER.info("Scheduled action %d with cron: %s", action_id, cron_expression)
|
||||
_LOGGER.info(
|
||||
"Scheduled action %d with cron: %s [tz=%s]",
|
||||
action_id, cron_expression, tz.key,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
_LOGGER.error("Invalid cron for action %d: %s — using interval", action_id, e)
|
||||
@@ -596,6 +675,92 @@ async def unschedule_action(action_id: int) -> None:
|
||||
_LOGGER.info("Unscheduled action %d", action_id)
|
||||
|
||||
|
||||
async def reschedule_cron_jobs_for_timezone_change() -> None:
|
||||
"""Re-add every cron-triggered tracker/action job under the new app timezone.
|
||||
|
||||
Called by the admin settings endpoint after the ``timezone`` AppSetting is
|
||||
updated. APScheduler's ``CronTrigger`` freezes its timezone at construction
|
||||
time, so a timezone change has no effect on jobs already in the scheduler
|
||||
— we have to rebuild those jobs. Interval-triggered jobs are tz-agnostic
|
||||
and are left alone.
|
||||
"""
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import Action, NotificationTracker, ServiceProvider as ServiceProviderModel
|
||||
|
||||
engine = get_engine()
|
||||
scheduler = get_scheduler()
|
||||
tz = await _load_app_timezone()
|
||||
rescheduled = 0
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
# Trackers with cron scheduling (scheduler provider + schedule_type=cron).
|
||||
trackers = (await session.exec(
|
||||
select(NotificationTracker).where(NotificationTracker.enabled == True) # noqa: E712
|
||||
)).all()
|
||||
provider_ids = list({t.provider_id for t in trackers})
|
||||
provider_types: dict[int, str] = {}
|
||||
if provider_ids:
|
||||
rows = await session.exec(
|
||||
select(ServiceProviderModel).where(ServiceProviderModel.id.in_(provider_ids))
|
||||
)
|
||||
provider_types = {p.id: p.type for p in rows.all()}
|
||||
|
||||
for tracker in trackers:
|
||||
if provider_types.get(tracker.provider_id) != "scheduler":
|
||||
continue
|
||||
filters = tracker.filters or {}
|
||||
if filters.get("schedule_type") != "cron":
|
||||
continue
|
||||
cron_expr = filters.get("cron_expression", "")
|
||||
if not cron_expr:
|
||||
continue
|
||||
job_id = f"tracker_{tracker.id}"
|
||||
if scheduler.get_job(job_id):
|
||||
scheduler.remove_job(job_id)
|
||||
try:
|
||||
_add_cron_job(scheduler, job_id, tracker.id, cron_expr, tracker.name, tz)
|
||||
rescheduled += 1
|
||||
except Exception as e: # noqa: BLE001
|
||||
_LOGGER.error(
|
||||
"Failed to re-apply cron for tracker %d on tz change: %s",
|
||||
tracker.id, e,
|
||||
)
|
||||
|
||||
# Actions with cron schedules.
|
||||
actions = (await session.exec(
|
||||
select(Action).where(Action.enabled == True) # noqa: E712
|
||||
)).all()
|
||||
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
for action in actions:
|
||||
if action.schedule_type != "cron" or not action.schedule_cron:
|
||||
continue
|
||||
job_id = f"action_{action.id}"
|
||||
if scheduler.get_job(job_id):
|
||||
scheduler.remove_job(job_id)
|
||||
try:
|
||||
scheduler.add_job(
|
||||
_run_action,
|
||||
CronTrigger.from_crontab(action.schedule_cron, timezone=tz),
|
||||
id=job_id,
|
||||
args=[action.id],
|
||||
replace_existing=True,
|
||||
)
|
||||
rescheduled += 1
|
||||
except Exception as e: # noqa: BLE001
|
||||
_LOGGER.error(
|
||||
"Failed to re-apply cron for action %d on tz change: %s",
|
||||
action.id, e,
|
||||
)
|
||||
|
||||
_LOGGER.info(
|
||||
"Rescheduled %d cron job(s) for new app timezone %s", rescheduled, tz.key,
|
||||
)
|
||||
|
||||
|
||||
async def _run_action(action_id: int) -> None:
|
||||
"""Run an action (called by APScheduler)."""
|
||||
from .action_runner import run_action
|
||||
|
||||
@@ -11,11 +11,13 @@ CommandTrackerListeners with enabled CommandTrackers.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.log_context import bind_log_context
|
||||
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
||||
|
||||
from ..database.engine import get_engine
|
||||
@@ -125,7 +127,14 @@ async def stop_bot_if_unused(bot_id: int) -> None:
|
||||
|
||||
|
||||
def schedule_bot_polling(bot_id: int) -> None:
|
||||
"""Add a polling job for a bot (idempotent)."""
|
||||
"""Add a polling job for a bot (idempotent).
|
||||
|
||||
We schedule at a 30 s interval, but each tick calls ``getUpdates`` with
|
||||
``timeout=25`` — Telegram holds the connection open until either an
|
||||
update arrives or the timeout elapses, so in practice the bot streams
|
||||
updates with sub-second latency while consuming ~2 API calls / minute
|
||||
per bot (down from 20 under the old 3 s short-poll).
|
||||
"""
|
||||
scheduler = get_scheduler()
|
||||
job_id = f"telegram_poll_{bot_id}"
|
||||
if scheduler.get_job(job_id):
|
||||
@@ -133,13 +142,13 @@ def schedule_bot_polling(bot_id: int) -> None:
|
||||
scheduler.add_job(
|
||||
_poll_bot,
|
||||
"interval",
|
||||
seconds=3,
|
||||
seconds=30,
|
||||
id=job_id,
|
||||
args=[bot_id],
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Started polling for bot %d", bot_id)
|
||||
_LOGGER.info("Started polling for bot %d (long-poll, 25s timeout)", bot_id)
|
||||
|
||||
|
||||
def unschedule_bot_polling(bot_id: int) -> None:
|
||||
@@ -231,8 +240,10 @@ async def _poll_bot(bot_id: int) -> None:
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, bot_token)
|
||||
# Long-poll: hold connection open until an update arrives or 25 s
|
||||
# elapse. Drastically cuts API calls vs. 3 s short-poll.
|
||||
result = await client.get_updates(
|
||||
offset=offset + 1 if offset else None, limit=50,
|
||||
offset=offset + 1 if offset else None, limit=50, timeout=25,
|
||||
)
|
||||
if not result.get("success"):
|
||||
err_text = str(result.get("error") or "")
|
||||
@@ -289,29 +300,64 @@ async def _poll_bot(bot_id: int) -> None:
|
||||
|
||||
# Dispatch commands (only if chat has commands enabled)
|
||||
if text and text.startswith("/"):
|
||||
try:
|
||||
async with AsyncSession(engine) as cmd_session:
|
||||
chat_row = (await cmd_session.exec(
|
||||
select(TelegramChat).where(
|
||||
TelegramChat.bot_id == bot_obj.id,
|
||||
TelegramChat.chat_id == chat_id,
|
||||
from ..commands.parser import parse_command
|
||||
cmd_name, _, _ = parse_command(text)
|
||||
update_id = update.get("update_id")
|
||||
message_id = message.get("message_id")
|
||||
request_id = f"tg:{update_id}" if update_id is not None else f"tg:msg{message_id}"
|
||||
with bind_log_context(
|
||||
request_id=request_id,
|
||||
command=cmd_name or "-",
|
||||
chat_id=chat_id,
|
||||
bot_id=bot_obj.id,
|
||||
):
|
||||
started = time.monotonic()
|
||||
try:
|
||||
async with AsyncSession(engine) as cmd_session:
|
||||
chat_row = (await cmd_session.exec(
|
||||
select(TelegramChat).where(
|
||||
TelegramChat.bot_id == bot_obj.id,
|
||||
TelegramChat.chat_id == chat_id,
|
||||
)
|
||||
)).first()
|
||||
if not chat_row or not chat_row.commands_enabled:
|
||||
_LOGGER.info(
|
||||
"Command ignored — commands disabled (poll) for bot=%s chat=%s",
|
||||
bot_obj.id, chat_id,
|
||||
)
|
||||
continue
|
||||
effective_lang = chat_row.language_override or msg_language
|
||||
_LOGGER.info("Command received (poll): /%s args=%r lang=%s", cmd_name, text[:200], effective_lang)
|
||||
async with telegram_chat_action(
|
||||
bot_token, chat_id, classify_command_chat_action(text),
|
||||
):
|
||||
responses = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
|
||||
if not responses:
|
||||
_LOGGER.info(
|
||||
"Command produced no response (cmd=%r, poll) after %.0f ms",
|
||||
cmd_name, (time.monotonic() - started) * 1000,
|
||||
)
|
||||
continue
|
||||
text_count = sum(1 for r in responses if r.text)
|
||||
media_count = sum(len(r.media or []) for r in responses)
|
||||
_LOGGER.info(
|
||||
"Command dispatching %d response(s): text=%d media_items=%d",
|
||||
len(responses), text_count, media_count,
|
||||
)
|
||||
)).first()
|
||||
if not chat_row or not chat_row.commands_enabled:
|
||||
continue
|
||||
effective_lang = chat_row.language_override or msg_language
|
||||
message_id = message.get("message_id")
|
||||
async with telegram_chat_action(
|
||||
bot_token, chat_id, classify_command_chat_action(text),
|
||||
):
|
||||
responses = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
|
||||
if responses:
|
||||
for resp in responses:
|
||||
if resp.text:
|
||||
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||
if resp.media:
|
||||
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||
except Exception:
|
||||
_LOGGER.error("Error handling command from bot %d", bot_id, exc_info=True)
|
||||
_LOGGER.info(
|
||||
"Command /%s completed in %.0f ms (responses=%d media=%d)",
|
||||
cmd_name, (time.monotonic() - started) * 1000,
|
||||
len(responses), media_count,
|
||||
)
|
||||
except Exception:
|
||||
_LOGGER.exception(
|
||||
"Error handling command /%s from bot %d after %.0f ms",
|
||||
cmd_name, bot_id, (time.monotonic() - started) * 1000,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -246,6 +246,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
name=provider_name,
|
||||
tracker_name=tracker_name,
|
||||
custom_variables=custom_vars,
|
||||
timezone_name=app_tz,
|
||||
)
|
||||
events, new_state = await sched.poll(collection_ids, state_dict)
|
||||
elif provider_type == "nut":
|
||||
@@ -317,6 +318,26 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
|
||||
for event in events:
|
||||
assets_count = event.added_count or event.removed_count or 0
|
||||
details: dict[str, Any] = {
|
||||
"added_count": event.added_count,
|
||||
"removed_count": event.removed_count,
|
||||
"provider_type": event.provider_type.value,
|
||||
}
|
||||
# Scheduler/periodic events carry the schedule context in ``extra``
|
||||
# (cron expression, interval, timezone, fire count). Surface that
|
||||
# in the event log so the dashboard and audit queries can show
|
||||
# *why* the event fired, not just that it did.
|
||||
if event.event_type.value == "scheduled_message":
|
||||
sched_type = tracker_filters.get("schedule_type", "interval")
|
||||
details["schedule_type"] = sched_type
|
||||
if sched_type == "cron":
|
||||
details["cron_expression"] = tracker_filters.get("cron_expression", "")
|
||||
else:
|
||||
details["interval_seconds"] = tracker.scan_interval
|
||||
details["timezone"] = app_tz
|
||||
fire_count = event.extra.get("fire_count") if event.extra else None
|
||||
if fire_count is not None:
|
||||
details["fire_count"] = fire_count
|
||||
log = EventLog(
|
||||
user_id=tracker.user_id,
|
||||
tracker_id=tracker_id,
|
||||
@@ -327,11 +348,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
collection_id=event.collection_id,
|
||||
collection_name=event.collection_name,
|
||||
assets_count=assets_count,
|
||||
details={
|
||||
"added_count": event.added_count,
|
||||
"removed_count": event.removed_count,
|
||||
"provider_type": event.provider_type.value,
|
||||
},
|
||||
details=details,
|
||||
)
|
||||
session.add(log)
|
||||
|
||||
@@ -352,7 +369,13 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
|
||||
if events and link_data:
|
||||
url_cache, asset_cache = await _get_telegram_caches()
|
||||
dispatcher = NotificationDispatcher(url_cache=url_cache, asset_cache=asset_cache)
|
||||
from .http_session import get_http_session
|
||||
shared_session = await get_http_session()
|
||||
dispatcher = NotificationDispatcher(
|
||||
url_cache=url_cache,
|
||||
asset_cache=asset_cache,
|
||||
session=shared_session,
|
||||
)
|
||||
for event in events:
|
||||
_LOGGER.info(
|
||||
"Dispatching event %s for %s (added=%d removed=%d)",
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Shared pytest fixtures.
|
||||
|
||||
We set the required env vars before any ``notify_bridge_server`` module is
|
||||
imported so ``Settings()`` passes its startup validation and opens the DB
|
||||
in a writable temp directory instead of the production ``/data`` default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Provision a writable temp data dir BEFORE the server package is imported —
|
||||
# Settings() materializes at import time, so env-var overrides have to land
|
||||
# here (conftest.py) to be effective.
|
||||
_TMP = Path(tempfile.mkdtemp(prefix="notify-bridge-tests-"))
|
||||
os.environ["NOTIFY_BRIDGE_DATA_DIR"] = str(_TMP)
|
||||
|
||||
os.environ.setdefault(
|
||||
"NOTIFY_BRIDGE_SECRET_KEY",
|
||||
"pytest-secret-key-" + "x" * 40,
|
||||
)
|
||||
os.environ.setdefault("NOTIFY_BRIDGE_DEBUG", "false")
|
||||
os.environ.setdefault("NOTIFY_BRIDGE_CORS_ALLOWED_ORIGINS", "http://localhost:8420")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tmp_data_dir() -> Path:
|
||||
"""Expose the already-provisioned temp dir to tests that want the path."""
|
||||
return _TMP
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Unit tests for the Settings validator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from notify_bridge_server.config import Settings, _FORBIDDEN_SECRETS
|
||||
|
||||
|
||||
_GOOD = "a" * 40
|
||||
|
||||
|
||||
def _make(**kwargs) -> Settings:
|
||||
defaults = dict(secret_key=_GOOD, cors_allowed_origins="http://localhost:8420")
|
||||
defaults.update(kwargs)
|
||||
return Settings(**defaults)
|
||||
|
||||
|
||||
class TestSecretKey:
|
||||
def test_accepts_long_random_key(self) -> None:
|
||||
_make()
|
||||
|
||||
def test_rejects_default(self) -> None:
|
||||
with pytest.raises(ValueError, match="SECURITY"):
|
||||
_make(secret_key="change-me-in-production")
|
||||
|
||||
def test_rejects_known_dev_keys(self) -> None:
|
||||
for bad in _FORBIDDEN_SECRETS:
|
||||
with pytest.raises(ValueError):
|
||||
_make(secret_key=bad)
|
||||
|
||||
def test_rejects_short_key(self) -> None:
|
||||
with pytest.raises(ValueError, match="32 characters"):
|
||||
_make(secret_key="short")
|
||||
|
||||
|
||||
class TestCors:
|
||||
def test_rejects_wildcard(self) -> None:
|
||||
with pytest.raises(ValueError, match="wildcard"):
|
||||
_make(cors_allowed_origins="*")
|
||||
|
||||
def test_rejects_missing_scheme(self) -> None:
|
||||
with pytest.raises(ValueError, match="scheme"):
|
||||
_make(cors_allowed_origins="example.com")
|
||||
|
||||
def test_accepts_multiple(self) -> None:
|
||||
cfg = _make(cors_allowed_origins="http://localhost:8420,https://example.com")
|
||||
assert "http://localhost:8420" in cfg.cors_allowed_origins
|
||||
|
||||
|
||||
class TestNumericValidation:
|
||||
def test_rejects_zero_access_token_expiry(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_make(access_token_expire_minutes=0)
|
||||
|
||||
def test_rejects_invalid_port(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_make(port=0)
|
||||
with pytest.raises(ValueError):
|
||||
_make(port=70000)
|
||||
|
||||
def test_rejects_negative_retention(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
_make(event_log_retention_days=-1)
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Discord client 429-retry bounding."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from aioresponses import aioresponses
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.notifications.discord.client import DiscordClient
|
||||
|
||||
|
||||
WEBHOOK = "https://discord.com/api/webhooks/123/abc"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bounded_retries_on_persistent_429() -> None:
|
||||
"""If every response is 429, the client gives up after _MAX_RETRIES."""
|
||||
with aioresponses() as mocked:
|
||||
mocked.post(WEBHOOK, status=429, headers={"Retry-After": "0.001"}, repeat=True)
|
||||
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
client = DiscordClient(sess)
|
||||
result = await client.send(WEBHOOK, "hello")
|
||||
|
||||
assert result["success"] is False
|
||||
# Either the custom "Rate limited" message or the bare HTTP 429 from the
|
||||
# final attempt — both indicate bounded retries without infinite recursion.
|
||||
assert "429" in result["error"] or "Rate limited" in result["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_caps_retry_after() -> None:
|
||||
"""A malicious Retry-After: 99999 must not pin the task for hours."""
|
||||
with aioresponses() as mocked:
|
||||
# First call: absurd Retry-After. Second call: success.
|
||||
mocked.post(WEBHOOK, status=429, headers={"Retry-After": "99999"})
|
||||
mocked.post(WEBHOOK, status=204)
|
||||
|
||||
async with aiohttp.ClientSession() as sess:
|
||||
client = DiscordClient(sess)
|
||||
# Override the cap to something trivial so the test completes fast.
|
||||
client._MAX_RETRY_AFTER = 0.001 # type: ignore[attr-defined]
|
||||
result = await client.send(WEBHOOK, "hello")
|
||||
|
||||
assert result["success"] is True
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Smoke test: app imports, /api/health returns 200, version string present."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_health_endpoint(tmp_data_dir) -> None: # noqa: ARG001 — fixture applies env
|
||||
from notify_bridge_server.main import app
|
||||
|
||||
# TestClient runs the lifespan on enter/exit, so migrations run once
|
||||
# against the temp data dir — a genuine integration smoke check.
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/health")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "ok"
|
||||
assert body["version"] != "0.0.0+unknown"
|
||||
assert "." in body["version"] # looks like a real version
|
||||
|
||||
|
||||
def test_ready_endpoint(tmp_data_dir) -> None: # noqa: ARG001
|
||||
from notify_bridge_server.main import app
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/ready")
|
||||
# By the time TestClient yields, lifespan startup has completed.
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ready"
|
||||
|
||||
|
||||
def test_health_is_anonymous(tmp_data_dir) -> None: # noqa: ARG001
|
||||
"""/api/health must not require auth — the Docker healthcheck depends on it."""
|
||||
from notify_bridge_server.main import app
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/health")
|
||||
assert resp.status_code == 200
|
||||
@@ -0,0 +1,74 @@
|
||||
"""JWT encode/decode round-trips."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
|
||||
from notify_bridge_server.auth.jwt import (
|
||||
ALGORITHM,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
)
|
||||
from notify_bridge_server.config import settings
|
||||
|
||||
|
||||
def test_access_token_round_trip() -> None:
|
||||
token = create_access_token(user_id=1, role="admin", token_version=3)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == "1"
|
||||
assert payload["type"] == "access"
|
||||
assert payload["role"] == "admin"
|
||||
assert payload["ver"] == 3
|
||||
assert payload["iss"] == settings.jwt_issuer
|
||||
assert payload["aud"] == settings.jwt_audience
|
||||
|
||||
|
||||
def test_refresh_token_round_trip() -> None:
|
||||
token = create_refresh_token(user_id=7, token_version=2)
|
||||
payload = decode_token(token)
|
||||
assert payload["type"] == "refresh"
|
||||
assert payload["sub"] == "7"
|
||||
|
||||
|
||||
def test_decode_rejects_wrong_audience() -> None:
|
||||
"""A token signed with our key but for a different audience is rejected."""
|
||||
now = datetime.now(timezone.utc)
|
||||
forged = pyjwt.encode(
|
||||
{
|
||||
"iss": settings.jwt_issuer,
|
||||
"aud": "other-service",
|
||||
"sub": "1",
|
||||
"type": "access",
|
||||
"ver": 1,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=5),
|
||||
},
|
||||
settings.secret_key,
|
||||
algorithm=ALGORITHM,
|
||||
)
|
||||
with pytest.raises(pyjwt.InvalidAudienceError):
|
||||
decode_token(forged)
|
||||
|
||||
|
||||
def test_decode_rejects_none_alg() -> None:
|
||||
"""An ``alg: none`` token must never be accepted."""
|
||||
now = datetime.now(timezone.utc)
|
||||
forged = pyjwt.encode(
|
||||
{
|
||||
"iss": settings.jwt_issuer,
|
||||
"aud": settings.jwt_audience,
|
||||
"sub": "1",
|
||||
"type": "access",
|
||||
"ver": 1,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=5),
|
||||
},
|
||||
"",
|
||||
algorithm="none",
|
||||
)
|
||||
with pytest.raises(pyjwt.InvalidAlgorithmError):
|
||||
decode_token(forged)
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Pre-migration snapshot: atomic copy + retention pruning."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from notify_bridge_server.database.snapshot import (
|
||||
prune_old_snapshots,
|
||||
snapshot_and_prune,
|
||||
snapshot_database,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def sqlite_engine(tmp_path: Path):
|
||||
"""Tiny SQLite DB with one table + one row, closed cleanly after the test."""
|
||||
db_path = tmp_path / "app.db"
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
||||
async with engine.begin() as conn:
|
||||
await conn.execute(text("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT)"))
|
||||
await conn.execute(text("INSERT INTO t (v) VALUES ('seed')"))
|
||||
yield engine, db_path, tmp_path / "backups"
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
class TestSnapshot:
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_consistent_copy(self, sqlite_engine) -> None:
|
||||
engine, _db, backups = sqlite_engine
|
||||
dest = await snapshot_database(engine, backups)
|
||||
assert dest is not None
|
||||
assert dest.exists()
|
||||
# Can open the snapshot and see the seed row — proves it's a real DB copy.
|
||||
copy = create_async_engine(f"sqlite+aiosqlite:///{dest}")
|
||||
async with copy.connect() as c:
|
||||
result = await c.execute(text("SELECT v FROM t"))
|
||||
rows = result.all()
|
||||
await copy.dispose()
|
||||
assert rows == [("seed",)]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_db_missing(self, tmp_path: Path) -> None:
|
||||
# Engine pointing at a path that doesn't exist yet.
|
||||
engine = create_async_engine(
|
||||
f"sqlite+aiosqlite:///{tmp_path / 'does-not-exist.db'}"
|
||||
)
|
||||
try:
|
||||
dest = await snapshot_database(engine, tmp_path / "backups")
|
||||
finally:
|
||||
await engine.dispose()
|
||||
assert dest is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_unsafe_label(self, sqlite_engine) -> None:
|
||||
engine, _db, backups = sqlite_engine
|
||||
dest = await snapshot_database(engine, backups, label="bad'; DROP TABLE t;--")
|
||||
assert dest is None
|
||||
|
||||
|
||||
class TestPrune:
|
||||
def _make_snapshot(self, backups: Path, age_seconds: int) -> Path:
|
||||
backups.mkdir(parents=True, exist_ok=True)
|
||||
ts = datetime.now(timezone.utc) - timedelta(seconds=age_seconds)
|
||||
name = f"pre-migrate-{ts.strftime('%Y-%m-%dT%H-%M-%S')}.db"
|
||||
p = backups / name
|
||||
p.write_bytes(b"x")
|
||||
mtime = ts.timestamp()
|
||||
import os
|
||||
os.utime(p, (mtime, mtime))
|
||||
return p
|
||||
|
||||
def test_keeps_n_newest(self, tmp_path: Path) -> None:
|
||||
backups = tmp_path / "backups"
|
||||
for age in (100, 80, 60, 40, 20, 0):
|
||||
self._make_snapshot(backups, age)
|
||||
|
||||
deleted = prune_old_snapshots(backups, keep=3)
|
||||
remaining = sorted(backups.glob("pre-migrate-*.db"))
|
||||
assert len(deleted) == 3
|
||||
assert len(remaining) == 3
|
||||
|
||||
def test_keep_zero_deletes_all(self, tmp_path: Path) -> None:
|
||||
backups = tmp_path / "backups"
|
||||
for age in (30, 20, 10):
|
||||
self._make_snapshot(backups, age)
|
||||
prune_old_snapshots(backups, keep=0)
|
||||
assert list(backups.glob("pre-migrate-*.db")) == []
|
||||
|
||||
def test_missing_dir_is_noop(self, tmp_path: Path) -> None:
|
||||
assert prune_old_snapshots(tmp_path / "never-created", keep=5) == []
|
||||
|
||||
|
||||
class TestSnapshotAndPrune:
|
||||
@pytest.mark.asyncio
|
||||
async def test_keep_zero_disables(self, sqlite_engine) -> None:
|
||||
engine, _db, backups = sqlite_engine
|
||||
result = await snapshot_and_prune(engine, backups, keep=0)
|
||||
assert result is None
|
||||
assert not backups.exists() or list(backups.glob("*.db")) == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end(self, sqlite_engine) -> None:
|
||||
engine, _db, backups = sqlite_engine
|
||||
# Run twice — second run should keep both snapshots (keep=5).
|
||||
a = await snapshot_and_prune(engine, backups, keep=5)
|
||||
# Guarantee distinct filenames (timestamp has second resolution).
|
||||
await asyncio.sleep(1.05)
|
||||
b = await snapshot_and_prune(engine, backups, keep=5)
|
||||
assert a and b and a != b
|
||||
assert a.exists() and b.exists()
|
||||
@@ -0,0 +1,59 @@
|
||||
"""SSRF guard regression tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from notify_bridge_core.notifications.ssrf import (
|
||||
UnsafeURLError,
|
||||
avalidate_outbound_url,
|
||||
validate_outbound_url,
|
||||
)
|
||||
|
||||
|
||||
class TestScheme:
|
||||
def test_rejects_file_scheme(self) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
validate_outbound_url("file:///etc/passwd")
|
||||
|
||||
def test_rejects_gopher(self) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
validate_outbound_url("gopher://example.com/")
|
||||
|
||||
def test_accepts_https(self) -> None:
|
||||
# A well-known public host — validated via real DNS so this test is
|
||||
# skipped when offline.
|
||||
try:
|
||||
validate_outbound_url("https://example.com/")
|
||||
except UnsafeURLError as err:
|
||||
if "DNS" in str(err):
|
||||
pytest.skip("No DNS in test environment")
|
||||
raise
|
||||
|
||||
|
||||
class TestBlockedRanges:
|
||||
@pytest.mark.parametrize(
|
||||
"url",
|
||||
[
|
||||
"http://127.0.0.1/",
|
||||
"http://10.0.0.1/",
|
||||
"http://192.168.1.1/",
|
||||
"http://169.254.169.254/latest/meta-data/",
|
||||
"http://[::1]/",
|
||||
],
|
||||
)
|
||||
def test_rejects_literal_private(self, url: str) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
validate_outbound_url(url)
|
||||
|
||||
|
||||
class TestAsyncValidator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_rejects_loopback(self) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
await avalidate_outbound_url("http://127.0.0.1/")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_rejects_bad_scheme(self) -> None:
|
||||
with pytest.raises(UnsafeURLError):
|
||||
await avalidate_outbound_url("file:///etc/passwd")
|
||||
@@ -24,7 +24,7 @@ fi
|
||||
|
||||
# Start backend
|
||||
export NOTIFY_BRIDGE_DATA_DIR=./test-data
|
||||
export NOTIFY_BRIDGE_SECRET_KEY=test-secret-key-minimum-32-chars
|
||||
export NOTIFY_BRIDGE_SECRET_KEY=dev-only-pwIOUsKmfn4CYWQ9hCRs5lmI3GgrVlXSu2nqFzGW
|
||||
# Dev targets (homelab Immich / Gitea / etc.) live on RFC1918 ranges; the SSRF
|
||||
# guard rejects private addresses by default, which would make trackers fail.
|
||||
export NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS=1
|
||||
|
||||
Reference in New Issue
Block a user