feat: production readiness — security, perf, bug fixes, bridge self-monitoring

Comprehensive multi-area pass driven by a parallel 8-agent production
review. Frontend, backend, database, security, performance, operational,
plus a new self-monitoring feature.

## Critical fixes
- Planka webhook: reads bounded raw body (was NameError on every call)
- HA quiet hours: ha_state_changed/automation_triggered/service_called/
  event_fired added to deferrable set (were silently dropped)
- DNS-rebinding SSRF: PinnedResolver wired into shared aiohttp session
- Telegram inbound webhook: secret now mandatory (401 without)
- Generic webhook: auth_mode="none" requires explicit
  acknowledge_unauthenticated=true; per-IP rate limit 60/min
- svelte-check: 5 null-narrowing errors in EventDetailModal fixed
- Provider hardcoding: Immich-only block extracted to descriptor
  featureDiscoveryHint
- command_sync: snapshot+expunge bot before exiting AsyncSession

## Bug fixes
- notifier asyncio.gather(return_exceptions=True) — one bad chat no longer
  cancels peer sends
- NotificationDispatcher hoisted out of per-tracker loop
- Provider credential resolution unified across all 5 dispatch sites
- HA asyncio.shield now drains inner task on cancellation
- Provider construction switched from if/elif ladder to factory registry
- NUT first poll seeds silently (no spurious ups_on_battery)
- Quiet-hours gate: event-type-disabled now wins over deferral
- APScheduler drain job ID resolution upgraded to seconds
- HA on_status_change wired through to EventLog
- Webhook payload rollback failures now logged (not swallowed)
- Batched receivers/chats/bots in load_link_data (was per-target N+1)
- flag_modified on JSON column reassignments in deferred_dispatch

## Database
- UNIQUE indexes on service_provider.webhook_token,
  telegram_bot.webhook_path_id, partial UNIQUE on telegram_bot.bot_id,
  telegram_chat(bot_id, chat_id), notification_tracker_target unique link,
  partial UNIQUE on bridge_self provider per user
- Composite ix_event_log_user_event_type_created index
- save_chat_from_webhook switched to ON CONFLICT DO UPDATE
- ondelete=CASCADE on user-id FKs (model annotation; app-side cascade
  delete added for existing data)
- delete_notification_tracker converted from N+1 to bulk DELETE/UPDATE
- Module-level asyncio.Lock replaced with lazy _get_lock() pattern
- VACUUM INTO snapshot now PRAGMA integrity_check verified

## Performance
- Jinja2 template compilation LRU cached (lru_cache maxsize=512)
- Per-locale render cache in NotificationDispatcher (skips re-rendering
  identical content for receivers sharing a locale)
- Tracker list cached per provider_id with 5s TTL + explicit invalidation
  on tracker CRUD (relieves HA chat-bus rate query pressure)
- Nav-counts collapsed from 16 round-trips to single UNION ALL
- HA event_log: skip persisting empty assets_added/removed events

## Security hardening
- Mass-assignment guard on Action create/update; cron sub-minute reject
- Backup JSON depth/node-count cap (depth ≤ 10, nodes ≤ 100k)
- _sanitize_config extended to all JSON-typed fields on backup import
- Telegram _safe_get walks redirects manually with SSRF revalidation
- Bcrypt 72-byte password length cap with clear 422
- Webhook payload body redaction; sensitive substring set extended with
  oauth/client_secret/webhook_secret/csrf in both header filter and
  template extras filter

## Frontend
- 76 catch (err: any) sites converted to errMsg(err) helper
- globalProviderFilter: pure getter; reconciliation moved to one-time
  $effect in +layout
- Provider-filter binding: removed paired $effects + _syncingFilter flag,
  now one-way derived
- entity-cache: separate _refreshing flag for background re-fetches
- api.ts 401 handling: AuthRedirectError class + dedup _redirecting flag,
  goto() instead of window.location.href
- a11y: aria-expanded on mobile More, role=switch + aria-checked on
  Telegram bot toggles

## Tests & operations
- CI pytest gate added to .gitea/workflows/build.yml + release.yml
  (wheel-built install to dodge editable-install slowness)
- /api/ready upgraded to deep healthcheck (db SELECT 1, scheduler.running,
  HA supervisor presence) returning {ready, checks, errors, version}
- /api/metrics endpoint with prometheus_client (deferred_pending,
  event_log_total, dispatch_duration, poll_failures, send_failures)
- New OPERATIONS.md covering deploy, healthchecks, metrics, backup/restore
  procedures, log handling, common scenarios, upgrade flow
- New tests: test_bridge_self (11), test_gitea_parser (9),
  test_planka_parser (6), test_immich_change_detector (6),
  test_backup_roundtrip (1)

## New feature: bridge self-monitoring
- New bridge_self provider type — internal sink for bridge health events
- Three event types: bridge_self_poll_failures (consecutive tracker poll
  failures), bridge_self_deferred_backlog (pending count crosses
  threshold), bridge_self_target_failures (consecutive 5xx/network
  failures per target)
- Per-user thresholds (defaults: 3 / 100 / 5) configurable via the
  provider config form
- Auto-seeded on user create + /setup + boot backfill for existing users
- Anti-spam: counters reset after emission; backlog uses transition latch
- Self-loop guard: bridge_self failures don't count toward target-failure
  thresholds (logged only) — wire to your own Telegram/Email/Matrix to
  get notified when polls/dispatches/sends fail
- 6 default templates (3 events × 2 locales), tracking config columns
  with backfill migration, frontend descriptor (excluded from "create
  provider" wizard since auto-managed)

Operator-visible behavior changes (call out in release notes):
- NOTIFY_BRIDGE_TELEGRAM_WEBHOOK_SECRET now REQUIRED for webhook mode
- Existing webhook providers with auth_mode="none" need explicit opt-in
- Generic webhook endpoint rate-limited 60/min per source IP
- HA disconnect/reconnect writes ha_status_* EventLog rows
- Every user gets a bridge_self provider — wire it to a target to
  receive failure alerts

Pre-existing test failures (test_ssrf, test_release_provider) on
Python 3.13 are unrelated; CI runs on 3.12.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-16 02:16:49 +03:00
parent 22127e2a59
commit 10d30fc956
97 changed files with 5423 additions and 821 deletions
@@ -71,6 +71,12 @@ class EventType(str, Enum):
HA_SERVICE_CALLED = "ha_service_called"
HA_EVENT_FIRED = "ha_event_fired"
# Bridge self-monitoring events — emitted by the bridge itself when
# internal failures cross configured thresholds.
BRIDGE_SELF_POLL_FAILURES = "bridge_self_poll_failures"
BRIDGE_SELF_DEFERRED_BACKLOG = "bridge_self_deferred_backlog"
BRIDGE_SELF_TARGET_FAILURES = "bridge_self_target_failures"
@dataclass
class ServiceEvent:
@@ -107,6 +107,12 @@ class NotificationDispatcher:
# Optional shared session owned by the caller; when supplied we reuse
# its connection pool instead of opening a fresh per-dispatch session.
self._shared_session = session
# Per-dispatch render cache, keyed by locale. Populated by
# ``_send_to_target`` and consumed inside ``_message_for_receiver``
# so a 100-receiver fan-out renders each unique locale once.
# Initialized to empty so handlers called outside the normal
# dispatch path (tests) still see a valid dict.
self._render_cache: dict[str, str] = {}
@contextlib.asynccontextmanager
async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]:
@@ -198,20 +204,49 @@ class NotificationDispatcher:
def _message_for_receiver(
self, receiver: Receiver, default_message: str,
event: ServiceEvent, target: TargetConfig,
cache: dict[str, str] | None = None,
) -> str:
if receiver.locale and receiver.locale != target.locale:
return self._render_message(event, target, receiver.locale)
return default_message
"""Render message respecting receiver locale, with optional cache.
The ``cache`` dict (typically created in ``_send_to_target`` and
threaded through the per-channel ``_send_*`` handlers) memoizes
per-locale renders so a 100-receiver fan-out with two locales
renders twice instead of one hundred times.
"""
loc = receiver.locale or target.locale
if loc == target.locale:
return default_message
if cache is not None:
cached = cache.get(loc)
if cached is not None:
return cached
rendered = self._render_message(event, target, loc)
cache[loc] = rendered
return rendered
return self._render_message(event, target, loc)
async def _send_to_target(
self, event: ServiceEvent, target: TargetConfig
) -> dict[str, Any]:
"""Dispatch to a single target via the registered handler."""
"""Dispatch to a single target via the registered handler.
Builds a per-locale render cache once and threads it through the
send handler. The cache is keyed by receiver locale; the default
locale's render lives in ``default_message`` and is short-circuited
before any cache lookup.
"""
default_message = self._render_message(event, target, target.locale)
send_method = _PROVIDER_HANDLERS.get(target.type)
if send_method is None:
return {"success": False, "error": f"Unknown target type: {target.type}"}
return await send_method(self, target, default_message, event)
# Stash the cache on the dispatcher instance for the duration of
# this dispatch — handlers pick it up via _message_for_receiver.
# Avoids changing every _send_* signature.
self._render_cache: dict[str, str] = {}
try:
return await send_method(self, target, default_message, event)
finally:
self._render_cache = {}
# ------------------------------------------------------------------
# Asset preload (Telegram-specific)
@@ -352,7 +387,7 @@ class NotificationDispatcher:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, TelegramReceiver) or not receiver.chat_id:
return {"success": False, "error": "Invalid telegram receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
text_result = await client.send_message(
chat_id=receiver.chat_id,
text=message,
@@ -407,7 +442,7 @@ class NotificationDispatcher:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
return {"success": False, "error": "Invalid webhook receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
payload = {
"message": message,
"event_type": event.event_type.value,
@@ -450,7 +485,7 @@ class NotificationDispatcher:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, EmailReceiver) or not receiver.email:
return {"success": False, "error": "Invalid email receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
# body_html=None lets EmailClient build a safely-escaped HTML
# alternative from body_text instead of trusting user content.
return await email_client.send(
@@ -479,7 +514,7 @@ class NotificationDispatcher:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
return {"success": False, "error": "Invalid discord receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
return await client.send(receiver.webhook_url, message, username=username)
results = await self._fan_out(target.receivers, send_one)
@@ -501,7 +536,7 @@ class NotificationDispatcher:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
return {"success": False, "error": "Invalid slack receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
return await client.send(receiver.webhook_url, message, username=username)
results = await self._fan_out(target.receivers, send_one)
@@ -530,7 +565,7 @@ class NotificationDispatcher:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
return {"success": False, "error": "Invalid ntfy receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
return await client.send(
server_url, receiver.topic, message,
title=title, priority=receiver.priority, auth_token=auth_token,
@@ -563,7 +598,7 @@ class NotificationDispatcher:
async def send_one(receiver: Receiver) -> dict[str, Any]:
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
return {"success": False, "error": "Invalid matrix receiver"}
message = self._message_for_receiver(receiver, default_message, event, target)
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
# body_html is the same plain text — Matrix accepts the
# raw message as both ``body`` and ``formatted_body``.
# If templates emit HTML in the future, generate a
@@ -222,21 +222,48 @@ class TelegramClient:
"""SSRF-guarded GET that returns ``(data, error)``.
Validates the URL via ``avalidate_outbound_url`` before any HTTP
traffic. Errors are returned (not raised) and stripped of any
embedded secrets before they propagate to the operator-visible
result dict.
traffic. Redirects are walked manually so each ``Location`` is
re-validated — without this an attacker-controlled origin could
302 to a private-IP target after the initial guard passed.
Errors are returned (not raised) and stripped of any embedded
secrets before they propagate to the operator-visible result
dict.
"""
max_redirects = 3
current_url = url
try:
await avalidate_outbound_url(url)
await avalidate_outbound_url(current_url)
except UnsafeURLError as err:
return None, f"Unsafe URL: {redact_exc(err)}"
try:
async with self._session.get(
url, headers=headers or {}, timeout=_DOWNLOAD_TIMEOUT,
) as resp:
if resp.status != 200:
return None, f"HTTP {resp.status}"
return await resp.read(), None
for _ in range(max_redirects + 1):
async with self._session.get(
current_url,
headers=headers or {},
timeout=_DOWNLOAD_TIMEOUT,
allow_redirects=False,
) as resp:
if resp.status in (301, 302, 303, 307, 308):
loc = resp.headers.get("Location")
if not loc:
return None, f"HTTP {resp.status} without Location header"
# ``resp.url`` is a yarl.URL; ``.join`` resolves
# relative redirects (``/foo/bar``) against it.
from yarl import URL as _URL
try:
next_url = str(resp.url.join(_URL(loc)))
except (ValueError, TypeError):
return None, "Malformed redirect Location"
try:
await avalidate_outbound_url(next_url)
except UnsafeURLError as err:
return None, f"Unsafe redirect: {redact_exc(err)}"
current_url = next_url
continue
if resp.status != 200:
return None, f"HTTP {resp.status}"
return await resp.read(), None
return None, f"Too many redirects (>{max_redirects})"
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err:
return None, redact_exc(err)
@@ -22,6 +22,7 @@ class ServiceProviderType(str, Enum):
GOOGLE_PHOTOS = "google_photos"
WEBHOOK = "webhook"
HOME_ASSISTANT = "home_assistant"
BRIDGE_SELF = "bridge_self"
# Callback signature for push-style providers: a coroutine that accepts a
@@ -0,0 +1,39 @@
"""Bridge self-monitoring service provider.
Unlike external providers (Immich, Gitea, NUT, ...), the ``bridge_self``
provider does not connect to any remote service. Its sole purpose is to
give operators a configurable surface (thresholds + notification slots
+ trackers + targets) for events that the bridge itself emits when its
internal subsystems fail.
Three failure conditions are surfaced as :class:`ServiceEvent` instances
through the same dispatch pipeline that all other providers use:
* ``bridge_self_poll_failures`` — N consecutive poll failures for
any tracker exceed the configured threshold.
* ``bridge_self_deferred_backlog`` — pending ``deferred_dispatch`` row
count crosses the configured threshold.
* ``bridge_self_target_failures`` — N consecutive 5xx / network failures
for a single notification target.
Events are constructed by ``services/bridge_self.py`` on the server side
(it owns DB access for looking up the bridge_self provider per user)
and then fed into ``dispatch_provider_event`` like any other event.
"""
from notify_bridge_core.providers.base import ServiceProviderType
from notify_bridge_core.templates.variables import registry
from .event_parser import build_event
from .provider import BRIDGE_SELF_VARIABLES, BridgeSelfServiceProvider
# Register variables so the validator and template-vars API see them.
registry.register_provider_variables(
ServiceProviderType.BRIDGE_SELF, BRIDGE_SELF_VARIABLES,
)
__all__ = [
"BRIDGE_SELF_VARIABLES",
"BridgeSelfServiceProvider",
"build_event",
]
@@ -0,0 +1,89 @@
"""Bridge self-monitoring event parser.
The bridge generates these events from internal subsystems (watcher,
scheduler, dispatcher) — the parser turns a flat payload dict into the
generic :class:`ServiceEvent` shape that the rest of the dispatch
pipeline expects.
Payload shape::
{
"failure_type": "poll_failures" | "deferred_backlog" | "target_failures",
"subject_id": int, # tracker_id, target_id, or 0
"subject_name": str,
"count": int, # consecutive failures or pending count
"threshold": int,
"last_error": str, # may be empty
"details": dict[str, Any], # extra context
}
"""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from notify_bridge_core.models.events import EventType, ServiceEvent
from notify_bridge_core.providers.base import ServiceProviderType
# Defensive cap on the persisted error message; very long tracebacks would
# bloat the EventLog details JSON column otherwise.
_MAX_ERROR_LEN = 1000
_FAILURE_TYPE_TO_EVENT: dict[str, EventType] = {
"poll_failures": EventType.BRIDGE_SELF_POLL_FAILURES,
"deferred_backlog": EventType.BRIDGE_SELF_DEFERRED_BACKLOG,
"target_failures": EventType.BRIDGE_SELF_TARGET_FAILURES,
}
def build_event(
payload: dict[str, Any],
*,
provider_name: str = "Bridge Self-Monitoring",
timestamp: datetime | None = None,
) -> ServiceEvent | None:
"""Convert a self-monitoring payload dict into a ServiceEvent.
Returns None for malformed payloads (unknown failure_type or missing
keys) — the caller drops without raising so a misbehaving emitter
can never tip over the dispatch pipeline.
"""
if not isinstance(payload, dict):
return None
failure_type = payload.get("failure_type")
event_type = _FAILURE_TYPE_TO_EVENT.get(str(failure_type) if failure_type else "")
if event_type is None:
return None
subject_id = int(payload.get("subject_id") or 0)
subject_name = str(payload.get("subject_name") or "")
count = int(payload.get("count") or 0)
threshold = int(payload.get("threshold") or 0)
last_error = str(payload.get("last_error") or "")[:_MAX_ERROR_LEN]
details = payload.get("details") if isinstance(payload.get("details"), dict) else {}
when = timestamp or datetime.now(timezone.utc)
return ServiceEvent(
event_type=event_type,
provider_type=ServiceProviderType.BRIDGE_SELF,
provider_name=provider_name,
# ``collection_id`` / ``collection_name`` are required fields on
# ServiceEvent; we use the subject so quiet-hours / dedupe logic
# treats different subjects as distinct streams.
collection_id=str(subject_id),
collection_name=subject_name or str(failure_type),
timestamp=when,
extra={
"failure_type": str(failure_type),
"subject_id": subject_id,
"subject_name": subject_name,
"count": count,
"threshold": threshold,
"last_error": last_error,
"details": dict(details),
},
)
@@ -0,0 +1,148 @@
"""Bridge self-monitoring service provider — emits internal-failure events.
This is a passive provider: it does not connect to anything, never polls,
and never subscribes. It exists so the rest of the bridge's CRUD / config /
template / target plumbing has a single ``ServiceProvider`` to attach
self-monitoring trackers and notification slots to.
Events are constructed by the server-side helper
``services/bridge_self.emit_bridge_self_event`` and pushed into
``dispatch_provider_event`` directly — the provider itself is not asked
to produce events.
"""
from __future__ import annotations
from typing import Any
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.providers.base import (
ServiceProvider,
ServiceProviderType,
)
from notify_bridge_core.templates.variables import TemplateVariableDefinition
# Configuration keys recognised on the bridge_self provider's ``config`` JSON.
DEFAULT_POLL_FAILURE_THRESHOLD = 3
DEFAULT_DEFERRED_BACKLOG_THRESHOLD = 100
DEFAULT_TARGET_FAILURE_THRESHOLD = 5
# Template variables exposed to bridge_self templates.
BRIDGE_SELF_VARIABLES: list[TemplateVariableDefinition] = [
TemplateVariableDefinition(
name="failure_type",
type="string",
description="Which self-monitoring condition fired",
example="poll_failures",
provider_type=ServiceProviderType.BRIDGE_SELF,
),
TemplateVariableDefinition(
name="subject_id",
type="int",
description="ID of the affected entity (tracker_id, target_id, or 0)",
example="42",
provider_type=ServiceProviderType.BRIDGE_SELF,
),
TemplateVariableDefinition(
name="subject_name",
type="string",
description="Human-readable name of the affected entity",
example="My Immich Tracker",
provider_type=ServiceProviderType.BRIDGE_SELF,
),
TemplateVariableDefinition(
name="count",
type="int",
description="Consecutive failure count or current backlog size",
example="3",
provider_type=ServiceProviderType.BRIDGE_SELF,
),
TemplateVariableDefinition(
name="threshold",
type="int",
description="Configured threshold that was crossed",
example="3",
provider_type=ServiceProviderType.BRIDGE_SELF,
),
TemplateVariableDefinition(
name="last_error",
type="string",
description="Last underlying error message (truncated)",
example="Connection refused",
provider_type=ServiceProviderType.BRIDGE_SELF,
),
TemplateVariableDefinition(
name="details",
type="dict",
description="Extra structured context for the event",
example='{"provider_id": 7}',
provider_type=ServiceProviderType.BRIDGE_SELF,
),
]
class BridgeSelfServiceProvider(ServiceProvider):
"""Passive provider — exposes nothing remote, holds only thresholds.
Polling is a no-op and ``connect`` always succeeds; the bridge itself
is what generates events for this provider.
"""
provider_type = ServiceProviderType.BRIDGE_SELF
supports_subscription = False
def __init__(self, name: str = "Bridge Self-Monitoring") -> None:
self._name = name
async def connect(self) -> bool:
return True
async def disconnect(self) -> None:
return None
async def poll(
self,
collection_ids: list[str],
tracker_state: dict[str, Any],
) -> tuple[list[ServiceEvent], dict[str, Any]]:
# No external service to poll. Returning empty keeps the contract
# so accidental scheduling no-ops cleanly.
return [], tracker_state
def get_available_variables(self) -> list[TemplateVariableDefinition]:
return list(BRIDGE_SELF_VARIABLES)
def get_provider_config_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"poll_failure_threshold": {
"type": "integer",
"minimum": 1,
"default": DEFAULT_POLL_FAILURE_THRESHOLD,
"description": "Consecutive tracker poll failures before alerting",
},
"deferred_backlog_threshold": {
"type": "integer",
"minimum": 1,
"default": DEFAULT_DEFERRED_BACKLOG_THRESHOLD,
"description": "Pending deferred_dispatch rows before alerting",
},
"target_failure_threshold": {
"type": "integer",
"minimum": 1,
"default": DEFAULT_TARGET_FAILURE_THRESHOLD,
"description": "Consecutive target send failures before alerting",
},
},
"required": [],
}
async def list_collections(self) -> list[dict[str, Any]]:
# No collection concept — operators don't pick anything for this provider.
return []
async def test_connection(self) -> dict[str, Any]:
return {"ok": True, "message": "Bridge self-monitoring is always available"}
@@ -514,6 +514,39 @@ HOME_ASSISTANT_CAPABILITIES = ProviderCapabilities(
)
# ---------------------------------------------------------------------------
# Bridge self-monitoring capabilities
# ---------------------------------------------------------------------------
BRIDGE_SELF_CAPABILITIES = ProviderCapabilities(
provider_type="bridge_self",
display_name="Bridge Self-Monitoring",
webhook_based=False,
supported_filters=[],
notification_slots=[
{
"name": "message_bridge_self_poll_failures",
"description": "Tracker poll failures crossed threshold",
},
{
"name": "message_bridge_self_deferred_backlog",
"description": "Deferred dispatch backlog crossed threshold",
},
{
"name": "message_bridge_self_target_failures",
"description": "Target send failures crossed threshold",
},
],
events=[
{"name": "bridge_self_poll_failures", "description": "Tracker poll failures"},
{"name": "bridge_self_deferred_backlog", "description": "Deferred backlog high"},
{"name": "bridge_self_target_failures", "description": "Target send failures"},
],
command_slots=[],
commands=[],
)
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
@@ -527,6 +560,7 @@ _REGISTRY: dict[str, ProviderCapabilities] = {
"google_photos": GOOGLE_PHOTOS_CAPABILITIES,
"webhook": WEBHOOK_CAPABILITIES,
"home_assistant": HOME_ASSISTANT_CAPABILITIES,
"bridge_self": BRIDGE_SELF_CAPABILITIES,
}
@@ -10,7 +10,7 @@ arrive. The lifecycle is owned by the server-side subscription manager
from __future__ import annotations
import logging
from typing import Any
from typing import Any, Callable
import aiohttp
@@ -25,6 +25,12 @@ from notify_bridge_core.templates.variables import TemplateVariableDefinition
from .client import HomeAssistantWSClient
from .event_parser import parse_event
# Status callback signature: ``(state, detail)`` where ``state`` is one of
# ``"connected"`` / ``"disconnected"`` and ``detail`` is an optional already-
# redacted reason string (or None on connect).
StatusChangeCallback = Callable[[str, str | None], None]
_LOGGER = logging.getLogger(__name__)
@@ -229,7 +235,11 @@ class HomeAssistantServiceProvider(ServiceProvider):
# — the subscription manager owns this provider's lifecycle instead.
return [], tracker_state
async def subscribe(self, emit: EventEmitCallback) -> None:
async def subscribe(
self,
emit: EventEmitCallback,
on_status_change: StatusChangeCallback | None = None,
) -> None:
async def _on_event(ha_event: dict[str, Any]) -> None:
event = parse_event(
ha_event,
@@ -252,6 +262,7 @@ class HomeAssistantServiceProvider(ServiceProvider):
on_event=_on_event,
event_types=self._event_types,
refresh_areas=_refresh_areas,
on_status_change=on_status_change,
)
def get_available_variables(self) -> list[TemplateVariableDefinition]:
@@ -29,10 +29,21 @@ _LOGGER = logging.getLogger(__name__)
# calls per poll cycle. TTL is conservative (1h) and a hashed key keeps the
# raw api_key out of dict keys in case of a memory dump.
_USERS_CACHE_TTL_SECONDS = 3600
_users_cache_lock = asyncio.Lock()
# Lazy init: ``asyncio.Lock()`` at module import binds to whichever event
# loop is current at import time (often none, or the wrong one when tests
# spin up dedicated loops). Defer creation to first use.
_users_cache_lock: asyncio.Lock | None = None
_users_cache: dict[str, tuple[float, dict[str, str]]] = {}
def _get_users_cache_lock() -> asyncio.Lock:
"""Return the module users-cache lock, creating it on first call."""
global _users_cache_lock
if _users_cache_lock is None:
_users_cache_lock = asyncio.Lock()
return _users_cache_lock
def _users_cache_key(url: str, api_key: str) -> str:
digest = hashlib.sha256(f"{url}|{api_key}".encode("utf-8")).hexdigest()
return digest[:32]
@@ -51,7 +62,7 @@ async def _get_cached_users(
if entry is not None and (now - entry[0]) < _USERS_CACHE_TTL_SECONDS:
return entry[1]
async with _users_cache_lock:
async with _get_users_cache_lock():
# Re-check after acquiring the lock — another coroutine may have
# refreshed the entry while we waited.
entry = _users_cache.get(key)
@@ -200,10 +200,28 @@ class NutServiceProvider(ServiceProvider):
try:
for ups_name in collection_ids:
prev = tracker_state.get(ups_name, {})
# First-ever observation has no baseline — emitting transition
# events for whatever flags the device happens to carry would
# spam the user with "OB"/"LB"/"REPLBATT" alerts on every fresh
# tracker even when nothing changed. Seed state silently and
# skip event emission until the next poll provides a baseline.
is_first_observation = ups_name not in tracker_state
try:
variables = await client.list_var(ups_name)
data = NutUpsData.from_variables(ups_name, variables)
if is_first_observation:
new_state[ups_name] = {
"name": data.description or ups_name,
"status": data.status,
"battery_charge": data.battery_charge,
"comms_ok": True,
"asset_ids": [],
"pending_asset_ids": [],
"shared": False,
}
continue
# Check for comms restored
if not prev.get("comms_ok", True):
events.append(self._make_event(
@@ -35,6 +35,10 @@ _SENSITIVE_EXTRA_TOKENS: tuple[str, ...] = (
"bearer",
"private_key",
"access_key",
"oauth",
"client_secret",
"webhook_secret",
"csrf",
)
@@ -0,0 +1,6 @@
⚠️ <b>Deferred dispatch backlog high</b>
Pending notifications: <b>{{ count }}</b>
Threshold: <b>{{ threshold }}</b>
{%- if last_error %}
<i>Note:</i> <code>{{ last_error }}</code>
{%- endif %}
@@ -0,0 +1,6 @@
🚨 <b>Tracker poll failures</b>
<b>{{ subject_name }}</b> (id <code>{{ subject_id }}</code>)
<b>{{ count }}</b> consecutive failures (threshold {{ threshold }})
{%- if last_error %}
<i>Last error:</i> <code>{{ last_error }}</code>
{%- endif %}
@@ -0,0 +1,6 @@
📡 <b>Target send failures</b>
<b>{{ subject_name }}</b> (id <code>{{ subject_id }}</code>)
<b>{{ count }}</b> consecutive failures (threshold {{ threshold }})
{%- if last_error %}
<i>Last error:</i> <code>{{ last_error }}</code>
{%- endif %}
@@ -79,6 +79,11 @@ PROVIDER_SLOT_FILE_MAP: dict[str, dict[str, str]] = {
"message_ha_service_called": "ha_service_called.jinja2",
"message_ha_event_fired": "ha_event_fired.jinja2",
},
"bridge_self": {
"message_bridge_self_poll_failures": "bridge_self_poll_failures.jinja2",
"message_bridge_self_deferred_backlog": "bridge_self_deferred_backlog.jinja2",
"message_bridge_self_target_failures": "bridge_self_target_failures.jinja2",
},
}
# Backward-compatible alias
@@ -0,0 +1,6 @@
⚠️ <b>Очередь отложенной отправки растёт</b>
Ожидают отправки: <b>{{ count }}</b>
Порог: <b>{{ threshold }}</b>
{%- if last_error %}
<i>Примечание:</i> <code>{{ last_error }}</code>
{%- endif %}
@@ -0,0 +1,6 @@
🚨 <b>Сбои опроса трекера</b>
<b>{{ subject_name }}</b> (id <code>{{ subject_id }}</code>)
Подряд сбоев: <b>{{ count }}</b> (порог {{ threshold }})
{%- if last_error %}
<i>Последняя ошибка:</i> <code>{{ last_error }}</code>
{%- endif %}
@@ -0,0 +1,6 @@
📡 <b>Сбои отправки в адресат</b>
<b>{{ subject_name }}</b> (id <code>{{ subject_id }}</code>)
Подряд сбоев: <b>{{ count }}</b> (порог {{ threshold }})
{%- if last_error %}
<i>Последняя ошибка:</i> <code>{{ last_error }}</code>
{%- endif %}
@@ -13,6 +13,7 @@ from __future__ import annotations
import logging
import threading
from functools import lru_cache
from typing import Any
import jinja2
@@ -27,6 +28,19 @@ RENDER_TIMEOUT_SECONDS = 2.0
_env = SandboxedEnvironment(autoescape=True)
@lru_cache(maxsize=512)
def _compile_cached(template_str: str) -> jinja2.Template:
"""Compile + cache Jinja2 templates by source text.
Hot paths (NotificationDispatcher fan-out, periodic dispatch) re-render
the same template string for every event; ``_env.from_string`` parses
the source from scratch each time (~ms each). The 512-entry cache is
large enough to hold every template across a busy install while
keeping memory bounded.
"""
return _env.from_string(template_str)
class TemplateRenderTimeout(jinja2.TemplateError):
"""Raised when a template exceeds the configured render budget."""
@@ -74,7 +88,7 @@ def render_template(template_str: str, context: dict[str, Any]) -> str:
)
return "[Template too large]"
try:
compiled = _env.from_string(template_str)
compiled = _compile_cached(template_str)
output = _render_with_timeout(compiled, context)
except TemplateRenderTimeout as e:
_LOGGER.error("Template render timeout: %s", e)
@@ -27,6 +27,9 @@ def validate_template(
"has_oversized_videos", "max_video_size", "max_video_size_mb",
"added_assets", "assets", "albums",
"raw_payload", "event_type_raw", "source_ip",
# bridge_self self-monitoring variables.
"failure_type", "subject_id", "subject_name", "count",
"threshold", "last_error", "details",
}
allowed = available | runtime_vars
+1
View File
@@ -21,6 +21,7 @@ dependencies = [
"slowapi>=0.1.9",
"cachetools>=5.3",
"python-multipart>=0.0.9",
"prometheus_client>=0.20",
]
[project.optional-dependencies]
@@ -1,6 +1,7 @@
"""Action management API routes — CRUD, execute, dry-run, executions."""
import logging
import re
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
@@ -54,6 +55,58 @@ class ActionUpdate(BaseModel):
# ---------------------------------------------------------------------------
# Allowlist of fields a CRUD client may set on Action. Mirrors ActionCreate /
# ActionUpdate but enforced server-side so a tampered request body cannot
# overwrite ``user_id``, ``last_run_at``, ``created_at``, etc. via ``**dump``.
_ALLOWED_ACTION_CREATE_FIELDS = frozenset({
"provider_id", "name", "icon", "action_type", "config",
"schedule_type", "schedule_interval", "schedule_cron", "enabled",
})
_ALLOWED_ACTION_UPDATE_FIELDS = frozenset({
"name", "icon", "config",
"schedule_type", "schedule_interval", "schedule_cron", "enabled",
})
# 6 fields = standard cron, 7 fields = with seconds (Quartz-style). Reject
# the 7-field form whose first column allows fires more often than once per
# minute. Also reject ``*/N`` minute patterns where N<1 (so ``*/0``) and the
# bare ``*`` minute used together with ``*`` second.
_DISALLOWED_CRON_PATTERNS = (
re.compile(r"^\s*\*/0\s+"), # */0 in any leading position
)
def _validate_cron(expr: str) -> None:
"""Reject schedule_cron strings that fire more often than once per minute.
Without croniter as a hard dep we apply a conservative regex check: a
valid 5-field cron's first column is the minute, so anything other than
``*``/digits/comma/dash/slash there is bogus, and a sub-minute cadence
requires a 6+ field expression with seconds. Reject both shapes.
"""
if not expr or not expr.strip():
return
parts = expr.split()
if len(parts) >= 6:
# Seconds field present (Quartz-style or 6-field). Forbid
# second-level fires entirely; minute-cadence is the floor.
seconds_field = parts[0]
if seconds_field != "0":
raise HTTPException(
status_code=400,
detail=(
"schedule_cron with a sub-minute cadence is not allowed; "
"set the seconds field to 0 or use a standard 5-field cron"
),
)
for pattern in _DISALLOWED_CRON_PATTERNS:
if pattern.search(expr):
raise HTTPException(
status_code=400,
detail="schedule_cron contains a disallowed pattern",
)
async def _action_response(session: AsyncSession, action: Action) -> dict:
"""Build response dict with rules inlined."""
result = await session.exec(
@@ -127,7 +180,15 @@ async def create_action(
detail=f"Invalid action type '{body.action_type}' for provider type '{provider.type}'",
)
action = Action(user_id=user.id, **body.model_dump())
_validate_cron(body.schedule_cron)
# Project only allowlisted fields so a tampered body can't write
# ``user_id``, ``id``, ``last_run_at``, etc. via ``**dump``.
payload = {
k: v for k, v in body.model_dump().items()
if k in _ALLOWED_ACTION_CREATE_FIELDS
}
action = Action(user_id=user.id, **payload)
session.add(action)
await session.commit()
await session.refresh(action)
@@ -168,7 +229,13 @@ async def update_action(
raise HTTPException(status_code=404, detail="Action not found")
updates = body.model_dump(exclude_unset=True)
if "schedule_cron" in updates:
_validate_cron(updates["schedule_cron"] or "")
# Drop any field outside the update allowlist so a tampered request
# can't mutate ``user_id`` / ``provider_id`` / ``action_type`` etc.
for key, value in updates.items():
if key not in _ALLOWED_ACTION_UPDATE_FIELDS:
continue
setattr(action, key, value)
session.add(action)
await session.commit()
@@ -48,6 +48,40 @@ _LOGGER = logging.getLogger(__name__)
router = APIRouter(prefix="/api/backup", tags=["backup"])
# Hard caps on uploaded backup file shape — defend against parser DoS
# (deeply nested or pathologically wide JSON) before we hand the
# structure to the import pipeline.
_MAX_BACKUP_DEPTH = 10
_MAX_BACKUP_NODES = 100_000
def _validate_backup_shape(value: object, depth: int = 0, count: list[int] | None = None) -> None:
"""Walk ``value`` and reject anything beyond the depth/node caps.
Raises HTTPException(400) on overflow. Cheap O(n) walk; runs once
per upload.
"""
if count is None:
count = [0]
if depth > _MAX_BACKUP_DEPTH:
raise HTTPException(
status_code=400,
detail=f"Backup file too deeply nested (max depth {_MAX_BACKUP_DEPTH})",
)
count[0] += 1
if count[0] > _MAX_BACKUP_NODES:
raise HTTPException(
status_code=400,
detail=f"Backup file has too many nodes (max {_MAX_BACKUP_NODES})",
)
if isinstance(value, dict):
for v in value.values():
_validate_backup_shape(v, depth + 1, count)
elif isinstance(value, list):
for v in value:
_validate_backup_shape(v, depth + 1, count)
MAX_UPLOAD_SIZE = 10 * 1024 * 1024 # 10 MB
@@ -181,6 +215,8 @@ async def validate_config(
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
_validate_backup_shape(raw)
result = validate_backup(raw)
return result.model_dump()
@@ -204,6 +240,8 @@ async def import_config(
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
_validate_backup_shape(raw)
# Validate first
validation = validate_backup(raw)
if not validation.valid:
@@ -259,6 +297,8 @@ async def prepare_restore(
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
_validate_backup_shape(raw)
validation = validate_backup(raw)
if not validation.valid:
raise HTTPException(
@@ -504,11 +504,14 @@ async def delete_config(
if config.user_id == 0 and user.role != "admin":
raise HTTPException(status_code=403, detail="Cannot delete system default configs")
raise_if_used(await check_command_template_config(session, config.id), config.name)
slot_result = await session.exec(
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == config.id)
# Bulk delete slot rows so the round-trip count stays O(1) regardless
# of how many locale/slot combinations the config carries.
from sqlalchemy import delete as sa_delete
await session.execute(
sa_delete(CommandTemplateSlot).where(
CommandTemplateSlot.config_id == config.id
)
)
for slot in slot_result.all():
await session.delete(slot)
await session.delete(config)
await session.commit()
@@ -162,17 +162,26 @@ async def delete_command_tracker(
from ..services.command_sync import mark_dirty_for_tracker
await mark_dirty_for_tracker(tracker.id)
# Delete associated listeners, collecting bot IDs for polling cleanup
# First read the listeners we're about to delete so we can collect the
# set of telegram_bot IDs whose polling state may need to be re-checked.
# Then issue a single bulk DELETE instead of N per-row deletes.
from sqlalchemy import delete as sa_delete
result = await session.exec(
select(CommandTrackerListener).where(
CommandTrackerListener.command_tracker_id == tracker_id
)
)
bot_ids_to_check: set[int] = set()
for listener in result.all():
if listener.listener_type == "telegram_bot":
bot_ids_to_check.add(listener.listener_id)
await session.delete(listener)
bot_ids_to_check: set[int] = {
listener.listener_id
for listener in result.all()
if listener.listener_type == "telegram_bot"
}
await session.execute(
sa_delete(CommandTrackerListener).where(
CommandTrackerListener.command_tracker_id == tracker_id
)
)
await session.delete(tracker)
await session.commit()
@@ -0,0 +1,161 @@
"""Prometheus metrics endpoint and central registry.
Exposes operational metrics via ``GET /api/metrics`` in the standard
Prometheus text format. Unauthenticated by design — Prometheus scrapers do
not authenticate. If the API port crosses a trust boundary, disable via
``NOTIFY_BRIDGE_METRICS_ENABLED=false``.
Metrics are defined as module-level singletons so the rest of the codebase
can ``from notify_bridge_server.api.metrics import metrics`` and call
``metrics.dispatch_duration.labels(channel="telegram").observe(0.42)``
without re-creating the underlying objects.
Other modules MUST NOT ``import prometheus_client`` directly. Route every
metric through :data:`metrics` (a :class:`MetricsRegistry`) so we have one
place to swap implementations or add labels.
"""
from __future__ import annotations
import logging
from typing import Final
from fastapi import APIRouter, HTTPException
from starlette.responses import Response
from prometheus_client import (
CONTENT_TYPE_LATEST,
CollectorRegistry,
Counter,
Gauge,
Histogram,
generate_latest,
)
from ..config import settings as _settings
_LOGGER = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Metric definitions
# ---------------------------------------------------------------------------
# Use a dedicated CollectorRegistry instead of the global default registry so
# tests can construct the module repeatedly without ``Duplicated timeseries``
# errors and so we never accidentally export Python GC / process metrics that
# aren't part of the documented surface in OPERATIONS.md.
_REGISTRY: Final[CollectorRegistry] = CollectorRegistry()
class MetricsRegistry:
"""Singleton holder for module-level Prometheus collectors.
Instantiated once at import time as :data:`metrics`. Keep collectors as
instance attributes so call sites get IDE autocomplete and so swapping
the collector type (e.g. Counter -> Summary) is a one-line change here.
"""
def __init__(self, registry: CollectorRegistry) -> None:
self.registry = registry
# Gauge: populated on every scrape via the collector hook below.
self.deferred_pending = Gauge(
"notify_bridge_deferred_pending",
"Count of deferred_dispatch rows awaiting drain.",
registry=registry,
)
# Counter: incremented after each event_log row is persisted.
self.event_log_total = Counter(
"notify_bridge_event_log_total",
"Total events written to event_log, partitioned by status and event_type.",
["status", "event_type"],
registry=registry,
)
# Histogram: observed wall-clock seconds per outbound dispatch attempt.
self.dispatch_duration = Histogram(
"notify_bridge_dispatch_duration_seconds",
"Wall-clock duration of one dispatch attempt to a notification channel.",
["channel"],
registry=registry,
buckets=(0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0),
)
# Counter: each polling provider that fails a tick increments by 1.
self.provider_poll_failures = Counter(
"notify_bridge_provider_poll_failures_total",
"Polling provider failures partitioned by provider type.",
["provider_type"],
registry=registry,
)
# Counter: each rejected delivery to a target increments by 1.
self.target_send_failures = Counter(
"notify_bridge_target_send_failures_total",
"Failed sends to a target partitioned by target type and HTTP status.",
["target_type", "status_code"],
registry=registry,
)
metrics: Final[MetricsRegistry] = MetricsRegistry(_REGISTRY)
# ---------------------------------------------------------------------------
# Scrape hook: refresh dynamic gauges on demand
# ---------------------------------------------------------------------------
async def _refresh_deferred_pending_gauge() -> None:
"""Populate ``deferred_pending`` by counting pending rows in the DB.
Called from the request handler before serializing — we don't poll the
DB on a fixed cadence to avoid a steady-state cost when nothing is
scraping. Kept tolerant: a DB error logs and leaves the previous value.
"""
try:
from sqlalchemy import text
from ..database.engine import get_engine
engine = get_engine()
async with engine.connect() as conn:
result = await conn.execute(
text("SELECT count(*) FROM deferred_dispatch WHERE status='pending'")
)
row = result.first()
count = int(row[0]) if row else 0
metrics.deferred_pending.set(count)
except Exception as exc: # noqa: BLE001 — never fail the scrape over this
_LOGGER.debug("deferred_pending refresh skipped: %s", exc)
# ---------------------------------------------------------------------------
# Router
# ---------------------------------------------------------------------------
router = APIRouter(tags=["metrics"])
@router.get("/api/metrics")
async def metrics_endpoint() -> Response:
"""Expose collected metrics in Prometheus text format.
No auth by design — Prometheus scrapers don't authenticate. Gate the
endpoint via ``NOTIFY_BRIDGE_METRICS_ENABLED=false`` when the API port
is reachable from outside the trust boundary.
"""
if not _settings.metrics_enabled:
raise HTTPException(status_code=404, detail="Metrics disabled")
await _refresh_deferred_pending_gauge()
# Stub increments so the endpoint reports non-empty data even before
# callers wire instrumentation. Removed once code-paths are instrumented.
# The labels here intentionally use a sentinel value so dashboards can
# filter the noise out: ``status="bootstrap"``.
metrics.event_log_total.labels(status="bootstrap", event_type="metrics_scrape").inc(0)
payload = generate_latest(_REGISTRY)
return Response(content=payload, media_type=CONTENT_TYPE_LATEST)
@@ -152,6 +152,10 @@ async def create_notification_tracker(
session.add(tracker)
await session.commit()
await session.refresh(tracker)
# Drop the cached enabled-trackers list so the next inbound event
# (HA / webhook) sees the new tracker without waiting out the TTL.
from ..services.event_dispatch import invalidate_tracker_cache
invalidate_tracker_cache(tracker.provider_id)
if tracker.enabled:
await schedule_tracker(
tracker.id, tracker.scan_interval,
@@ -184,6 +188,8 @@ async def update_notification_tracker(
session.add(tracker)
await session.commit()
await session.refresh(tracker)
from ..services.event_dispatch import invalidate_tracker_cache
invalidate_tracker_cache(tracker.provider_id)
if tracker.enabled:
await schedule_tracker(
tracker.id, tracker.scan_interval,
@@ -201,28 +207,39 @@ async def delete_notification_tracker(
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Delete a tracker and its child rows in three bulk statements.
The previous implementation issued one DELETE per child row plus one
UPDATE per event_log row, which scaled linearly with the tracker's
history (an old, busy tracker could hit thousands of round-trips).
Bulk DELETE/UPDATE collapses that to three SQL statements regardless
of size.
"""
from sqlalchemy import delete as sa_delete, update as sa_update
tracker = await _get_user_tracker(session, tracker_id, user.id)
# Delete associated tracker-target links
result = await session.exec(
select(NotificationTrackerTarget).where(NotificationTrackerTarget.tracker_id == tracker_id)
# Junction rows — direct dependents of the tracker.
await session.execute(
sa_delete(NotificationTrackerTarget).where(
NotificationTrackerTarget.tracker_id == tracker_id
)
)
for tt in result.all():
await session.delete(tt)
# Delete associated tracker state
state_result = await session.exec(
select(NotificationTrackerState).where(NotificationTrackerState.tracker_id == tracker_id)
# Persisted scan state for this tracker.
await session.execute(
sa_delete(NotificationTrackerState).where(
NotificationTrackerState.tracker_id == tracker_id
)
)
for ts in state_result.all():
await session.delete(ts)
# Nullify event log references
event_result = await session.exec(
select(EventLog).where(EventLog.tracker_id == tracker_id)
# Preserve the audit trail in event_log; just null the back-reference
# so the tracker row can be removed without an FK violation.
await session.execute(
sa_update(EventLog).where(EventLog.tracker_id == tracker_id).values(tracker_id=None)
)
for el in event_result.all():
el.tracker_id = None
session.add(el)
provider_id_for_cache = tracker.provider_id
await session.delete(tracker)
await session.commit()
from ..services.event_dispatch import invalidate_tracker_cache
invalidate_tracker_cache(provider_id_for_cache)
await unschedule_tracker(tracker_id)
await reschedule_immich_dispatch_jobs()
@@ -1,9 +1,10 @@
"""Service provider management API routes."""
import logging
import secrets
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import AnyHttpUrl, BaseModel, ValidationError, field_validator
from pydantic import AnyHttpUrl, BaseModel, ValidationError, field_validator, model_validator
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from typing import Any
@@ -94,14 +95,36 @@ class PayloadMapping(BaseModel):
class WebhookProviderConfig(BaseModel):
auth_mode: str = "none"
# Default to bearer to avoid silently creating an open relay. Operators
# who genuinely want an unauthenticated endpoint must set
# ``acknowledge_unauthenticated=True`` to opt in explicitly.
auth_mode: str = "bearer_token"
webhook_secret: str | None = None
# Explicit opt-in required for ``auth_mode="none"``. Without this flag
# an unauthenticated webhook is rejected at validation time so a
# mis-clicked dropdown can't expose the bridge to arbitrary internet
# traffic.
acknowledge_unauthenticated: bool = False
payload_mappings: list[PayloadMapping] = []
event_type_path: str | None = None
collection_path: str | None = None
store_payloads: bool = True
max_stored_payloads: int = 20 # 1-100
@model_validator(mode="after")
def _check_auth(self) -> "WebhookProviderConfig":
if self.auth_mode == "none" and not self.acknowledge_unauthenticated:
raise ValueError(
"auth_mode='none' creates an open webhook endpoint; set "
"acknowledge_unauthenticated=true to confirm this is intentional"
)
if self.auth_mode in ("bearer_token", "hmac_sha256") and not self.webhook_secret:
# Auto-generate a strong secret if the operator forgot to supply
# one — better than rejecting an otherwise-valid config and far
# better than silently leaving the endpoint open.
self.webhook_secret = secrets.token_urlsafe(32)
return self
class HomeAssistantProviderConfig(BaseModel):
url: str
@@ -291,15 +291,19 @@ async def get_nav_counts(
):
"""Return entity counts for sidebar navigation badges.
Note: queries run sequentially because SQLAlchemy AsyncSession is NOT safe
for concurrent use within a single session (no asyncio.gather). We
minimise round-trips by combining user + system counts and per-type
target counts into single aggregate queries where possible.
Combines user-owned counts, system-owned shared counts, and per-type
target counts into a single round-trip via a UNION ALL of label + count
rows. SQLAlchemy AsyncSession is single-threaded so we cannot
asyncio.gather; collapsing 16 SELECTs into one is the optimisation.
"""
from sqlalchemy import literal, union_all
counts: dict[str, int] = {}
# --- 1) User-owned entity counts (one query per model) ---
for model, key in [
user_id = user.id
# User-owned counts: one (label, count) per model.
user_models = [
(ServiceProvider, "providers"),
(NotificationTracker, "notification_trackers"),
(TrackingConfig, "tracking_configs"),
@@ -311,40 +315,52 @@ async def get_nav_counts(
(CommandTracker, "command_trackers"),
(CommandConfig, "command_configs"),
(CommandTemplateConfig, "command_template_configs"),
]:
count = (await session.exec(
select(func.count()).select_from(model).where(model.user_id == user.id)
)).one()
counts[key] = count
# --- 2) Add system-owned counts (user_id=0) for shared entities ---
for model, key in [
]
# System-owned shared counts (user_id=0) folded back into the same key.
system_models = [
(TemplateConfig, "template_configs"),
(CommandTemplateConfig, "command_template_configs"),
(TrackingConfig, "tracking_configs"),
(CommandConfig, "command_configs"),
]:
system_count = (await session.exec(
select(func.count()).select_from(model).where(model.user_id == 0)
)).one()
counts[key] += system_count
# --- 3) Per-type target counts in a single query using conditional aggregation ---
]
target_types = ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix")
type_counts_result = (await session.exec(
select(
NotificationTarget.type,
func.count(),
# Initialise counts to 0 so missing UNION rows surface as zeroes
# instead of KeyErrors when a category has no rows.
for _model, key in user_models:
counts[key] = 0
for ttype in target_types:
counts[f"targets_{ttype}"] = 0
queries = []
for model, key in user_models:
queries.append(
select(literal(key).label("k"), func.count().label("c"))
.select_from(model).where(model.user_id == user_id)
)
.where(
NotificationTarget.user_id == user.id,
NotificationTarget.type.in_(target_types),
for model, key in system_models:
queries.append(
select(literal(f"__sys__:{key}").label("k"), func.count().label("c"))
.select_from(model).where(model.user_id == 0)
)
.group_by(NotificationTarget.type)
)).all()
type_counts_map = dict(type_counts_result)
for target_type in target_types:
counts[f"targets_{target_type}"] = type_counts_map.get(target_type, 0)
for ttype in target_types:
queries.append(
select(literal(f"target:{ttype}").label("k"), func.count().label("c"))
.select_from(NotificationTarget).where(
NotificationTarget.user_id == user_id,
NotificationTarget.type == ttype,
)
)
union_q = union_all(*queries)
rows = (await session.execute(union_q)).all()
for label, value in rows:
if label.startswith("__sys__:"):
counts[label.removeprefix("__sys__:")] += int(value or 0)
elif label.startswith("target:"):
counts[f"targets_{label.removeprefix('target:')}"] = int(value or 0)
else:
counts[label] = int(value or 0)
return counts
@@ -287,6 +287,8 @@ async def get_template_variables(
**_nut_variables(),
# --- Home Assistant slots ---
**_home_assistant_variables(),
# --- Bridge self-monitoring slots ---
**_bridge_self_variables(),
# --- Scheduler slots ---
"message_scheduled_message": {
"description": "Notification for scheduled message events",
@@ -487,6 +489,32 @@ def _home_assistant_variables() -> dict:
}
def _bridge_self_variables() -> dict:
common = {
"failure_type": "Which condition fired (poll_failures, deferred_backlog, target_failures)",
"subject_id": "Affected entity ID (tracker_id, target_id, or 0 for backlog)",
"subject_name": "Human-readable name of the affected entity",
"count": "Consecutive failure count or current backlog size",
"threshold": "Configured threshold that was crossed",
"last_error": "Last underlying error message (truncated)",
"details": "Extra structured context dict (use {{ details | tojson }})",
}
return {
"message_bridge_self_poll_failures": {
"description": "Tracker poll failures crossed threshold",
"variables": common,
},
"message_bridge_self_deferred_backlog": {
"description": "Deferred dispatch backlog crossed threshold",
"variables": common,
},
"message_bridge_self_target_failures": {
"description": "Target send failures crossed threshold",
"variables": common,
},
}
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_config(
body: TemplateConfigCreate,
@@ -64,9 +64,19 @@ async def create_user(
admin: User = Depends(require_admin),
session: AsyncSession = Depends(get_session),
):
"""Create a new user (admin only)."""
"""Create a new user (admin only).
Username is normalised to ``strip().lower()`` so "Admin" and "admin"
cannot coexist. We do not add a CHECK constraint at the DB level that
would require rebuilding the table on SQLite so the application is
the single source of truth for normalisation.
"""
# Normalise so case-only variants collide with existing accounts.
username = (body.username or "").strip().lower()
if not username:
raise HTTPException(status_code=400, detail="Username cannot be empty")
# Check for duplicate username
result = await session.exec(select(User).where(User.username == body.username))
result = await session.exec(select(User).where(User.username == username))
if result.first():
raise HTTPException(status_code=409, detail="Username already exists")
@@ -74,13 +84,25 @@ async def create_user(
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
user = User(
username=body.username,
username=username,
hashed_password=await _hash_password(body.password),
role=body.role if body.role in ("admin", "user") else "user",
)
session.add(user)
await session.commit()
await session.refresh(user)
# Auto-create the bridge_self provider so the new user immediately gets
# internal-failure notifications without manual setup. Best-effort —
# a seeding hiccup must not fail the user creation itself.
try:
from ..database.seeds import ensure_bridge_self_provider_for_user
await ensure_bridge_self_provider_for_user(session, user.id)
await session.commit()
except Exception: # noqa: BLE001
_LOGGER.exception("Failed to auto-seed bridge_self provider for user %s", user.id)
await session.rollback()
return {"id": user.id, "username": user.username, "role": user.role}
@@ -103,14 +125,19 @@ async def update_user(
identity_changed = False
if body.username is not None and body.username != user.username:
new_username = body.username.strip()
# Normalise to match the case-insensitive uniqueness rule applied
# at user creation. Comparing the normalised form against the
# stored username also avoids false-positive "no change" when a
# legacy mixed-case account is being renamed to its lower form.
new_username = (body.username or "").strip().lower()
if not new_username:
raise HTTPException(status_code=400, detail="Username cannot be empty")
dup = await session.exec(select(User).where(User.username == new_username))
if dup.first():
raise HTTPException(status_code=409, detail="Username already exists")
user.username = new_username
identity_changed = True
if new_username != user.username:
dup = await session.exec(select(User).where(User.username == new_username))
if dup.first():
raise HTTPException(status_code=409, detail="Username already exists")
user.username = new_username
identity_changed = True
if body.role is not None and body.role != user.role:
if body.role not in ("admin", "user"):
@@ -191,11 +218,139 @@ async def delete_user(
admin: User = Depends(require_admin),
session: AsyncSession = Depends(get_session),
):
"""Delete a user (admin only, cannot delete self)."""
"""Delete a user (admin only, cannot delete self).
Cascades through every user-owned table by hand. The model declares
``ondelete=CASCADE`` on each FK, but SQLite only enforces FK actions
on tables created *after* the ondelete clause was added existing
installs upgraded from older schemas need this Python-side cascade
instead of a multi-step table rebuild.
TODO: drop this manual cascade once we ship a real
rebuild-with-FK-actions migration for legacy SQLite installs (or
once Postgres becomes the default deployment target).
"""
from sqlalchemy import delete as sa_delete, update as sa_update
if user_id == admin.id:
raise HTTPException(status_code=400, detail="Cannot delete yourself")
user = await session.get(User, user_id)
if not user:
raise HTTPException(status_code=404, detail="User not found")
await session.delete(user)
await session.commit()
# Lazy import to avoid circulars.
from ..database.models import (
Action,
ActionExecution,
ActionRule,
CommandConfig,
CommandTracker,
CommandTrackerListener,
DeferredDispatch,
EventLog,
NotificationTarget,
NotificationTracker,
NotificationTrackerState,
NotificationTrackerTarget,
ServiceProvider,
TelegramBot,
TelegramChat,
TrackingConfig,
EmailBot,
MatrixBot,
)
# Wrap the entire cascade in one transaction so a failure mid-way
# cannot leave dangling child rows pointing at a missing user.
try:
# Order: leaves first, then their parents, finally the user. This
# matters even with FKs disabled — it's the natural dependency
# graph and avoids accidental constraint trips on engines that do
# enforce FKs (Postgres).
# Resolve tracker ids first (needed for state + link cleanup
# before the parent rows themselves are deleted further down).
from sqlmodel import select as _select
tracker_ids = list((await session.exec(
_select(NotificationTracker.id).where(NotificationTracker.user_id == user_id)
)).all())
if tracker_ids:
await session.execute(
sa_delete(NotificationTrackerState).where(
NotificationTrackerState.tracker_id.in_(tracker_ids)
)
)
await session.execute(
sa_delete(NotificationTrackerTarget).where(
NotificationTrackerTarget.tracker_id.in_(tracker_ids)
)
)
await session.execute(
sa_delete(DeferredDispatch).where(
DeferredDispatch.tracker_id.in_(tracker_ids)
)
)
# Action children: rules and execution log.
action_ids = list((await session.exec(
_select(Action.id).where(Action.user_id == user_id)
)).all())
if action_ids:
await session.execute(
sa_delete(ActionRule).where(ActionRule.action_id.in_(action_ids))
)
await session.execute(
sa_delete(ActionExecution).where(
ActionExecution.action_id.in_(action_ids)
)
)
# Command tracker children: listeners.
cmd_tracker_ids = list((await session.exec(
_select(CommandTracker.id).where(CommandTracker.user_id == user_id)
)).all())
if cmd_tracker_ids:
await session.execute(
sa_delete(CommandTrackerListener).where(
CommandTrackerListener.command_tracker_id.in_(cmd_tracker_ids)
)
)
# Telegram bot children: chats.
bot_ids = list((await session.exec(
_select(TelegramBot.id).where(TelegramBot.user_id == user_id)
)).all())
if bot_ids:
await session.execute(
sa_delete(TelegramChat).where(TelegramChat.bot_id.in_(bot_ids))
)
# Owned top-level entities (user is a direct owner).
for model in (
NotificationTracker,
NotificationTarget,
CommandTracker,
CommandConfig,
TrackingConfig,
Action,
TelegramBot,
EmailBot,
MatrixBot,
ServiceProvider,
):
await session.execute(
sa_delete(model).where(model.user_id == user_id)
)
# EventLog: keep the audit trail but null the owner reference so
# the rows survive the user delete (matches the SET NULL semantic
# declared on the model).
await session.execute(
sa_update(EventLog).where(EventLog.user_id == user_id).values(user_id=None)
)
await session.delete(user)
await session.commit()
except Exception:
await session.rollback()
raise
@@ -12,6 +12,8 @@ from fastapi import APIRouter, HTTPException, Request
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..auth.routes import limiter
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.providers.gitea.event_parser import parse_webhook as parse_gitea_webhook
from notify_bridge_core.providers.planka.event_parser import parse_webhook as parse_planka_webhook
@@ -240,6 +242,10 @@ async def planka_webhook(token: str, request: Request):
if not _verify_planka_token(webhook_secret, request):
raise HTTPException(status_code=403, detail="Invalid token")
# Read body AFTER auth check so an attacker without the bearer token
# can't force an unbounded read. Token is in the header, not the body.
raw_body = await _read_bounded_body(request)
# Parse payload from the bounded raw_body we already read.
try:
payload = json.loads(raw_body.decode("utf-8"))
@@ -320,6 +326,8 @@ def _verify_generic_webhook_auth(
_SENSITIVE_HEADER_SUBSTR = (
"token", "auth", "key", "secret", "signature", "password", "credential",
"cookie", "x-api", "x-hub-signature",
# Extended for per-key body redaction; harmless extras for header check.
"oauth", "client_secret", "webhook_secret", "csrf",
)
@@ -328,6 +336,28 @@ def _is_sensitive_header(name: str) -> bool:
return any(s in n for s in _SENSITIVE_HEADER_SUBSTR)
_REDACTED_PLACEHOLDER = "[REDACTED]"
def _redact_sensitive_body(value: object) -> object:
"""Walk a parsed JSON body and redact values for sensitive-named keys.
Returns a defensively-copied structure so the caller's object is
never mutated (callers downstream still consume the original).
"""
if isinstance(value, dict):
cleaned: dict[str, object] = {}
for k, v in value.items():
if isinstance(k, str) and _is_sensitive_header(k):
cleaned[k] = _REDACTED_PLACEHOLDER
else:
cleaned[k] = _redact_sensitive_body(v)
return cleaned
if isinstance(value, list):
return [_redact_sensitive_body(v) for v in value]
return value
def _filter_headers(raw_headers: dict[str, str]) -> dict[str, str]:
"""Keep only safe headers for logging (strip Authorization, signatures, tokens).
@@ -358,11 +388,15 @@ async def _save_webhook_log(
"""Insert a webhook payload log entry and prune old ones."""
try:
body_json = body if isinstance(body, dict) else {}
# Strip sensitive values before persistence — webhook payloads
# routinely include OAuth tokens / secrets in the body, and the
# log is admin-readable but not need-to-know for the operator.
safe_body = _redact_sensitive_body(body_json) if body_json else {}
session.add(WebhookPayloadLog(
provider_id=provider_id,
method=method,
headers=headers,
body=body_json,
body=safe_body,
status=status,
extracted_fields=extracted_fields or {},
error_message=error_message,
@@ -386,13 +420,19 @@ async def _save_webhook_log(
_LOGGER.warning("Failed to save webhook payload log for provider %d", provider_id, exc_info=True)
try:
await session.rollback()
except Exception:
pass
except Exception: # noqa: BLE001
_LOGGER.exception("Rollback after payload-log save failed")
@router.post("/webhook/{token}")
@limiter.limit("60/minute")
async def generic_webhook(token: str, request: Request):
"""Receive a generic webhook, extract variables via JSONPath, and dispatch notifications."""
"""Receive a generic webhook, extract variables via JSONPath, and dispatch notifications.
Per-IP rate limit (60/min) caps blast radius from a single source
legitimate providers send well below this; anything higher is either
a misconfigured retry loop or abuse.
"""
engine = get_engine()
# --- Load provider and validate auth ---
@@ -50,7 +50,12 @@ class RefreshRequest(BaseModel):
async def _hash_password(password: str) -> str:
"""bcrypt.hashpw is CPU-bound (~200-500ms); never run it on the event loop."""
"""bcrypt.hashpw is CPU-bound (~200-500ms); never run it on the event loop.
Caller is responsible for length-validating ``password`` against the
72-byte bcrypt cap before calling bcrypt silently truncates beyond
that, which is a correctness footgun, not a security one.
"""
def _work() -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
@@ -58,6 +63,24 @@ async def _hash_password(password: str) -> str:
return await asyncio.to_thread(_work)
# bcrypt's algorithm cap — the underlying primitive truncates input
# beyond this so two distinct passwords sharing a 72-byte prefix would
# verify identically. We reject up-front with a clear 422 message.
_BCRYPT_MAX_PASSWORD_BYTES = 72
def _check_bcrypt_length(password: str) -> None:
if len(password.encode("utf-8")) > _BCRYPT_MAX_PASSWORD_BYTES:
raise HTTPException(
status_code=422,
detail=(
f"Password too long; bcrypt limit is "
f"{_BCRYPT_MAX_PASSWORD_BYTES} bytes (longer passwords would "
"be silently truncated)"
),
)
async def _verify_password(password: str, hashed: str) -> bool:
def _work() -> bool:
try:
@@ -74,6 +97,7 @@ async def _verify_password(password: str, hashed: str) -> bool:
async def setup(request: Request, body: SetupRequest, session: AsyncSession = Depends(get_session)):
if len(body.password) < 8:
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
_check_bcrypt_length(body.password)
# Compute hash BEFORE opening the transaction so we don't hold a writer lock
# during the CPU-bound bcrypt work.
hashed = await _hash_password(body.password)
@@ -97,6 +121,16 @@ async def setup(request: Request, body: SetupRequest, session: AsyncSession = De
session.add(user)
await session.refresh(user)
# Auto-create the bridge_self provider for the new admin so internal-
# failure notifications work out of the box. Best-effort — a seeding
# failure should not abort setup.
try:
from ..database.seeds import ensure_bridge_self_provider_for_user
await ensure_bridge_self_provider_for_user(session, user.id)
await session.commit()
except Exception: # noqa: BLE001
await session.rollback()
return TokenResponse(
access_token=create_access_token(user.id, user.role, user.token_version),
refresh_token=create_refresh_token(user.id, user.token_version),
@@ -170,6 +204,7 @@ async def change_password(
raise HTTPException(status_code=400, detail="Current password is incorrect")
if len(body.new_password) < 8:
raise HTTPException(status_code=400, detail="New password must be at least 8 characters")
_check_bcrypt_length(body.new_password)
user.hashed_password = await _hash_password(body.new_password)
user.token_version = (user.token_version or 1) + 1
session.add(user)
@@ -43,10 +43,21 @@ async def telegram_webhook(
session: AsyncSession = Depends(get_session),
):
"""Handle incoming Telegram messages — route commands to handlers."""
# Validate webhook secret if configured
if _webhook_secret:
if not hmac.compare_digest(x_telegram_bot_api_secret_token or "", _webhook_secret):
raise HTTPException(status_code=403, detail="Invalid webhook secret")
# Telegram webhook secret is MANDATORY: without it any peer that knows
# the opaque webhook URL could inject arbitrary updates as if Telegram
# had sent them. Refuse to handle updates if no secret is configured.
if not _webhook_secret:
_LOGGER.error(
"Refusing Telegram webhook update for %s — webhook secret not configured "
"(set NOTIFY_BRIDGE_TELEGRAM_WEBHOOK_SECRET)",
webhook_id,
)
raise HTTPException(
status_code=401,
detail="Telegram webhook secret not configured on this server",
)
if not hmac.compare_digest(x_telegram_bot_api_secret_token or "", _webhook_secret):
raise HTTPException(status_code=403, detail="Invalid webhook secret")
# Find bot by opaque webhook path ID (not by token — token must not appear in URLs)
bot_result = await session.exec(
@@ -161,7 +172,17 @@ async def telegram_webhook(
async def register_webhook(bot_token: str, webhook_url: str, secret: str | None = None) -> dict:
"""Register webhook URL with Telegram Bot API via TelegramClient."""
"""Register webhook URL with Telegram Bot API via TelegramClient.
Refuses to register without a secret: a webhook without a secret
accepts any unauthenticated POST as a valid Telegram update, so we
never want one in production.
"""
if not secret:
raise ValueError(
"Telegram webhook registration requires a secret token "
"(set NOTIFY_BRIDGE_TELEGRAM_WEBHOOK_SECRET)"
)
from ..services.http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, bot_token)
@@ -76,6 +76,13 @@ class Settings(BaseSettings):
before migrations run using SQLite's ``VACUUM INTO`` (atomic, consistent).
"""
metrics_enabled: bool = True
"""Expose the Prometheus ``/api/metrics`` endpoint. Disable on hardened
deployments where the API port is exposed beyond the trust boundary
metrics are unauthenticated and can leak operational information about
queue depth, dispatch rates, and provider failures.
"""
model_config = {"env_prefix": "NOTIFY_BRIDGE_"}
def model_post_init(self, __context: Any) -> None:
@@ -309,6 +309,22 @@ async def migrate_schema(engine: AsyncEngine) -> None:
)
logger.info("Added %s column to tracking_config table", col_name)
# Add Bridge self-monitoring tracking flags to tracking_config if missing.
# All three default ON — the bridge_self provider exists specifically
# to surface these conditions, so silencing one would defeat the point.
if await _has_table(conn, "tracking_config"):
bridge_self_flags = [
("track_bridge_self_poll_failures", "INTEGER DEFAULT 1"),
("track_bridge_self_deferred_backlog", "INTEGER DEFAULT 1"),
("track_bridge_self_target_failures", "INTEGER DEFAULT 1"),
]
for col_name, col_type in bridge_self_flags:
if not await _has_column(conn, "tracking_config", col_name):
await conn.execute(
text(f"ALTER TABLE tracking_config ADD COLUMN {col_name} {col_type}")
)
logger.info("Added %s column to tracking_config table", col_name)
# Add quiet hours to tracking_config if missing.
# Start/end are nullable HH:MM strings; quiet_hours_enabled gates them.
if await _has_table(conn, "tracking_config"):
@@ -1361,6 +1377,12 @@ _INDEXES: list[tuple[str, str, str]] = [
("ix_action_provider_id", "action", "provider_id"),
# Dashboard: SELECT event_log WHERE user_id = ? ORDER BY created_at DESC
("ix_event_log_user_created", "event_log", "user_id, created_at DESC"),
# Dashboard "events of type X for me, recent first" filter.
(
"ix_event_log_user_event_type_created",
"event_log",
"user_id, event_type, created_at DESC",
),
("ix_event_log_provider_id", "event_log", "provider_id"),
("ix_event_log_notification_tracker_id", "event_log", "notification_tracker_id"),
("ix_event_log_action_id", "event_log", "action_id"),
@@ -1543,6 +1565,269 @@ async def migrate_chat_action_to_column(engine: AsyncEngine) -> None:
logger.info("Migrated chat_action from config JSON to column where present")
# ---------------------------------------------------------------------------
# Uniqueness + dedupe migrations for webhook hot paths.
#
# These backfill missing UNIQUE indexes on webhook tokens, webhook path IDs,
# bot_id (with sentinel guard), (bot_id, chat_id), and tracker-target links.
# Every CREATE UNIQUE INDEX is preceded by a dedupe pass that keeps the
# canonical row (lowest id, or oldest created_at where specified) and removes
# the rest, logging a WARNING with the dropped count so operators can audit.
# ---------------------------------------------------------------------------
async def _dedupe_by_columns(
conn,
table: str,
cols: list[str],
*,
keep: str = "min_id",
label: str = "",
) -> int:
"""Delete duplicate rows leaving one survivor per ``cols`` group.
``keep`` chooses the survivor:
- ``"min_id"`` keeps the row with the lowest ``id`` (default used
when there is no semantic "first" row to preserve).
- ``"min_created_at"`` keeps the row with the oldest ``created_at``,
falling back to the lowest id on ties preferred for tracker-target
links so the original link wins.
Returns the number of rows deleted. All identifiers flow through
``_assert_ident`` to neutralise SQL injection from any caller mistake.
"""
_assert_ident(table, "table")
for c in cols:
_assert_ident(c, "column")
group_by = ", ".join(cols)
where_cols = " AND ".join(f"{c} = g.{c}" for c in cols)
if keep == "min_created_at":
# Tie-break on id so the survivor is deterministic even if two rows
# share the same created_at (insert-batches commonly do).
survivor_sql = (
f"SELECT id FROM {table} "
f"WHERE {where_cols} "
f"ORDER BY created_at ASC, id ASC LIMIT 1"
)
elif keep == "min_id":
survivor_sql = f"SELECT MIN(id) FROM {table} WHERE {where_cols}"
else:
raise ValueError(f"Unknown keep strategy: {keep!r}")
delete_sql = (
f"DELETE FROM {table} WHERE id IN ("
f" SELECT t.id FROM {table} t "
f" JOIN ("
f" SELECT {group_by} FROM {table} "
f" GROUP BY {group_by} HAVING COUNT(*) > 1"
f" ) g ON {' AND '.join(f't.{c} = g.{c}' for c in cols)} "
f" WHERE t.id NOT IN ({survivor_sql})"
f")"
)
result = await conn.execute(text(delete_sql))
deleted = int(getattr(result, "rowcount", 0) or 0)
if deleted:
logger.warning(
"Removed %d duplicate row(s) from %s on (%s)%s",
deleted, table, ", ".join(cols),
f"{label}" if label else "",
)
return deleted
async def migrate_uniqueness_constraints(engine: AsyncEngine) -> None:
"""Backfill missing UNIQUE indexes on webhook hot paths.
SQLite cannot ALTER an existing column to add a UNIQUE constraint, but
a UNIQUE INDEX is functionally equivalent and can be created with
``IF NOT EXISTS`` on every boot. Each index is preceded by a dedupe
pass so the index creation does not fail on existing duplicates.
Indexes added:
- service_provider.webhook_token (full unique)
- telegram_bot.webhook_path_id (full unique)
- telegram_bot.bot_id (partial unique WHERE bot_id != 0; 0 is a
sentinel meaning "not yet validated")
- telegram_chat (bot_id, chat_id) (full unique composite)
- notification_tracker_target (notification_tracker_id, target_id)
(full unique composite)
"""
# Skip on non-SQLite engines — they enforce UNIQUE via the model
# metadata (create_all) and don't have sqlite_master introspection.
if not str(engine.url).startswith("sqlite"):
return
async with engine.begin() as conn:
# service_provider.webhook_token
if await _has_table(conn, "service_provider") and await _has_column(
conn, "service_provider", "webhook_token",
):
await _dedupe_by_columns(
conn, "service_provider", ["webhook_token"],
keep="min_id", label="webhook_token uniqueness",
)
await conn.execute(text(
"CREATE UNIQUE INDEX IF NOT EXISTS "
"uq_service_provider_webhook_token "
"ON service_provider(webhook_token)"
))
# telegram_bot.webhook_path_id (full unique)
# telegram_bot.bot_id (partial unique excluding sentinel 0)
if await _has_table(conn, "telegram_bot"):
if await _has_column(conn, "telegram_bot", "webhook_path_id"):
await _dedupe_by_columns(
conn, "telegram_bot", ["webhook_path_id"],
keep="min_id", label="webhook_path_id uniqueness",
)
await conn.execute(text(
"CREATE UNIQUE INDEX IF NOT EXISTS "
"uq_telegram_bot_webhook_path_id "
"ON telegram_bot(webhook_path_id)"
))
if await _has_column(conn, "telegram_bot", "bot_id"):
# Dedupe only non-sentinel rows. Two unverified bots both
# carrying bot_id=0 is legitimate — only collisions among
# validated bot_ids signal a real corruption to clean up.
deleted = await conn.execute(text(
"DELETE FROM telegram_bot WHERE id IN ("
" SELECT t.id FROM telegram_bot t "
" JOIN ("
" SELECT bot_id FROM telegram_bot "
" WHERE bot_id != 0 GROUP BY bot_id HAVING COUNT(*) > 1"
" ) g ON t.bot_id = g.bot_id "
" WHERE t.id NOT IN ("
" SELECT MIN(id) FROM telegram_bot WHERE bot_id = g.bot_id"
" )"
")"
))
rc = int(getattr(deleted, "rowcount", 0) or 0)
if rc:
logger.warning(
"Removed %d duplicate telegram_bot row(s) on bot_id "
"(non-sentinel collisions)", rc,
)
# Plain INDEX for the lookup-by-bot_id path.
await conn.execute(text(
"CREATE INDEX IF NOT EXISTS ix_telegram_bot_bot_id "
"ON telegram_bot(bot_id)"
))
# Partial UNIQUE excluding the sentinel.
await conn.execute(text(
"CREATE UNIQUE INDEX IF NOT EXISTS "
"uq_telegram_bot_bot_id_nonzero "
"ON telegram_bot(bot_id) WHERE bot_id != 0"
))
# telegram_chat (bot_id, chat_id) — keep the survivor with the oldest
# discovered_at so the original discovery row wins. _dedupe_by_columns
# only handles created_at; do this one inline.
if await _has_table(conn, "telegram_chat"):
res = await conn.execute(text(
"DELETE FROM telegram_chat WHERE id IN ("
" SELECT t.id FROM telegram_chat t "
" JOIN ("
" SELECT bot_id, chat_id FROM telegram_chat "
" GROUP BY bot_id, chat_id HAVING COUNT(*) > 1"
" ) g ON t.bot_id = g.bot_id AND t.chat_id = g.chat_id "
" WHERE t.id NOT IN ("
" SELECT id FROM telegram_chat "
" WHERE bot_id = g.bot_id AND chat_id = g.chat_id "
" ORDER BY discovered_at ASC, id ASC LIMIT 1"
" )"
")"
))
rc = int(getattr(res, "rowcount", 0) or 0)
if rc:
logger.warning(
"Removed %d duplicate telegram_chat row(s) on (bot_id, chat_id)",
rc,
)
await conn.execute(text(
"CREATE UNIQUE INDEX IF NOT EXISTS uq_telegram_chat_bot_chat "
"ON telegram_chat(bot_id, chat_id)"
))
await conn.execute(text(
"CREATE INDEX IF NOT EXISTS ix_telegram_chat_bot_chat "
"ON telegram_chat(bot_id, chat_id)"
))
# notification_tracker_target (notification_tracker_id, target_id)
# — keep the oldest created_at link so the original wins.
if await _has_table(conn, "notification_tracker_target") and await _has_column(
conn, "notification_tracker_target", "notification_tracker_id",
):
await _dedupe_by_columns(
conn,
"notification_tracker_target",
["notification_tracker_id", "target_id"],
keep="min_created_at",
label="tracker-target link uniqueness",
)
await conn.execute(text(
"CREATE UNIQUE INDEX IF NOT EXISTS uq_ntt_tracker_target "
"ON notification_tracker_target(notification_tracker_id, target_id)"
))
# service_provider partial unique on (user_id) WHERE type='bridge_self'.
# Bridge-self is special: exactly one row per user, auto-seeded at boot,
# at user-create, and on /setup. Without this guard, a concurrent boot
# backfill + POST /api/users could double-insert. Dedupe keeps the
# oldest row so any user-customised thresholds on it survive.
if await _has_table(conn, "service_provider"):
res = await conn.execute(text(
"DELETE FROM service_provider WHERE id IN ("
" SELECT t.id FROM service_provider t "
" JOIN ("
" SELECT user_id FROM service_provider "
" WHERE type='bridge_self' GROUP BY user_id HAVING COUNT(*) > 1"
" ) g ON t.user_id = g.user_id "
" WHERE t.type='bridge_self' AND t.id NOT IN ("
" SELECT MIN(id) FROM service_provider "
" WHERE type='bridge_self' AND user_id = g.user_id"
" )"
")"
))
rc = int(getattr(res, "rowcount", 0) or 0)
if rc:
logger.warning(
"Removed %d duplicate bridge_self service_provider row(s) "
"on user_id", rc,
)
await conn.execute(text(
"CREATE UNIQUE INDEX IF NOT EXISTS "
"uq_service_provider_bridge_self_per_user "
"ON service_provider(user_id) WHERE type='bridge_self'"
))
async def migrate_eventlog_provider_fk(engine: AsyncEngine) -> None:
"""Document the EventLog.provider_id FK situation.
SQLite cannot ALTER a column to add a foreign-key constraint without
rebuilding the table. The model annotation now declares
``ondelete=SET NULL`` which only takes effect on freshly created
tables (i.e. brand-new installs). For existing installs we rely on
application-side cleanup in ``api/providers.delete_provider`` to NULL
out ``event_log.provider_id`` rows before deleting the provider row.
This migration is intentionally a no-op aside from the log line it
exists so the migration order is explicit and operators see in the
logs that the FK strategy was reviewed on this boot.
"""
if not str(engine.url).startswith("sqlite"):
return
async with engine.begin() as conn:
if not await _has_table(conn, "event_log"):
return
# No DDL change. Application code in api/providers.delete_provider
# is the source of truth for the SET NULL semantic on existing tables.
logger.debug(
"event_log.provider_id FK enforcement deferred to application "
"code on existing SQLite tables (model declares ondelete=SET NULL "
"which applies to fresh schemas only)."
)
# ---------------------------------------------------------------------------
# Schema version tracking — lightweight alternative to Alembic while the
# hand-rolled idempotent migrations remain the source of truth. Gives
@@ -6,7 +6,7 @@ from datetime import datetime, timezone
from typing import Any
from uuid import uuid4
from sqlalchemy import ForeignKey, UniqueConstraint, Text
from sqlalchemy import ForeignKey, Index, UniqueConstraint, Text
from sqlmodel import JSON, Column, Field, SQLModel
@@ -29,12 +29,25 @@ class ServiceProvider(SQLModel, table=True):
__tablename__ = "service_provider"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
user_id: int = Field(
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
)
type: str # ServiceProviderType value ("immich")
name: str
icon: str = Field(default="")
config: dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON))
webhook_token: str = Field(default_factory=lambda: uuid4().hex)
# Webhook token is the shared secret embedded in inbound webhook URLs.
# Must be unique so a token uniquely identifies a provider; indexed so
# the webhook router does an O(log n) lookup on every inbound request.
webhook_token: str = Field(
default_factory=lambda: uuid4().hex,
unique=True,
index=True,
)
created_at: datetime = Field(default_factory=_utcnow)
@@ -42,13 +55,29 @@ class TelegramBot(SQLModel, table=True):
__tablename__ = "telegram_bot"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
user_id: int = Field(
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
)
name: str
token: str
icon: str = Field(default="")
bot_username: str = Field(default="")
bot_id: int = Field(default=0)
webhook_path_id: str = Field(default_factory=lambda: uuid4().hex)
# bot_id=0 is a sentinel meaning "Telegram has not yet returned a numeric
# ID for this bot" (i.e. token never validated). Multiple unverified bots
# may legitimately carry 0, so we only enforce uniqueness for non-sentinel
# values via a partial index added in migrate_uniqueness_constraints.
bot_id: int = Field(default=0, index=True)
# URL-path embedded in Telegram's setWebhook callback URL. Must be unique
# so the inbound dispatcher resolves a single bot per incoming request.
webhook_path_id: str = Field(
default_factory=lambda: uuid4().hex,
unique=True,
index=True,
)
update_mode: str = Field(default="none") # "none", "polling", or "webhook"
# NOTE: commands_config column remains in the DB for backward compat,
# but is no longer part of the SQLModel class. Data migrated to CommandConfig.
@@ -61,7 +90,13 @@ class MatrixBot(SQLModel, table=True):
__tablename__ = "matrix_bot"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
user_id: int = Field(
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
)
name: str
icon: str = Field(default="")
homeserver_url: str # e.g. https://matrix.org
@@ -76,7 +111,13 @@ class EmailBot(SQLModel, table=True):
__tablename__ = "email_bot"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
user_id: int = Field(
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
)
name: str
icon: str = Field(default="")
email: str # From address
@@ -90,6 +131,13 @@ class EmailBot(SQLModel, table=True):
class TelegramChat(SQLModel, table=True):
__tablename__ = "telegram_chat"
# (bot_id, chat_id) uniquely identifies a chat. The composite index is
# the access pattern for save_chat_from_webhook ON CONFLICT updates and
# for any "lookup by (bot, chat)" callers.
__table_args__ = (
UniqueConstraint("bot_id", "chat_id", name="uq_telegram_chat_bot_chat"),
Index("ix_telegram_chat_bot_chat", "bot_id", "chat_id"),
)
id: int | None = Field(default=None, primary_key=True)
bot_id: int = Field(foreign_key="telegram_bot.id")
@@ -109,7 +157,13 @@ class TrackingConfig(SQLModel, table=True):
__tablename__ = "tracking_config"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
user_id: int = Field(
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
)
provider_type: str # Must match provider's type
name: str
icon: str = Field(default="")
@@ -171,6 +225,12 @@ class TrackingConfig(SQLModel, table=True):
track_ha_service_called: bool = Field(default=False)
track_ha_event_fired: bool = Field(default=False)
# Bridge self-monitoring event tracking — defaults ON because the whole
# point of the provider is to alert on these conditions.
track_bridge_self_poll_failures: bool = Field(default=True)
track_bridge_self_deferred_backlog: bool = Field(default=True)
track_bridge_self_target_failures: bool = Field(default=True)
# Immich asset display
track_images: bool = Field(default=True)
track_videos: bool = Field(default=True)
@@ -276,7 +336,13 @@ class NotificationTarget(SQLModel, table=True):
__tablename__ = "notification_target"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
user_id: int = Field(
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
)
type: str # "telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"
name: str
icon: str = Field(default="")
@@ -319,7 +385,13 @@ class NotificationTracker(SQLModel, table=True):
__tablename__ = "notification_tracker"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
user_id: int = Field(
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
)
provider_id: int = Field(foreign_key="service_provider.id")
name: str
icon: str = Field(default="")
@@ -342,6 +414,15 @@ class NotificationTrackerTarget(SQLModel, table=True):
"""Junction between NotificationTracker and NotificationTarget with per-link config."""
__tablename__ = "notification_tracker_target"
# A tracker should never link to the same target twice — duplicate links
# would deliver the same notification multiple times. Enforced at the DB
# level so concurrent inserts can't bypass an application-side check.
__table_args__ = (
UniqueConstraint(
"notification_tracker_id", "target_id",
name="uq_ntt_tracker_target",
),
)
id: int | None = Field(default=None, primary_key=True)
# Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id
@@ -403,7 +484,13 @@ class CommandConfig(SQLModel, table=True):
__tablename__ = "command_config"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
user_id: int = Field(
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
)
provider_type: str
name: str
icon: str = Field(default="")
@@ -464,7 +551,13 @@ class CommandTracker(SQLModel, table=True):
__tablename__ = "command_tracker"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
user_id: int = Field(
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
)
provider_id: int = Field(foreign_key="service_provider.id")
command_config_id: int = Field(foreign_key="command_config.id")
name: str
@@ -517,7 +610,15 @@ class DeferredDispatch(SQLModel, table=True):
__tablename__ = "deferred_dispatch"
id: int | None = Field(default=None, primary_key=True)
user_id: int | None = Field(default=None, foreign_key="user.id", index=True)
user_id: int | None = Field(
default=None,
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="CASCADE"),
nullable=True,
index=True,
),
)
tracker_id: int = Field(foreign_key="notification_tracker.id", index=True)
# The specific link this deferral targets. On drain we re-fetch by ID; if
# the link was disabled or removed in the meantime we drop with a
@@ -566,8 +667,17 @@ class EventLog(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
# Owner. Indexed for the dashboard events query. Nullable only because
# historical rows (pre-user_id column) may have no owner; new rows always
# set this directly.
user_id: int | None = Field(default=None, foreign_key="user.id", index=True)
# set this directly. SET NULL on user delete preserves the audit trail
# while letting the user record itself be removed.
user_id: int | None = Field(
default=None,
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="SET NULL"),
nullable=True,
index=True,
),
)
# Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id
tracker_id: int | None = Field(
default=None,
@@ -594,7 +704,21 @@ class EventLog(SQLModel, table=True):
default=None, foreign_key="telegram_bot.id", index=True,
)
bot_name: str = Field(default="")
provider_id: int | None = Field(default=None, index=True)
# FK to service_provider with SET NULL so deleting a provider leaves
# historical event_log rows intact (provider_name preserves the label
# for display). The FK only takes effect on freshly created tables —
# SQLite cannot ALTER a constraint into an existing table without a
# rebuild, so application code in api/providers.delete_provider also
# nulls these explicitly. See migrate_eventlog_provider_fk.
provider_id: int | None = Field(
default=None,
sa_column=Column(
"provider_id",
ForeignKey("service_provider.id", ondelete="SET NULL"),
nullable=True,
index=True,
),
)
provider_name: str = Field(default="")
event_type: str = Field(index=True)
collection_id: str
@@ -610,7 +734,13 @@ class Action(SQLModel, table=True):
__tablename__ = "action"
id: int | None = Field(default=None, primary_key=True)
user_id: int = Field(foreign_key="user.id")
user_id: int = Field(
sa_column=Column(
"user_id",
ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
)
provider_id: int = Field(foreign_key="service_provider.id")
name: str
icon: str = Field(default="")
@@ -13,9 +13,11 @@ from .models import (
CommandConfig,
CommandTemplateConfig,
CommandTemplateSlot,
ServiceProvider,
TemplateConfig,
TemplateSlot,
TrackingConfig,
User,
)
_LOGGER = logging.getLogger(__name__)
@@ -159,6 +161,7 @@ async def _seed_default_templates() -> None:
await _seed_provider_template(session, "google_photos", "Google Photos")
await _seed_provider_template(session, "webhook", "Generic Webhook")
await _seed_provider_template(session, "home_assistant", "Home Assistant")
await _seed_provider_template(session, "bridge_self", "Bridge Self-Monitoring")
await session.commit()
@@ -285,6 +288,13 @@ async def _seed_default_tracking_configs() -> None:
"track_ha_service_called": False,
"track_ha_event_fired": False,
},
{
"provider_type": "bridge_self",
"name": "Default Bridge Self-Monitoring",
"track_bridge_self_poll_failures": True,
"track_bridge_self_deferred_backlog": True,
"track_bridge_self_target_failures": True,
},
]
for cfg in defaults:
@@ -403,6 +413,67 @@ async def _seed_default_command_configs() -> None:
await session.commit()
# ---------------------------------------------------------------------------
# Bridge self-monitoring per-user provider
# ---------------------------------------------------------------------------
# Default thresholds — duplicated here as constants instead of imported so the
# seed module stays self-contained and import-cycle-free during boot.
_BRIDGE_SELF_DEFAULT_CONFIG = {
"poll_failure_threshold": 3,
"deferred_backlog_threshold": 100,
"target_failure_threshold": 5,
}
async def ensure_bridge_self_provider_for_user(
session: AsyncSession, user_id: int,
) -> ServiceProvider | None:
"""Create the user's bridge_self provider if absent. Returns the provider.
The bridge_self provider is special exactly one per user, auto-created
so the operator never has to think about wiring it up. Idempotent.
Skips ``user_id <= 0`` (the ``__system__`` placeholder) which never
receives notifications.
"""
if user_id <= 0:
return None
result = await session.exec(
select(ServiceProvider).where(
ServiceProvider.user_id == user_id,
ServiceProvider.type == "bridge_self",
)
)
existing = result.first()
if existing is not None:
return existing
provider = ServiceProvider(
user_id=user_id,
type="bridge_self",
name="Bridge Self-Monitoring",
config=dict(_BRIDGE_SELF_DEFAULT_CONFIG),
)
session.add(provider)
await session.flush()
return provider
async def _seed_bridge_self_providers_for_existing_users() -> None:
"""Backfill bridge_self provider for every existing real user.
Runs once at boot so deployments upgrading from a pre-bridge_self
release pick up the auto-created provider without requiring user
action. Skips users that already have one.
"""
engine = get_engine()
async with AsyncSession(engine) as session:
users = (await session.exec(select(User).where(User.id != 0))).all()
for user in users:
await ensure_bridge_self_provider_for_user(session, user.id)
await session.commit()
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
@@ -442,3 +513,4 @@ async def seed_all() -> None:
await _seed_default_command_templates()
await _seed_default_tracking_configs()
await _seed_default_command_configs()
await _seed_bridge_self_providers_for_existing_users()
@@ -50,6 +50,7 @@ from .commands.webhook import router as webhook_router, set_webhook_secret
from .api.webhooks import router as webhooks_router
from .api.webhook_logs import router as webhook_logs_router
from .api.backup import router as backup_router
from .api.metrics import router as metrics_router
# Readiness flag — flipped to True once the scheduler has started and the
@@ -78,6 +79,8 @@ async def lifespan(app: FastAPI):
migrate_chat_action_to_column,
migrate_deferred_dispatch_event_log_fk,
migrate_deferred_dispatch_unique_pending,
migrate_uniqueness_constraints,
migrate_eventlog_provider_fk,
migrate_schema_version,
)
from .database.snapshot import snapshot_and_prune
@@ -107,6 +110,13 @@ async def lifespan(app: FastAPI):
# the partial unique index.
await migrate_deferred_dispatch_event_log_fk(engine)
await migrate_deferred_dispatch_unique_pending(engine)
# Backfill missing UNIQUE indexes on webhook hot paths (deduping any
# existing duplicates). Runs after performance_indexes so non-unique
# support indexes are already in place.
await migrate_uniqueness_constraints(engine)
# Document EventLog.provider_id FK strategy on existing tables (no-op
# on SQLite besides the log line; new tables get the FK from create_all).
await migrate_eventlog_provider_fk(engine)
await migrate_schema_version(engine)
from .database.seeds import seed_all
await seed_all()
@@ -254,6 +264,7 @@ app.include_router(webhook_router)
app.include_router(webhooks_router)
app.include_router(webhook_logs_router)
app.include_router(backup_router)
app.include_router(metrics_router)
@app.get("/api/health")
@@ -265,15 +276,107 @@ async def health():
@app.get("/api/ready")
async def ready():
"""Readiness: migrations and scheduler have started, app can serve traffic.
"""Readiness: deep dependency check.
Returns 503 until the lifespan startup sequence has completed. Use this
for orchestrator readiness probes (Docker, Kubernetes).
Verifies each critical dependency is actually reachable, not just that
the app finished its lifespan startup. Returns 503 if any *required*
check fails (db, scheduler). Home Assistant supervisor presence is
informational a degraded HA does not flip readiness off.
Response shape:
{
"ready": bool,
"checks": {"db": "ok|fail", "scheduler": "ok|fail", "ha": "ok|degraded|na"},
"errors": [str, ...]
}
"""
from starlette.responses import JSONResponse
import asyncio as _asyncio
from sqlalchemy import text as _text
checks: dict[str, str] = {}
errors: list[str] = []
if not _READY:
from starlette.responses import JSONResponse
return JSONResponse({"status": "starting"}, status_code=503)
return {"status": "ready", "version": _APP_VERSION}
# Lifespan still running — short-circuit so we don't poke a half-built engine.
return JSONResponse(
{
"ready": False,
"checks": {"db": "fail", "scheduler": "fail", "ha": "na"},
"errors": ["startup not complete"],
"version": _APP_VERSION,
},
status_code=503,
)
# --- DB: SELECT 1 with a 2s timeout ---
try:
from .database.engine import get_engine
engine = get_engine()
async def _ping_db() -> None:
async with engine.connect() as conn:
await conn.execute(_text("SELECT 1"))
await _asyncio.wait_for(_ping_db(), timeout=2.0)
checks["db"] = "ok"
except Exception as exc: # noqa: BLE001
checks["db"] = "fail"
errors.append(f"db: {exc!s}")
# --- Scheduler: APScheduler must be running ---
try:
from .services.scheduler import get_scheduler
scheduler = get_scheduler()
if scheduler.running:
checks["scheduler"] = "ok"
else:
checks["scheduler"] = "fail"
errors.append("scheduler: not running")
except Exception as exc: # noqa: BLE001
checks["scheduler"] = "fail"
errors.append(f"scheduler: {exc!s}")
# --- HA supervisor: informational only ---
# If no HA providers are configured, report "na" (not applicable). If any
# HA providers exist, ensure at least one supervisor task is alive — a
# task being not-yet-connected is fine, we just want it to exist.
try:
from sqlmodel import select as _select
from sqlmodel.ext.asyncio.session import AsyncSession as _AS
from .database.models import ServiceProvider
from .services.ha_subscription import _running_tasks as _ha_tasks
from .database.engine import get_engine as _get_engine_ha
async with _AS(_get_engine_ha()) as _session:
_result = await _session.exec(
_select(ServiceProvider).where(
ServiceProvider.type == "home_assistant",
)
)
ha_providers = _result.all()
if not ha_providers:
checks["ha"] = "na"
else:
alive = [
t for t in _ha_tasks.values() if t is not None and not t.done()
]
checks["ha"] = "ok" if alive else "degraded"
except Exception as exc: # noqa: BLE001
# Never let the HA probe fail readiness — it's informational.
checks["ha"] = "degraded"
errors.append(f"ha: {exc!s}")
required_ok = checks["db"] == "ok" and checks["scheduler"] == "ok"
body = {
"ready": required_ok,
"checks": checks,
"errors": errors,
"version": _APP_VERSION,
}
if not required_ok:
return JSONResponse(body, status_code=503)
return body
# --- Serve frontend static files (production) ---
@@ -667,13 +667,19 @@ async def import_backup(
if name is None:
continue
ctc_id = _map_id(id_map, "command_template_configs", cc.command_template_config_id)
try:
safe_enabled = _sanitize_config(cc.enabled_commands or {})
safe_limits = _sanitize_config(cc.rate_limits or {})
except ValueError as exc:
result.warnings.append(f"Skipped command config '{cc.name}': {exc}")
continue
new_cc = CommandConfig(
user_id=user_id, provider_type=cc.provider_type,
name=name, icon=cc.icon,
enabled_commands=cc.enabled_commands,
enabled_commands=safe_enabled,
response_mode=cc.response_mode,
default_count=cc.default_count,
rate_limits=cc.rate_limits,
rate_limits=safe_limits,
command_template_config_id=ctc_id,
)
session.add(new_cc)
@@ -728,10 +734,16 @@ async def import_backup(
)
if name is None:
continue
try:
safe_filters = _sanitize_config(nt.filters or {})
safe_collection_ids = _sanitize_config(nt.collection_ids or [])
except ValueError as exc:
result.warnings.append(f"Skipped tracker '{nt.name}': {exc}")
continue
new_nt = NotificationTracker(
user_id=user_id, provider_id=provider_id,
name=name, icon=nt.icon, collection_ids=nt.collection_ids,
filters=nt.filters, scan_interval=nt.scan_interval,
name=name, icon=nt.icon, collection_ids=safe_collection_ids,
filters=safe_filters, scan_interval=nt.scan_interval,
default_tracking_config_id=_map_id(id_map, "tracking_configs", nt.default_tracking_config_id),
default_template_config_id=_map_id(id_map, "template_configs", nt.default_template_config_id),
enabled=nt.enabled,
@@ -810,9 +822,14 @@ async def import_backup(
)
if name is None:
continue
try:
safe_a_cfg = _sanitize_config(a.config or {})
except ValueError as exc:
result.warnings.append(f"Skipped action '{a.name}': {exc}")
continue
new_a = Action(
user_id=user_id, provider_id=provider_id, name=name,
icon=a.icon, action_type=a.action_type, config=a.config,
icon=a.icon, action_type=a.action_type, config=safe_a_cfg,
schedule_type=a.schedule_type,
schedule_interval=a.schedule_interval,
schedule_cron=a.schedule_cron, enabled=False, # always import disabled
@@ -820,9 +837,16 @@ async def import_backup(
session.add(new_a)
await session.flush()
for r in a.rules:
try:
safe_r_cfg = _sanitize_config(r.rule_config or {})
except ValueError as exc:
result.warnings.append(
f"Skipped rule '{r.name}' in action '{a.name}': {exc}"
)
continue
session.add(ActionRule(
action_id=new_a.id, name=r.name,
rule_config=r.rule_config, enabled=r.enabled,
rule_config=safe_r_cfg, enabled=r.enabled,
order=r.order,
))
result.created += len(a.rules)
@@ -0,0 +1,432 @@
"""Bridge self-monitoring service helpers.
Three subsystems feed into ``emit_bridge_self_event``:
1. The watcher's poll loop, when consecutive provider polls fail.
2. A periodic scheduler job, when the deferred-dispatch backlog crosses
the configured threshold.
3. The notification dispatcher, when consecutive sends to a single target
fail with 5xx / network errors.
The helper looks up the user's ``bridge_self`` provider, builds a
synthetic :class:`ServiceEvent`, and pushes it through the same
``dispatch_provider_event`` pipeline that every other provider uses.
That keeps templates, quiet hours, deferral, target gating, and event
logging consistent with the rest of the system.
We intentionally avoid raising into the caller's flow — a
self-monitoring failure must never break the subsystem it's monitoring.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.providers.bridge_self import build_event
from notify_bridge_core.providers.bridge_self.provider import (
DEFAULT_DEFERRED_BACKLOG_THRESHOLD,
DEFAULT_POLL_FAILURE_THRESHOLD,
DEFAULT_TARGET_FAILURE_THRESHOLD,
)
from ..database.engine import get_engine
from ..database.models import ServiceProvider, User
_LOGGER = logging.getLogger(__name__)
# Detail keys carried into the EventLog.details JSON column. Mirrors the
# pattern used by the HA subscription and webhook routers for their
# ``dispatch_provider_event`` calls.
BRIDGE_SELF_DETAIL_KEYS: tuple[str, ...] = (
"failure_type", "subject_id", "subject_name",
"count", "threshold", "last_error", "details",
)
async def get_bridge_self_provider(
session: AsyncSession, user_id: int,
) -> ServiceProvider | None:
"""Return the user's bridge_self provider row (or None if absent)."""
result = await session.exec(
select(ServiceProvider).where(
ServiceProvider.user_id == user_id,
ServiceProvider.type == "bridge_self",
)
)
return result.first()
async def get_user_thresholds(user_id: int) -> dict[str, int]:
"""Return the user's bridge_self thresholds, falling back to defaults.
Reads in a short-lived session emission sites should NOT hold a
transaction across this call.
"""
engine = get_engine()
async with AsyncSession(engine) as session:
provider = await get_bridge_self_provider(session, user_id)
if provider is None:
return {
"poll_failure_threshold": DEFAULT_POLL_FAILURE_THRESHOLD,
"deferred_backlog_threshold": DEFAULT_DEFERRED_BACKLOG_THRESHOLD,
"target_failure_threshold": DEFAULT_TARGET_FAILURE_THRESHOLD,
}
cfg = dict(provider.config or {})
def _int(key: str, fallback: int) -> int:
raw = cfg.get(key, fallback)
try:
value = int(raw)
except (TypeError, ValueError):
return fallback
return value if value >= 1 else fallback
return {
"poll_failure_threshold": _int("poll_failure_threshold", DEFAULT_POLL_FAILURE_THRESHOLD),
"deferred_backlog_threshold": _int(
"deferred_backlog_threshold", DEFAULT_DEFERRED_BACKLOG_THRESHOLD,
),
"target_failure_threshold": _int(
"target_failure_threshold", DEFAULT_TARGET_FAILURE_THRESHOLD,
),
}
async def emit_bridge_self_event(
*,
user_id: int,
failure_type: str,
subject_id: int,
subject_name: str,
count: int,
threshold: int,
last_error: str = "",
details: dict[str, Any] | None = None,
timestamp: datetime | None = None,
) -> int:
"""Emit a self-monitoring event for ``user_id``.
Resolves the user's bridge_self provider and dispatches the event via
``dispatch_provider_event``. Returns the number of dispatched
notifications (0 when the user has no bridge_self provider, no
matching trackers, or the event was suppressed by quiet hours / event-
type gating).
Always swallows internal exceptions so the calling subsystem keeps
running self-monitoring must never crash the watcher / scheduler /
dispatcher.
"""
payload = {
"failure_type": failure_type,
"subject_id": subject_id,
"subject_name": subject_name,
"count": count,
"threshold": threshold,
"last_error": last_error,
"details": dict(details or {}),
}
event = build_event(payload, timestamp=timestamp or datetime.now(timezone.utc))
if event is None:
_LOGGER.debug("Skipping malformed bridge_self payload: %s", payload)
return 0
engine = get_engine()
try:
async with AsyncSession(engine) as session:
provider = await get_bridge_self_provider(session, user_id)
if provider is None:
_LOGGER.debug(
"User %s has no bridge_self provider; skipping %s emission",
user_id, failure_type,
)
return 0
provider_id = provider.id
provider_name = provider.name
provider_config = dict(provider.config or {})
# Imported here to avoid a top-level cycle: dispatch_helpers imports
# several models which transitively touch this module's siblings.
from .event_dispatch import dispatch_provider_event
return await dispatch_provider_event(
engine=engine,
provider_id=provider_id,
provider_name=provider_name,
provider_config=provider_config,
event=event,
detail_keys=BRIDGE_SELF_DETAIL_KEYS,
filter_fn=lambda _ev, _filters: True,
)
except Exception: # noqa: BLE001
_LOGGER.exception(
"bridge_self emission failed (user=%s, failure_type=%s)",
user_id, failure_type,
)
return 0
# ---------------------------------------------------------------------------
# Threshold-crossing trackers (in-memory, per-process).
#
# We track consecutive failure counts in module-level dicts keyed by the
# subject id (tracker_id, target_id). On threshold crossing we emit and
# reset the counter so we don't spam — the next emission only happens after
# another full streak of failures.
# ---------------------------------------------------------------------------
# Tracker poll failures (keyed by tracker_id).
_poll_failure_counts: dict[int, int] = {}
_poll_failure_last_error: dict[int, str] = {}
# Target send failures (keyed by target_id).
_target_failure_counts: dict[int, int] = {}
_target_failure_last_error: dict[int, str] = {}
# Last-known backlog state per user (True = above threshold, False = below).
# We only emit on the False -> True transition so a sustained backlog
# triggers exactly one notification per crossing.
_backlog_above_threshold: dict[int, bool] = {}
def record_poll_success(tracker_id: int) -> None:
"""Reset the failure counter for ``tracker_id`` after a successful poll."""
_poll_failure_counts.pop(tracker_id, None)
_poll_failure_last_error.pop(tracker_id, None)
def record_poll_failure(tracker_id: int, error: str = "") -> int:
"""Increment the failure counter for ``tracker_id``; return the new count."""
_poll_failure_counts[tracker_id] = _poll_failure_counts.get(tracker_id, 0) + 1
if error:
_poll_failure_last_error[tracker_id] = error
return _poll_failure_counts[tracker_id]
def reset_poll_counter(tracker_id: int) -> None:
"""Clear the failure counter for ``tracker_id`` without emitting."""
_poll_failure_counts.pop(tracker_id, None)
_poll_failure_last_error.pop(tracker_id, None)
def record_target_success(target_id: int) -> None:
"""Reset the failure counter for ``target_id`` after a successful send."""
_target_failure_counts.pop(target_id, None)
_target_failure_last_error.pop(target_id, None)
def record_target_failure(target_id: int, error: str = "") -> int:
"""Increment the failure counter for ``target_id``; return the new count."""
_target_failure_counts[target_id] = _target_failure_counts.get(target_id, 0) + 1
if error:
_target_failure_last_error[target_id] = error
return _target_failure_counts[target_id]
def reset_target_counter(target_id: int) -> None:
"""Clear the failure counter for ``target_id`` without emitting."""
_target_failure_counts.pop(target_id, None)
_target_failure_last_error.pop(target_id, None)
def record_backlog_state(user_id: int, above_threshold: bool) -> bool:
"""Record the new backlog state, returning True iff we just crossed up.
The first ever observation is treated as "below" so a process that
starts with a non-empty backlog still emits one notification.
"""
prior = _backlog_above_threshold.get(user_id, False)
_backlog_above_threshold[user_id] = above_threshold
return above_threshold and not prior
def get_poll_failure_count(tracker_id: int) -> int:
return _poll_failure_counts.get(tracker_id, 0)
def get_target_failure_count(target_id: int) -> int:
return _target_failure_counts.get(target_id, 0)
def get_poll_last_error(tracker_id: int) -> str:
return _poll_failure_last_error.get(tracker_id, "")
def get_target_last_error(target_id: int) -> str:
return _target_failure_last_error.get(target_id, "")
# ---------------------------------------------------------------------------
# User-level helpers
# ---------------------------------------------------------------------------
async def list_user_ids() -> list[int]:
"""Return all real user ids (excluding the __system__ placeholder)."""
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(select(User.id).where(User.id != 0))
return [int(uid) for uid in result.all() if uid is not None]
async def find_tracker_owner(tracker_id: int) -> int | None:
"""Return the user_id that owns ``tracker_id`` (or None)."""
from ..database.models import NotificationTracker
engine = get_engine()
async with AsyncSession(engine) as session:
tracker = await session.get(NotificationTracker, tracker_id)
if tracker is None:
return None
return int(tracker.user_id)
async def find_target_owner(target_id: int) -> int | None:
"""Return the user_id that owns ``target_id`` (or None)."""
from ..database.models import NotificationTarget
engine = get_engine()
async with AsyncSession(engine) as session:
target = await session.get(NotificationTarget, target_id)
if target is None:
return None
return int(target.user_id)
# ---------------------------------------------------------------------------
# Backlog scan
# ---------------------------------------------------------------------------
async def check_deferred_backlog() -> dict[str, Any]:
"""Scan the deferred_dispatch table and emit a backlog event if needed.
Counts pending rows per user, compares against each user's configured
threshold, and emits ``bridge_self_deferred_backlog`` for users that
just crossed up. Returns a small stats dict for logging.
"""
from sqlalchemy import func
from ..database.models import DeferredDispatch
engine = get_engine()
crossings = 0
async with AsyncSession(engine) as session:
# GROUP BY user_id so we don't have to scan once per user. Skip rows
# whose user_id is NULL — those are legacy / orphaned and have no
# bridge_self provider to alert anyway.
rows = (
await session.exec(
select(
DeferredDispatch.user_id,
func.count(DeferredDispatch.id),
)
.where(DeferredDispatch.status == "pending")
.where(DeferredDispatch.user_id.is_not(None))
.group_by(DeferredDispatch.user_id)
)
).all()
counts_by_user: dict[int, int] = {}
for row in rows:
if isinstance(row, tuple):
uid, count = row
else:
uid, count = row
if uid is None:
continue
counts_by_user[int(uid)] = int(count or 0)
for user_id, count in counts_by_user.items():
thresholds = await get_user_thresholds(user_id)
threshold = thresholds["deferred_backlog_threshold"]
above = count >= threshold
if record_backlog_state(user_id, above):
crossings += 1
await emit_bridge_self_event(
user_id=user_id,
failure_type="deferred_backlog",
subject_id=0,
subject_name="Deferred dispatch queue",
count=count,
threshold=threshold,
details={"pending": count},
)
# Reset latch for users that recovered (count < threshold or zero rows).
# Iterate all known users so a user whose backlog drained to 0 (no row in
# GROUP BY) still flips back to "below".
for user_id in list(_backlog_above_threshold.keys()):
if user_id in counts_by_user:
continue
# No pending rows for this user — clear the latch.
_backlog_above_threshold[user_id] = False
return {"users_scanned": len(counts_by_user), "crossings": crossings}
# ---------------------------------------------------------------------------
# Threshold-aware emission wrappers (used by watcher / dispatcher).
# ---------------------------------------------------------------------------
async def maybe_emit_poll_failure(
*, tracker_id: int, tracker_name: str, error: str = "",
) -> None:
"""Increment poll failure counter; emit + reset if threshold reached."""
count = record_poll_failure(tracker_id, error)
user_id = await find_tracker_owner(tracker_id)
if user_id is None:
return
thresholds = await get_user_thresholds(user_id)
threshold = thresholds["poll_failure_threshold"]
if count < threshold:
return
last_err = get_poll_last_error(tracker_id) or error
await emit_bridge_self_event(
user_id=user_id,
failure_type="poll_failures",
subject_id=tracker_id,
subject_name=tracker_name or f"tracker {tracker_id}",
count=count,
threshold=threshold,
last_error=last_err,
details={"tracker_id": tracker_id},
)
# Reset so the next emission requires another full streak. Without this
# the same tracker would fire on EVERY tick once it crosses the
# threshold, drowning the operator.
reset_poll_counter(tracker_id)
async def maybe_emit_target_failure(
*, target_id: int, target_name: str, target_type: str, error: str = "",
) -> None:
"""Increment target failure counter; emit + reset if threshold reached."""
count = record_target_failure(target_id, error)
user_id = await find_target_owner(target_id)
if user_id is None:
return
thresholds = await get_user_thresholds(user_id)
threshold = thresholds["target_failure_threshold"]
if count < threshold:
return
last_err = get_target_last_error(target_id) or error
await emit_bridge_self_event(
user_id=user_id,
failure_type="target_failures",
subject_id=target_id,
subject_name=target_name or f"target {target_id}",
count=count,
threshold=threshold,
last_error=last_err,
details={"target_id": target_id, "target_type": target_type},
)
reset_target_counter(target_id)
@@ -98,9 +98,16 @@ async def _flush_dirty_bots() -> None:
bot = await session.get(TelegramBot, bot_id)
if not bot:
continue
# Snapshot every attribute we touch after the session
# exits — once detached, lazy attribute access raises
# MissingGreenlet under SQLAlchemy async.
bot_username = bot.bot_username
# Expunge so the detached instance can still read snapshotted
# attrs but won't trigger a refresh / re-query downstream.
session.expunge(bot)
success = await register_commands_with_telegram(bot)
if success:
_LOGGER.info("Auto-synced commands for bot %d (@%s)", bot_id, bot.bot_username)
_LOGGER.info("Auto-synced commands for bot %d (@%s)", bot_id, bot_username)
else:
_LOGGER.warning("Auto-sync failed for bot %d", bot_id)
except Exception:
@@ -29,6 +29,7 @@ import logging
from datetime import datetime, timezone
from typing import Any
from sqlalchemy.orm.attributes import flag_modified
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -53,6 +54,7 @@ from .dispatch_helpers import (
evaluate_event_gate,
get_app_timezone,
load_link_data,
resolve_provider_credential,
)
_LOGGER = logging.getLogger(__name__)
@@ -88,6 +90,10 @@ _DEFERRABLE_EVENT_TYPES: frozenset[str] = frozenset({
"ups_online", "ups_on_battery", "ups_low_battery",
"ups_battery_restored", "ups_comms_lost", "ups_comms_restored",
"ups_replace_battery", "ups_overload",
# Home Assistant — state changes & automations are change-driven; the
# underlying state remains relevant after the quiet window.
"ha_state_changed", "ha_automation_triggered",
"ha_service_called", "ha_event_fired",
})
# Per-tracker cap on the pending queue. A misconfigured short quiet window
@@ -206,6 +212,11 @@ def _coalesce_assets_added(
payload["removed_asset_ids"] = kept
payload["removed_count"] = len(kept)
existing_removed_row.event_payload = payload
# Belt-and-braces: SQLAlchemy's mutation tracker sometimes
# misses JSON-typed reassignments depending on dialect / column
# config. Explicit flag_modified guarantees the dirty bit is
# set for the upcoming flush.
flag_modified(existing_removed_row, "event_payload")
if not kept:
# All previously-removed IDs are being re-added → entire
# removal is cancelled. Mark for caller to delete.
@@ -235,6 +246,7 @@ def _coalesce_assets_added(
payload["added_assets"] = existing_assets
payload["added_count"] = len(existing_assets)
existing_added_row.event_payload = payload
flag_modified(existing_added_row, "event_payload")
return ("merge", existing_added_row, existing_removed_row)
@@ -257,6 +269,7 @@ def _coalesce_assets_removed(
payload["added_assets"] = kept_assets
payload["added_count"] = len(kept_assets)
existing_added_row.event_payload = payload
flag_modified(existing_added_row, "event_payload")
if not kept_assets:
existing_added_row.status = "cancelled"
# IDs that were just added during the window don't need to flow
@@ -282,6 +295,7 @@ def _coalesce_assets_removed(
payload["removed_asset_ids"] = existing_ids
payload["removed_count"] = len(existing_ids)
existing_removed_row.event_payload = payload
flag_modified(existing_removed_row, "event_payload")
return ("merge", existing_added_row, existing_removed_row)
@@ -695,7 +709,7 @@ async def _process_row(
template_slots=ld.get("template_slots"),
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
date_only_format=(tmpl.date_only_format if tmpl and tmpl.date_only_format else "%d.%m.%Y"),
provider_api_key=provider_config.get("api_key") or provider_config.get("api_token"),
provider_api_key=resolve_provider_credential(provider_config),
provider_internal_url=provider_config.get("url", ""),
provider_external_url=provider_config.get("external_domain", "") or provider_config.get("url", ""),
receivers=ld["receivers"],
@@ -210,6 +210,13 @@ def _event_type_enabled(event: ServiceEvent, tc: TrackingConfig) -> bool:
"ha_automation_triggered": getattr(tc, "track_ha_automation_triggered", False),
"ha_service_called": getattr(tc, "track_ha_service_called", False),
"ha_event_fired": getattr(tc, "track_ha_event_fired", False),
# Bridge self-monitoring — defaults True so a tracker created before the
# columns existed still surfaces the alerts. Legacy rows are extremely
# unlikely here (the columns ship in the same release as the provider),
# but the safer default matches the rest of this map.
"bridge_self_poll_failures": getattr(tc, "track_bridge_self_poll_failures", True),
"bridge_self_deferred_backlog": getattr(tc, "track_bridge_self_deferred_backlog", True),
"bridge_self_target_failures": getattr(tc, "track_bridge_self_target_failures", True),
}
return flag_map.get(event_type, True)
@@ -225,11 +232,15 @@ def evaluate_event_gate(
by quiet hours the UTC datetime at which the window ends so the caller
can schedule a deferred dispatch.
Order of checks: quiet hours first, then per-event-type flag. Quiet hours
is the "louder" gate (it applies to every type), so reporting it first
avoids the surprising case of "you disabled this event type" showing up
when the user really just opened the quiet window.
Order of checks: per-event-type flag FIRST, then quiet hours. Otherwise a
disabled event type would get deferred during quiet hours and then
silently dropped at drain time wasted work and a confusing "deferred
then dropped" trail in the dashboard. The user already said "don't tell
me about this kind of event"; honour that immediately.
"""
if not _event_type_enabled(event, tc):
return GateOutcome(reason=GateReason.EVENT_TYPE_DISABLED)
if tc.quiet_hours_enabled:
end_at = quiet_hours_status(
tc.quiet_hours_start, tc.quiet_hours_end, tz_name,
@@ -240,9 +251,6 @@ def evaluate_event_gate(
quiet_hours_end_at=end_at,
)
if not _event_type_enabled(event, tc):
return GateOutcome(reason=GateReason.EVENT_TYPE_DISABLED)
return GateOutcome(reason=GateReason.ALLOWED)
@@ -397,23 +405,49 @@ def apply_tracking_display_filters(
)
def resolve_provider_credential(cfg: dict[str, Any] | None) -> str | None:
"""Pick the first non-empty provider credential field.
Provider configs use different field names (Immich ``api_key``,
Gitea ``api_token``, HA ``access_token``). All four dispatch
sites used to pick one field by hand; centralising here keeps the
fallback order consistent so a config edit on one provider type can't
silently break dispatch for another.
"""
if not cfg:
return None
return cfg.get("api_key") or cfg.get("api_token") or cfg.get("access_token")
async def _resolve_target(
session: AsyncSession,
target: NotificationTarget,
*,
receivers_by_target: dict[int, list[TargetReceiver]] | None = None,
telegram_chats_by_bot: dict[int, dict[str, TelegramChat]] | None = None,
email_bots_by_id: dict[int, EmailBot] | None = None,
matrix_bots_by_id: dict[int, MatrixBot] | None = None,
) -> dict[str, Any]:
"""Resolve a single target into dispatch-ready data (config + receivers + credentials).
Returns a dict with target_type, target_config, and receivers.
Does NOT include tracking_config or template_slots those come from the tracker link.
Optional ``*_by_*`` maps short-circuit per-target DB queries when the
caller has batch-prefetched the data. When omitted, we fall back to the
original single-query path so direct callers (manual_dispatch) still work.
"""
# Load receivers as typed Receiver objects
recv_result = await session.exec(
select(TargetReceiver).where(
TargetReceiver.target_id == target.id,
TargetReceiver.enabled == True,
# Receivers — prefer pre-fetched map.
if receivers_by_target is not None:
recv_rows = [r for r in receivers_by_target.get(target.id, []) if r.enabled]
else:
recv_result = await session.exec(
select(TargetReceiver).where(
TargetReceiver.target_id == target.id,
TargetReceiver.enabled == True,
)
)
)
recv_rows = recv_result.all()
recv_rows = recv_result.all()
# For Telegram targets, resolve locale from TelegramChat
chat_locale_map: dict[str, str] = {}
@@ -422,13 +456,24 @@ async def _resolve_target(
if bot_id:
chat_ids = [str(r.config.get("chat_id", "")) for r in recv_rows if r.config.get("chat_id")]
if chat_ids:
chat_result = await session.exec(
select(TelegramChat).where(
TelegramChat.bot_id == bot_id,
TelegramChat.chat_id.in_(chat_ids),
)
chats_for_bot = (
telegram_chats_by_bot.get(bot_id, {})
if telegram_chats_by_bot is not None else None
)
for chat in chat_result.all():
if chats_for_bot is not None:
rows = [
chats_for_bot[cid] for cid in chat_ids
if cid in chats_for_bot
]
else:
chat_result = await session.exec(
select(TelegramChat).where(
TelegramChat.bot_id == bot_id,
TelegramChat.chat_id.in_(chat_ids),
)
)
rows = chat_result.all()
for chat in rows:
resolved = (
getattr(chat, 'language_override', '') or
getattr(chat, 'language_code', '') or ''
@@ -457,7 +502,10 @@ async def _resolve_target(
if target.type == "email":
email_bot_id = target.config.get("email_bot_id")
if email_bot_id:
email_bot = await session.get(EmailBot, email_bot_id)
if email_bots_by_id is not None:
email_bot = email_bots_by_id.get(email_bot_id)
else:
email_bot = await session.get(EmailBot, email_bot_id)
if email_bot:
target_config["smtp"] = {
"host": email_bot.smtp_host,
@@ -471,12 +519,17 @@ async def _resolve_target(
elif target.type == "matrix":
matrix_bot_id = target.config.get("matrix_bot_id")
if matrix_bot_id:
matrix_bot = await session.get(MatrixBot, matrix_bot_id)
if matrix_bots_by_id is not None:
matrix_bot = matrix_bots_by_id.get(matrix_bot_id)
else:
matrix_bot = await session.get(MatrixBot, matrix_bot_id)
if matrix_bot:
target_config["homeserver_url"] = matrix_bot.homeserver_url
target_config["access_token"] = matrix_bot.access_token
return {
"target_id": target.id,
"target_name": target.name,
"target_type": target.type,
"target_config": target_config,
"receivers": receivers,
@@ -567,6 +620,76 @@ async def load_link_data(
)
child_target_map = {t.id: t for t in child_rows.all()}
# ---- Batch pre-fetch for _resolve_target ----
# Build the universe of target IDs (regular + expanded broadcast children)
# so a single query per relation type covers every call to _resolve_target
# below — no per-target follow-up SELECTs.
all_target_ids: set[int] = set(target_map.keys()) | set(child_target_map.keys())
receivers_by_target: dict[int, list[TargetReceiver]] = {}
if all_target_ids:
recv_result = await session.exec(
select(TargetReceiver).where(TargetReceiver.target_id.in_(all_target_ids))
)
for r in recv_result.all():
receivers_by_target.setdefault(r.target_id, []).append(r)
# Telegram chats keyed by (bot_id, chat_id) — collect all (bot_id, chat_id)
# pairs referenced by enabled telegram-target receivers, then one query.
tg_pairs: dict[int, set[str]] = {} # bot_id -> {chat_id}
for tid in all_target_ids:
tgt = target_map.get(tid) or child_target_map.get(tid)
if not tgt or tgt.type != "telegram":
continue
bot_id = tgt.config.get("bot_id")
if not bot_id:
continue
for r in receivers_by_target.get(tid, []):
cid = str(r.config.get("chat_id", ""))
if cid:
tg_pairs.setdefault(bot_id, set()).add(cid)
telegram_chats_by_bot: dict[int, dict[str, TelegramChat]] = {}
for bot_id, chat_ids in tg_pairs.items():
if not chat_ids:
continue
chat_rows = await session.exec(
select(TelegramChat).where(
TelegramChat.bot_id == bot_id,
TelegramChat.chat_id.in_(chat_ids),
)
)
telegram_chats_by_bot[bot_id] = {c.chat_id: c for c in chat_rows.all()}
# Email + Matrix bots
email_bot_ids: set[int] = set()
matrix_bot_ids: set[int] = set()
for tid in all_target_ids:
tgt = target_map.get(tid) or child_target_map.get(tid)
if not tgt:
continue
if tgt.type == "email":
bid = tgt.config.get("email_bot_id")
if bid:
email_bot_ids.add(bid)
elif tgt.type == "matrix":
bid = tgt.config.get("matrix_bot_id")
if bid:
matrix_bot_ids.add(bid)
email_bots_by_id: dict[int, EmailBot] = {}
if email_bot_ids:
rows = await session.exec(
select(EmailBot).where(EmailBot.id.in_(email_bot_ids))
)
email_bots_by_id = {b.id: b for b in rows.all()}
matrix_bots_by_id: dict[int, MatrixBot] = {}
if matrix_bot_ids:
rows = await session.exec(
select(MatrixBot).where(MatrixBot.id.in_(matrix_bot_ids))
)
matrix_bots_by_id = {b.id: b for b in rows.all()}
link_data: list[dict[str, Any]] = []
for tt in active_links:
target = target_map.get(tt.target_id)
@@ -589,7 +712,13 @@ async def load_link_data(
child_target = child_target_map.get(child_id)
if not child_target or child_target.type == "broadcast":
continue
resolved = await _resolve_target(session, child_target)
resolved = await _resolve_target(
session, child_target,
receivers_by_target=receivers_by_target,
telegram_chats_by_bot=telegram_chats_by_bot,
email_bots_by_id=email_bots_by_id,
matrix_bots_by_id=matrix_bots_by_id,
)
link_data.append({
**resolved,
"link_id": tt.id,
@@ -600,7 +729,13 @@ async def load_link_data(
continue
# Regular target
resolved = await _resolve_target(session, target)
resolved = await _resolve_target(
session, target,
receivers_by_target=receivers_by_target,
telegram_chats_by_bot=telegram_chats_by_bot,
email_bots_by_id=email_bots_by_id,
matrix_bots_by_id=matrix_bots_by_id,
)
link_data.append({
**resolved,
"link_id": tt.id,
@@ -14,6 +14,7 @@ services -> api cycle).
from __future__ import annotations
import logging
import time
from typing import Any, Awaitable, Callable
from sqlmodel import select
@@ -33,6 +34,7 @@ from .dispatch_helpers import (
evaluate_event_gate,
get_app_timezone,
load_link_data,
resolve_provider_credential,
)
_LOGGER = logging.getLogger(__name__)
@@ -44,6 +46,67 @@ _LOGGER = logging.getLogger(__name__)
FilterFn = Callable[[ServiceEvent, dict[str, Any]], bool]
# ---------------------------------------------------------------------------
# Tracker cache (per-provider, TTL-bounded)
# ---------------------------------------------------------------------------
#
# HA's chat-bus emits dozens of events per second; the per-event SELECT for
# enabled trackers becomes the bottleneck on busy installs. This 5-second
# TTL cache short-circuits the lookup when the same provider is hot. Cache
# entries are invalidated explicitly by tracker CRUD endpoints; the TTL is
# the safety net for missed invalidations.
_TRACKER_CACHE_TTL_SECONDS = 5.0
_trackers_cache: dict[int, tuple[float, list[NotificationTracker]]] = {}
def invalidate_tracker_cache(provider_id: int | None = None) -> None:
"""Drop cached trackers for one provider (or all if ``provider_id`` is None).
Call from tracker create / update / delete endpoints so the next
inbound event sees the change without waiting out the TTL.
"""
if provider_id is None:
_trackers_cache.clear()
else:
_trackers_cache.pop(provider_id, None)
async def _load_trackers_cached(
session: AsyncSession, provider_id: int,
) -> list[NotificationTracker]:
"""Return enabled trackers for ``provider_id``, with a short TTL cache.
Caches the ``NotificationTracker`` rows themselves NOT the per-tracker
``NotificationTrackerTarget`` link rows. ``load_link_data`` always re-reads
links from the DB on every dispatch, so adding/removing/toggling a link
does NOT require invalidating this cache. Only call ``invalidate_tracker_cache``
when a tracker row is created/updated/deleted.
"""
now = time.monotonic()
cached = _trackers_cache.get(provider_id)
if cached is not None:
ts, trackers = cached
if now - ts < _TRACKER_CACHE_TTL_SECONDS:
return trackers
result = await session.exec(
select(NotificationTracker).where(
NotificationTracker.provider_id == provider_id,
NotificationTracker.enabled == True, # noqa: E712
)
)
trackers = list(result.all())
# Detach cached instances so consumers don't accidentally use a stale
# session — re-fetch by id when mutating.
for t in trackers:
try:
session.expunge(t)
except Exception: # noqa: BLE001
pass
_trackers_cache[provider_id] = (now, trackers)
return trackers
async def dispatch_provider_event(
engine: Any,
provider_id: int,
@@ -82,18 +145,25 @@ async def dispatch_provider_event(
# Drain-scheduling is best-effort: a scheduling failure must not roll
# back the persisted defer rows (startup catch-up re-establishes them).
defers_to_schedule: set[Any] = set()
# Build the dispatcher once per inbound event — its only state is the
# shared aiohttp session and Telegram caches, both of which are reused
# across all trackers. Re-creating per-tracker meant a fresh dispatcher
# for every notification, paying the construction cost on every HA
# state_changed (ha provider can fire dozens per second).
from .http_session import get_http_session
from .watcher import _get_telegram_caches
url_cache, asset_cache = await _get_telegram_caches()
dispatcher = NotificationDispatcher(
url_cache=url_cache,
asset_cache=asset_cache,
session=await get_http_session(),
)
async with AsyncSession(engine) as session:
# App timezone is identical across trackers in one inbound event;
# pull it once.
app_tz = await get_app_timezone(session)
tracker_result = await session.exec(
select(NotificationTracker).where(
NotificationTracker.provider_id == provider_id,
NotificationTracker.enabled == True, # noqa: E712
)
)
trackers = tracker_result.all()
trackers = await _load_trackers_cached(session, provider_id)
for tracker in trackers:
filters = tracker.filters or {}
@@ -147,6 +217,17 @@ async def dispatch_provider_event(
prior = defers_for_event.get(link_id)
if prior is None or outcome.quiet_hours_end_at < prior:
defers_for_event[link_id] = outcome.quiet_hours_end_at
else:
# Non-deferrable event hit quiet hours — stamp
# the event_log so the dashboard surfaces *why*
# the notification never went out.
details = dict(event_log_row.details or {})
if not details.get("dispatch_status"):
details["dispatch_status"] = (
"dropped_quiet_hours_nondeferrable"
)
event_log_row.details = details
session.add(event_log_row)
continue
if outcome.reason is GateReason.EVENT_TYPE_DISABLED:
continue
@@ -162,7 +243,7 @@ async def dispatch_provider_event(
if tmpl and tmpl.date_only_format
else "%d.%m.%Y"
),
provider_api_key=provider_config.get("api_token"),
provider_api_key=resolve_provider_credential(provider_config),
provider_internal_url=provider_config.get("url", ""),
provider_external_url=provider_config.get("url", ""),
receivers=ld["receivers"],
@@ -170,7 +251,10 @@ async def dispatch_provider_event(
key = id(tc) if tc is not None else 0
if key not in groups:
groups[key] = (tc, [])
groups[key][1].append(target_cfg)
# Thread per-target metadata alongside the TargetConfig so the
# bridge_self failure counters can attribute results to a
# specific target_id after dispatch.
groups[key][1].append((target_cfg, ld.get("target_id"), ld.get("target_name", "")))
# Persist defers + stamp event_log dispatch_status in the same
# session that holds the EventLog row, so the "deferred" badge
@@ -198,14 +282,25 @@ async def dispatch_provider_event(
# Dispatch to targets. Isolate dispatcher exceptions per group so
# a failed remote call doesn't bubble out, abort the surrounding
# transaction, and roll back the just-written defers / event_log.
from .http_session import get_http_session
dispatcher = NotificationDispatcher(session=await get_http_session())
for tc, target_configs in groups.values():
if not target_configs:
from .bridge_self import (
maybe_emit_target_failure,
record_target_success,
)
# Skip target-failure tracking when we're already dispatching a
# bridge_self event — otherwise a failing alert target would
# endlessly re-emit alerts about itself.
track_target_failures = (
event.provider_type.value != "bridge_self"
)
for tc, target_entries in groups.values():
if not target_entries:
continue
shaped_event = apply_tracking_display_filters(event, tc)
if shaped_event is None:
continue
target_configs = [entry[0] for entry in target_entries]
try:
results = await dispatcher.dispatch(shaped_event, target_configs)
except Exception as err: # noqa: BLE001
@@ -213,14 +308,29 @@ async def dispatch_provider_event(
"Dispatcher raised for tracker %d: %s", tracker.id, err,
)
continue
for r in results:
for entry, r in zip(target_entries, results):
_, target_id, target_name = entry
if r.get("success"):
dispatched += 1
if track_target_failures and target_id is not None:
record_target_success(int(target_id))
else:
_LOGGER.error(
"Notification failed for tracker %d: %s",
tracker.id, r.get("error", "unknown"),
)
if track_target_failures and target_id is not None:
try:
await maybe_emit_target_failure(
target_id=int(target_id),
target_name=target_name or "",
target_type=entry[0].type,
error=str(r.get("error") or ""),
)
except Exception: # noqa: BLE001
_LOGGER.exception(
"bridge_self target-failure emission failed",
)
await session.commit()
@@ -35,7 +35,7 @@ from notify_bridge_core.providers.home_assistant import (
)
from ..database.engine import get_engine
from ..database.models import ServiceProvider
from ..database.models import EventLog, ServiceProvider
from .event_dispatch import dispatch_provider_event
from .http_session import get_http_session
@@ -104,6 +104,46 @@ def _ha_passes_filters(event: ServiceEvent, filters: dict[str, Any]) -> bool:
return False
async def _record_ha_status(
*,
provider_id: int,
provider_name: str,
state: str,
detail: str | None,
) -> None:
"""Persist an HA connection-status transition as an EventLog row.
Used by the supervisor's ``on_status_change`` callback so the
dashboard surfaces "HA disconnected" / "HA reconnected" events
alongside normal HA state changes. Best-effort: any DB failure is
logged and swallowed so the WS reader path remains untouched.
"""
engine = get_engine()
try:
async with AsyncSession(engine) as session:
session.add(EventLog(
user_id=None, # provider-level event, no per-tracker owner
tracker_id=None,
tracker_name="",
provider_id=provider_id,
provider_name=provider_name,
event_type=f"ha_status_{state}",
collection_id="",
collection_name="",
assets_count=0,
details={
"provider_type": "home_assistant",
"ha_status": state,
"ha_status_detail": detail or "",
},
))
await session.commit()
except Exception: # noqa: BLE001
_LOGGER.exception(
"Failed to persist HA status row for provider %s", provider_id,
)
async def _run_provider(provider_id: int) -> None:
"""One per-provider supervisor loop.
@@ -158,29 +198,38 @@ async def _run_provider(provider_id: int) -> None:
async def _emit(event: ServiceEvent) -> None:
# Shield the DB-writing dispatch from external cancellation
# (shutdown, supervisor restart). The shield ensures that
# once a transaction is mid-flight, it commits or rolls back
# cleanly instead of being torn down with the asyncio task
# at a write boundary. Worst case: shutdown waits up to one
# dispatch latency longer.
# (shutdown, supervisor restart). Without an inner Task,
# ``asyncio.shield(coro)`` cancels the underlying coroutine
# when the outer awaiter is cancelled — defeating the point
# of the shield. We wrap explicitly and *drain* the inner
# task on cancellation so the transaction completes.
#
# Perf note (Phase 2 follow-up): dispatch_provider_event
# opens a fresh AsyncSession per call. For HA's chatty
# state_changed bus this hammers the pool; batch in a
# follow-up.
inner = asyncio.create_task(dispatch_provider_event(
engine=engine,
provider_id=provider_id,
provider_name=provider_name,
provider_config=config,
event=event,
detail_keys=_HA_DETAIL_KEYS,
filter_fn=_ha_passes_filters,
))
try:
await asyncio.shield(dispatch_provider_event(
engine=engine,
provider_id=provider_id,
provider_name=provider_name,
provider_config=config,
event=event,
detail_keys=_HA_DETAIL_KEYS,
filter_fn=_ha_passes_filters,
))
await asyncio.shield(inner)
except asyncio.CancelledError:
# Shield re-raises CancelledError to the caller; let it
# propagate so the drain task exits cleanly.
# Drain the in-flight write before re-raising so the DB
# row commits cleanly (or fails cleanly) instead of
# being torn down at an arbitrary await point.
try:
await inner
except Exception: # noqa: BLE001
_LOGGER.exception(
"HA dispatch raised while draining shielded "
"task for provider %s", provider_id,
)
raise
except Exception: # noqa: BLE001
_LOGGER.exception(
@@ -188,11 +237,33 @@ async def _run_provider(provider_id: int) -> None:
provider_id,
)
def _on_status_change(state: str, detail: str | None) -> None:
"""Persist HA WS connect/disconnect transitions as event_log rows.
The client invokes this synchronously from inside the WS
run loop, so we can't ``await`` here. Schedule a fire-and-
forget task on the same loop instead log failures, never
propagate them back into the WS reader.
"""
try:
asyncio.create_task(_record_ha_status(
provider_id=provider_id,
provider_name=provider_name,
state=state,
detail=detail,
))
except RuntimeError:
# No running loop (shouldn't happen in normal operation).
_LOGGER.debug(
"Skipped HA status row for provider %s: no event loop",
provider_id,
)
_LOGGER.info(
"Starting HA subscription for provider %s (%s)",
provider_id, provider_name,
)
await ha_provider.subscribe(_emit)
await ha_provider.subscribe(_emit, on_status_change=_on_status_change)
except asyncio.CancelledError:
raise
except HomeAssistantAuthError as err:
@@ -6,6 +6,21 @@ per-request ``aiohttp.ClientSession`` instances. This keeps a single
TCP connection pool alive for the lifetime of the process, avoiding
the overhead of pool creation/teardown on every request.
DNS-rebinding mitigation
~~~~~~~~~~~~~~~~~~~~~~~~
The session is wired with a :class:`PinnedResolver` from
``notify_bridge_core.notifications.ssrf`` so the IP that passed the
SSRF block-range check during URL validation is the one aiohttp
actually connects to. Without this pinning a malicious DNS server
could swap a public IP for ``127.0.0.1`` between validation and
connect, defeating the SSRF guard.
Callers that do their own ``avalidate_outbound_url_full`` should also
call :func:`pin_validated` to register the resolved host->IP mapping
on the shared resolver before issuing the request. Callers that just
use the session opportunistically still benefit from scheme + range
checks at validation sites, plus the fallback resolver here.
Call ``close_http_session()`` once during application shutdown.
"""
@@ -15,32 +30,71 @@ import asyncio
import aiohttp
from notify_bridge_core.notifications.ssrf import PinnedResolver, ValidatedURL
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30, connect=10)
_session: aiohttp.ClientSession | None = None
_lock = asyncio.Lock()
_resolver: PinnedResolver | None = None
# Lazy init: ``asyncio.Lock()`` at module import time binds to whichever
# event loop happens to be running (or none). Tests that spin up multiple
# loops (or subprocesses with their own loop) would otherwise hit
# "RuntimeError: ... attached to a different loop". Defer creation to
# first use so the lock binds to the loop that actually calls us.
_lock: asyncio.Lock | None = None
def _get_lock() -> asyncio.Lock:
"""Return the module lock, creating it on first call from this loop."""
global _lock
if _lock is None:
_lock = asyncio.Lock()
return _lock
async def get_http_session() -> aiohttp.ClientSession:
"""Get or create the shared HTTP session.
Concurrent first-callers are serialized through ``_lock`` so we never
leak a second ClientSession / connector pair. Once established, hot
callers skip the lock via the fast-path check.
Concurrent first-callers are serialized through the lazy lock so we
never leak a second ClientSession / connector pair. Once established,
hot callers skip the lock via the fast-path check.
The session uses a :class:`PinnedResolver` connector so callers that
register validated host->IP mappings via :func:`pin_validated` defeat
DNS rebinding between validation and connect.
"""
global _session
global _session, _resolver
if _session is not None and not _session.closed:
return _session
async with _lock:
async with _get_lock():
if _session is None or _session.closed:
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
_resolver = PinnedResolver()
connector = aiohttp.TCPConnector(resolver=_resolver)
_session = aiohttp.ClientSession(
timeout=_DEFAULT_TIMEOUT, connector=connector,
)
return _session
def pin_validated(validated: ValidatedURL) -> None:
"""Register a validated (host, ip) mapping on the shared resolver.
Best-effort: if the resolver has not been created yet (no session
initialised), the call is a no-op. Once the session exists, every
aiohttp connect for ``validated.host`` will use ``validated.ip``.
"""
if _resolver is None:
return
_resolver.pin(validated.host, validated.ip)
async def close_http_session() -> None:
"""Close the shared HTTP session (call on app shutdown)."""
global _session
async with _lock:
global _session, _resolver
async with _get_lock():
if _session is not None and not _session.closed:
await _session.close()
if _resolver is not None:
await _resolver.close()
_session = None
_resolver = None
@@ -24,7 +24,7 @@ from ..database.models import (
TemplateSlot,
TrackingConfig,
)
from .dispatch_helpers import _resolve_target
from .dispatch_helpers import _resolve_target, resolve_provider_credential
from .watcher import _get_telegram_caches
_LOGGER = logging.getLogger(__name__)
@@ -94,7 +94,7 @@ async def dispatch_test_notification(
locale=locale,
date_format=template_config.date_format if template_config else "%d.%m.%Y, %H:%M UTC",
date_only_format=template_config.date_only_format if template_config and template_config.date_only_format else "%d.%m.%Y",
provider_api_key=provider_config.get("api_key"),
provider_api_key=resolve_provider_credential(provider_config),
provider_internal_url=provider_config.get("url", ""),
provider_external_url=provider_config.get("external_domain", ""),
receivers=resolved["receivers"],
@@ -442,8 +442,20 @@ async def _send_telegram_test_per_receiver(
disable_web_page_preview=bool(disable_preview),
)
raw = await asyncio.gather(*(_send_one(r) for r in recv_rows))
results = [r for r in raw if r is not None]
# ``return_exceptions=True`` so a single send raising (e.g. transient
# network error to one chat) doesn't abort the entire fan-out and lose
# the successful sibling sends from the aggregate count.
raw = await asyncio.gather(
*(_send_one(r) for r in recv_rows), return_exceptions=True,
)
results: list[dict] = []
for r in raw:
if isinstance(r, BaseException):
_LOGGER.warning("Test send to receiver raised: %s", r)
continue
if r is None:
continue
results.append(r)
return _aggregate(results)
@@ -234,4 +234,12 @@ _SAMPLE_CONTEXT = {
"target_entity": "light.kitchen",
"ha_event_type": "state_changed",
"event_data": {"foo": "bar"},
# Bridge self-monitoring variables (for bridge_self provider templates)
"failure_type": "poll_failures",
"subject_id": 42,
"subject_name": "My Immich Tracker",
"count": 3,
"threshold": 3,
"last_error": "Connection refused",
"details": {"provider_id": 7, "provider_type": "immich"},
}
@@ -49,6 +49,7 @@ from .dispatch_helpers import (
evaluate_event_gate,
get_app_timezone,
load_link_data,
resolve_provider_credential,
)
from .manual_dispatch import build_immich_dispatch_events
@@ -352,7 +353,7 @@ async def dispatch_scheduled_for_tracker(
date_only_format=(
tmpl.date_only_format or "%d.%m.%Y"
),
provider_api_key=provider_config.get("api_key"),
provider_api_key=resolve_provider_credential(provider_config),
provider_internal_url=provider_config.get("url", ""),
provider_external_url=provider_config.get("external_domain", ""),
receivers=ld["receivers"],
@@ -163,6 +163,9 @@ async def start_scheduler() -> None:
# Schedule the upstream release-check probe.
await _schedule_release_check()
# Schedule the bridge_self deferred-backlog scan (every 5 min).
_schedule_bridge_self_backlog_scan()
def _schedule_event_cleanup() -> None:
"""Schedule a daily job to delete EventLog entries older than 90 days."""
@@ -1122,7 +1125,11 @@ _DRAIN_CATCHUP_INTERVAL_SECONDS = 300
def _drain_job_id_for(fire_at_utc: datetime) -> str:
return f"{_DEFERRED_DRAIN_PREFIX}{fire_at_utc.strftime('%Y%m%d%H%M')}"
# Include seconds — two trackers with quiet windows that end at the same
# minute but different seconds (e.g. user-set 06:00:00 vs 06:00:30) would
# otherwise collide on a single APScheduler job id, and ``replace_existing``
# would silently drop the second one.
return f"{_DEFERRED_DRAIN_PREFIX}{fire_at_utc.strftime('%Y%m%d%H%M%S')}"
def schedule_deferred_drain(fire_at_utc: datetime) -> None:
@@ -1298,6 +1305,50 @@ async def _schedule_release_check() -> None:
interval_hours, _RELEASE_CHECK_ONESHOT_DELAY_SECONDS)
# ---------------------------------------------------------------------------
# Bridge self-monitoring — deferred-backlog scan
# ---------------------------------------------------------------------------
_BRIDGE_SELF_BACKLOG_JOB_ID = "bridge_self_deferred_backlog_scan"
# 5 min trade-off between "operator notices the backlog quickly" and "extra
# DB churn on a quiet system". The scan is one indexed GROUP BY query.
_BRIDGE_SELF_BACKLOG_INTERVAL_SECONDS = 300
def _schedule_bridge_self_backlog_scan() -> None:
"""Install the periodic deferred-backlog scan for bridge_self."""
from apscheduler.triggers.interval import IntervalTrigger
scheduler = get_scheduler()
if scheduler.get_job(_BRIDGE_SELF_BACKLOG_JOB_ID):
return
scheduler.add_job(
_run_bridge_self_backlog_scan,
IntervalTrigger(seconds=_BRIDGE_SELF_BACKLOG_INTERVAL_SECONDS),
id=_BRIDGE_SELF_BACKLOG_JOB_ID,
replace_existing=True,
max_instances=1,
coalesce=True,
)
_LOGGER.info(
"Scheduled bridge_self deferred-backlog scan every %ds",
_BRIDGE_SELF_BACKLOG_INTERVAL_SECONDS,
)
async def _run_bridge_self_backlog_scan() -> None:
"""APScheduler entry point — scan deferred backlog and emit if needed."""
from .bridge_self import check_deferred_backlog
try:
stats = await check_deferred_backlog()
if stats.get("crossings"):
_LOGGER.info("bridge_self backlog scan stats: %s", stats)
else:
_LOGGER.debug("bridge_self backlog scan stats: %s", stats)
except Exception as err: # noqa: BLE001
_LOGGER.exception("bridge_self backlog scan failed: %s", err)
async def reschedule_release_check() -> None:
"""Re-arm the release-check job after settings changed.
@@ -1,6 +1,6 @@
"""Telegram service utilities — chat persistence helpers."""
from sqlmodel import select
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.models import TelegramChat
@@ -12,36 +12,48 @@ async def save_chat_from_webhook(
) -> None:
"""Save or update a chat entry from an incoming webhook message.
Called by the webhook handler to auto-persist chats.
Called by the webhook handler to auto-persist chats. Uses a single
``INSERT ... ON CONFLICT DO UPDATE`` keyed on the
``uq_telegram_chat_bot_chat`` unique constraint so two concurrent
webhook deliveries cannot race a check-then-insert and produce
duplicate rows. Only mutable display/identity fields are updated on
conflict ``commands_enabled``, ``language_override``, and
``discovered_at`` belong to the user / first discovery and stay sticky.
"""
chat_id = str(chat_data.get("id", ""))
if not chat_id:
return
result = await session.exec(
select(TelegramChat).where(
TelegramChat.bot_id == bot_id,
TelegramChat.chat_id == chat_id,
)
)
existing = result.first()
title = chat_data.get("title") or (
chat_data.get("first_name", "") + (" " + chat_data.get("last_name", "")).strip()
)
chat_type = chat_data.get("type", "private")
username = chat_data.get("username", "")
if existing:
existing.title = title
existing.username = chat_data.get("username", existing.username)
if language_code:
existing.language_code = language_code
session.add(existing)
else:
session.add(TelegramChat(
bot_id=bot_id,
chat_id=chat_id,
title=title,
chat_type=chat_data.get("type", "private"),
username=chat_data.get("username", ""),
language_code=language_code,
))
# Only the SQLite dialect path is wired up — the deployed default. A
# future Postgres backend would need pg_insert here; the unique
# constraint name is dialect-portable so the same conflict_target works.
stmt = sqlite_insert(TelegramChat).values(
bot_id=bot_id,
chat_id=chat_id,
title=title,
chat_type=chat_type,
username=username,
language_code=language_code,
)
update_cols: dict = {
"title": title,
"chat_type": chat_type,
"username": username,
}
# Only overwrite language_code when the inbound payload carries one,
# otherwise we'd clobber a previously-detected locale with empty.
if language_code:
update_cols["language_code"] = language_code
stmt = stmt.on_conflict_do_update(
index_elements=["bot_id", "chat_id"],
set_=update_cols,
)
# session.execute (not exec) — exec is the SQLModel/Select wrapper that
# rejects raw Core Insert statements.
await session.execute(stmt)
@@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio
import logging
from typing import Any
from typing import Any, Awaitable, Callable
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -12,6 +12,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.notifications.dispatcher import NotificationDispatcher, TargetConfig
from notify_bridge_core.notifications.telegram.cache import TelegramFileCache
from notify_bridge_core.providers.capabilities import get_capabilities
from notify_bridge_core.storage import JsonFileBackend
from ..database.engine import get_engine
@@ -27,6 +28,7 @@ from .dispatch_helpers import (
evaluate_event_gate,
get_app_timezone,
load_link_data,
resolve_provider_credential,
)
_LOGGER = logging.getLogger(__name__)
@@ -34,7 +36,18 @@ _LOGGER = logging.getLogger(__name__)
# Module-level Telegram file caches — shared across dispatches for reuse
_url_cache: TelegramFileCache | None = None
_asset_cache: TelegramFileCache | None = None
_cache_lock = asyncio.Lock()
# Lazy init: creating ``asyncio.Lock()`` at module import time binds the
# lock to whichever event loop is current at import (often none / the wrong
# one when tests fire up dedicated loops). Defer until first use.
_cache_lock: asyncio.Lock | None = None
def _get_cache_lock() -> asyncio.Lock:
"""Return the module cache lock, creating it on first call."""
global _cache_lock
if _cache_lock is None:
_cache_lock = asyncio.Lock()
return _cache_lock
async def _load_cache_settings() -> tuple[int, int]:
@@ -68,7 +81,7 @@ async def _get_telegram_caches() -> tuple[TelegramFileCache | None, TelegramFile
global _url_cache, _asset_cache
if _url_cache is not None:
return _url_cache, _asset_cache
async with _cache_lock:
async with _get_cache_lock():
# Double-check after acquiring lock
if _url_cache is not None:
return _url_cache, _asset_cache
@@ -108,7 +121,7 @@ async def reset_telegram_caches_in_memory() -> None:
deletes cached file_ids.
"""
global _url_cache, _asset_cache
async with _cache_lock:
async with _get_cache_lock():
_url_cache = None
_asset_cache = None
_LOGGER.info("Reset Telegram cache refs in memory (files preserved)")
@@ -135,7 +148,7 @@ async def clear_telegram_caches() -> dict[str, Any]:
Returns a summary with the paths that were removed.
"""
global _url_cache, _asset_cache
async with _cache_lock:
async with _get_cache_lock():
removed: list[str] = []
for cache, label in ((_url_cache, "url"), (_asset_cache, "asset")):
if cache is not None:
@@ -163,6 +176,90 @@ async def clear_telegram_caches() -> dict[str, Any]:
return {"cleared": True, "removed": removed}
# ---------------------------------------------------------------------------
# Provider polling registry
# ---------------------------------------------------------------------------
#
# Each registered factory returns (events, new_state). Replaces the long
# ``if provider_type == ...`` chain in ``check_tracker``. New pollable
# providers register here; webhook-only providers are short-circuited above
# via ``capabilities.webhook_based``.
class _PollerConnectError(Exception):
"""Raised by a poller factory when initial provider connection fails."""
def __init__(self, reason: str) -> None:
super().__init__(reason)
self.reason = reason
PollResult = tuple[list[ServiceEvent], dict[str, Any]]
PollerFactory = Callable[..., Awaitable[PollResult]]
async def _poll_immich(*, provider_config, provider_name, collection_ids, state_dict, **_kw) -> PollResult:
from notify_bridge_core.providers.immich import ImmichServiceProvider
from .http_session import get_http_session
http_session = await get_http_session()
immich = ImmichServiceProvider(
http_session,
provider_config.get("url", ""),
provider_config.get("api_key", ""),
provider_config.get("external_domain"),
provider_name,
)
if not await immich.connect():
raise _PollerConnectError("failed to connect to provider")
return await immich.poll(collection_ids, state_dict)
async def _poll_scheduler(*, provider_name, tracker_name, tracker_filters, collection_ids, state_dict, app_tz, **_kw) -> PollResult:
from notify_bridge_core.providers.scheduler import SchedulerServiceProvider
sched = SchedulerServiceProvider(
name=provider_name,
tracker_name=tracker_name,
custom_variables=tracker_filters.get("custom_variables", {}),
timezone_name=app_tz,
)
return await sched.poll(collection_ids, state_dict)
async def _poll_nut(*, provider_config, provider_name, collection_ids, state_dict, **_kw) -> PollResult:
from notify_bridge_core.providers.nut import NutServiceProvider
nut = NutServiceProvider(
host=provider_config.get("host", "localhost"),
port=provider_config.get("port", 3493),
username=provider_config.get("username"),
password=provider_config.get("password"),
name=provider_name,
)
return await nut.poll(collection_ids, state_dict)
async def _poll_google_photos(*, provider_config, provider_name, collection_ids, state_dict, **_kw) -> PollResult:
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
from .http_session import get_http_session
http_session = await get_http_session()
gp = GooglePhotosServiceProvider(
http_session,
provider_config.get("client_id", ""),
provider_config.get("client_secret", ""),
provider_config.get("refresh_token", ""),
provider_name,
)
if not await gp.connect():
raise _PollerConnectError("failed to connect to Google Photos")
return await gp.poll(collection_ids, state_dict)
_POLL_FACTORIES: dict[str, PollerFactory] = {
"immich": _poll_immich,
"scheduler": _poll_scheduler,
"nut": _poll_nut,
"google_photos": _poll_google_photos,
}
async def check_tracker(tracker_id: int) -> dict[str, Any]:
"""Poll a tracker's provider for changes and dispatch notifications."""
engine = get_engine()
@@ -223,70 +320,61 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
events: list[ServiceEvent] = []
new_state: dict[str, Any] = {}
if provider_type == "immich":
from notify_bridge_core.providers.immich import ImmichServiceProvider
from .http_session import get_http_session
http_session = await get_http_session()
immich = ImmichServiceProvider(
http_session,
provider_config.get("url", ""),
provider_config.get("api_key", ""),
provider_config.get("external_domain"),
provider_name,
)
connected = await immich.connect()
if not connected:
return {"status": "error", "reason": "failed to connect to provider"}
# Webhook-only providers: capabilities.webhook_based short-circuits the
# poll path. Inbound events arrive via the /api/webhooks/* endpoints.
caps = get_capabilities(provider_type)
if caps is not None and caps.webhook_based:
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
events, new_state = await immich.poll(collection_ids, state_dict)
elif provider_type == "gitea":
# Gitea is webhook-based — events arrive via /api/webhooks/gitea endpoint.
# The scheduler still calls check_tracker but there's nothing to poll.
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
elif provider_type == "planka":
# Planka is webhook-based — events arrive via /api/webhooks/planka endpoint.
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
elif provider_type == "scheduler":
from notify_bridge_core.providers.scheduler import SchedulerServiceProvider
custom_vars = tracker_filters.get("custom_variables", {})
sched = SchedulerServiceProvider(
name=provider_name,
tracker_name=tracker_name,
custom_variables=custom_vars,
timezone_name=app_tz,
)
events, new_state = await sched.poll(collection_ids, state_dict)
elif provider_type == "nut":
from notify_bridge_core.providers.nut import NutServiceProvider
nut = NutServiceProvider(
host=provider_config.get("host", "localhost"),
port=provider_config.get("port", 3493),
username=provider_config.get("username"),
password=provider_config.get("password"),
name=provider_name,
)
events, new_state = await nut.poll(collection_ids, state_dict)
elif provider_type == "google_photos":
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
from .http_session import get_http_session
http_session = await get_http_session()
gp = GooglePhotosServiceProvider(
http_session,
provider_config.get("client_id", ""),
provider_config.get("client_secret", ""),
provider_config.get("refresh_token", ""),
provider_name,
)
connected = await gp.connect()
if not connected:
return {"status": "error", "reason": "failed to connect to Google Photos"}
events, new_state = await gp.poll(collection_ids, state_dict)
elif provider_type == "webhook":
# Webhook providers receive events via inbound HTTP; no polling needed.
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
else:
poller = _POLL_FACTORIES.get(provider_type)
if poller is None:
return {"status": "error", "reason": f"unsupported provider type: {provider_type}"}
try:
events, new_state = await poller(
provider_config=provider_config,
provider_name=provider_name,
tracker_name=tracker_name,
tracker_filters=tracker_filters,
collection_ids=collection_ids,
state_dict=state_dict,
app_tz=app_tz,
)
except _PollerConnectError as exc:
# Track consecutive poll failures so the bridge_self provider can
# alert when a tracker stops responding. The emission is async
# but cheap; we await it inline so its DB writes happen before
# check_tracker returns to the scheduler.
from .bridge_self import maybe_emit_poll_failure
try:
await maybe_emit_poll_failure(
tracker_id=tracker_id,
tracker_name=tracker_name,
error=exc.reason,
)
except Exception: # noqa: BLE001
_LOGGER.exception("bridge_self poll-failure emission failed")
return {"status": "error", "reason": exc.reason}
except Exception as exc: # noqa: BLE001
# Catch broader poll exceptions (e.g. a provider-side bug, transient
# network error inside the poller after connect) so the same
# streak-tracking logic applies. Re-raised after the bookkeeping so
# the existing error path keeps logging at the caller.
from .bridge_self import maybe_emit_poll_failure
try:
await maybe_emit_poll_failure(
tracker_id=tracker_id,
tracker_name=tracker_name,
error=str(exc),
)
except Exception: # noqa: BLE001
_LOGGER.exception("bridge_self poll-failure emission failed")
raise
# Successful poll — clear the consecutive-failure counter for this tracker.
from .bridge_self import record_poll_success
record_poll_success(tracker_id)
# Save updated state and log events
async with AsyncSession(engine) as session:
for cid, cstate in new_state.items():
@@ -328,6 +416,16 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
# row if quiet hours suppresses it.
event_log_id_by_event: dict[int, int] = {}
for event in events:
# Skip persistence for events the dispatch loop will filter
# anyway (assets_added with 0 added, assets_removed with 0
# removed). Without this we wrote a "noise" row for every
# tracker tick that detected nothing. The dispatch-time filter
# below still runs as a safety net.
etype = event.event_type.value
if etype == "assets_added" and event.added_count == 0:
continue
if etype == "assets_removed" and event.removed_count == 0:
continue
assets_count = event.added_count or event.removed_count or 0
details: dict[str, Any] = {
"added_count": event.added_count,
@@ -445,7 +543,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
template_slots=ld["template_slots"],
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
date_only_format=tmpl.date_only_format if tmpl and tmpl.date_only_format else "%d.%m.%Y",
provider_api_key=provider_config.get("api_key"),
provider_api_key=resolve_provider_credential(provider_config),
provider_internal_url=provider_config.get("url", ""),
provider_external_url=provider_config.get("external_domain", ""),
receivers=ld["receivers"],
@@ -453,7 +551,9 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
key = id(tc) if tc is not None else 0
if key not in groups:
groups[key] = (tc, [])
groups[key][1].append(target_cfg)
# Threaded with target_id/target_name so per-target failure
# counters can attribute the dispatch result correctly.
groups[key][1].append((target_cfg, ld.get("target_id"), ld.get("target_name", "")))
# Persist defers + stamp the event_log row + schedule drains in a
# single transaction. This keeps the "deferred" pill on the
@@ -496,8 +596,17 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
"Failed to schedule deferred drain for %s", fire_at,
)
for tc, target_configs in groups.values():
if not target_configs:
from .bridge_self import (
maybe_emit_target_failure,
record_target_success,
)
track_target_failures = (
event.provider_type.value != "bridge_self"
)
for tc, target_entries in groups.values():
if not target_entries:
continue
shaped_event = apply_tracking_display_filters(event, tc)
if shaped_event is None:
@@ -505,12 +614,28 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
" Event suppressed by display filters (favorites_only)",
)
continue
target_configs = [entry[0] for entry in target_entries]
results = await dispatcher.dispatch(shaped_event, target_configs)
for r in results:
for entry, r in zip(target_entries, results):
_, target_id, target_name = entry
if r.get("success"):
_LOGGER.info(" Notification sent successfully")
if track_target_failures and target_id is not None:
record_target_success(int(target_id))
else:
_LOGGER.error(" Notification failed: %s", r.get("error", "unknown"))
if track_target_failures and target_id is not None:
try:
await maybe_emit_target_failure(
target_id=int(target_id),
target_name=target_name or "",
target_type=entry[0].type,
error=str(r.get("error") or ""),
)
except Exception: # noqa: BLE001
_LOGGER.exception(
"bridge_self target-failure emission failed",
)
return {
"status": "ok",
@@ -0,0 +1,268 @@
"""End-to-end backup roundtrip: seed -> export -> wipe -> import -> verify.
Drives the backup service module directly (no HTTP layer) against a fresh
SQLite DB built in the conftest temp data dir. Verifies entity counts and
key fields survive a full round-trip.
Kept under 5s by avoiding the lifespan startup we build a private engine
in an isolated DB file so we don't share state with other tests in the
session.
"""
from __future__ import annotations
from pathlib import Path
import pytest
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
@pytest.fixture
async def isolated_engine(tmp_path: Path):
"""A throwaway SQLite engine + freshly created schema for one test.
Avoids the global engine in ``database.engine`` tests in the same
session share that singleton, and recreating tables on it would corrupt
parallel tests' state.
"""
# Importing the module registers all SQLModel tables on the metadata.
from notify_bridge_server.database import models # noqa: F401
db_path = tmp_path / "roundtrip.db"
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
yield engine
await engine.dispose()
async def _seed(session: AsyncSession, user_id: int) -> dict[str, int]:
"""Insert enough rows to exercise the major code paths in import/export."""
from notify_bridge_server.database.models import (
EventLog,
NotificationTarget,
NotificationTracker,
ServiceProvider,
TargetReceiver,
TelegramBot,
TrackingConfig,
User,
)
user = User(
id=user_id,
username="roundtrip-user",
hashed_password="hash",
role="user",
)
session.add(user)
await session.flush()
bot = TelegramBot(
user_id=user_id, name="Test bot", token="123456:fake-token-value",
bot_username="testbot", bot_id=1,
)
session.add(bot)
await session.flush()
provider = ServiceProvider(
user_id=user_id, type="immich", name="Immich prod",
config={"base_url": "https://immich.example.com", "api_key": "secret"},
)
session.add(provider)
await session.flush()
target = NotificationTarget(
user_id=user_id, type="telegram", name="My channel",
config={"bot_token_id": bot.id, "disable_url_preview": True},
)
session.add(target)
await session.flush()
receiver = TargetReceiver(
target_id=target.id, name="Channel A",
config={"chat_id": "-100123"}, receiver_key="-100123", locale="en",
)
session.add(receiver)
tc = TrackingConfig(
user_id=user_id, provider_type="immich",
name="Default Immich tracking", track_assets_added=True,
)
session.add(tc)
await session.flush()
tracker = NotificationTracker(
user_id=user_id, provider_id=provider.id,
name="Family album tracker", scan_interval=120,
collection_ids=["album-uuid-1"],
)
session.add(tracker)
await session.flush()
# Capture IDs before commit — accessing attributes after commit
# triggers a refresh that needs an async-IO context the test caller
# may not be inside. Better to snapshot now and use plain ints later.
ids = {
"provider_id": provider.id,
"target_id": target.id,
"bot_id": bot.id,
"tracker_id": tracker.id,
"tracking_config_id": tc.id,
"tracker_name": tracker.name,
"provider_name": provider.name,
}
# EventLog rows are NOT in the backup schema — they're operational data,
# not configuration. Insert a few anyway so we can verify they survive
# the export step (since export only reads, never writes/wipes them).
for i in range(3):
session.add(EventLog(
user_id=user_id, tracker_id=ids["tracker_id"], tracker_name=ids["tracker_name"],
provider_id=ids["provider_id"], provider_name=ids["provider_name"],
event_type="assets_added", collection_id="album-uuid-1",
collection_name="Family", assets_count=i,
))
await session.commit()
return ids
async def _wipe_user_owned_rows(engine, user_id: int) -> None:
"""Delete every backup-able row for the user via raw SQL.
Using ORM-level deletes triggers SQLAlchemy's cascade machinery, which
lazy-loads relationships in a sync context that the async driver cannot
serve (MissingGreenlet). Raw DELETE statements skip cascades and let
SQLite's FKs enforce ordering naturally.
Order matters: child rows first, then parents.
"""
from sqlalchemy import text
statements = (
"DELETE FROM event_log",
"DELETE FROM notification_tracker_target",
"DELETE FROM notification_tracker",
"DELETE FROM target_receiver",
"DELETE FROM notification_target",
"DELETE FROM tracking_config",
"DELETE FROM service_provider",
"DELETE FROM template_slot",
"DELETE FROM template_config",
"DELETE FROM telegram_bot",
"DELETE FROM appsetting",
)
async with engine.begin() as conn:
for stmt in statements:
try:
await conn.execute(text(stmt))
except Exception: # noqa: BLE001 — table may not exist in test schema
pass
@pytest.mark.asyncio
async def test_export_wipe_import_roundtrip(isolated_engine, tmp_data_dir) -> None: # noqa: ARG001
"""A full round-trip preserves entity counts and the key fields the
UI relies on names, configs (with secrets included), provider
references via id_map.
"""
from notify_bridge_server.database.models import (
NotificationTarget, NotificationTracker, ServiceProvider,
TargetReceiver, TelegramBot, TrackingConfig,
)
from notify_bridge_server.services.backup_schema import (
ConflictMode, SecretsMode,
)
from notify_bridge_server.services.backup_service import (
export_backup, import_backup,
)
user_id = 1
# ---- Seed ----
async with AsyncSession(isolated_engine) as session:
ids = await _seed(session, user_id)
# ---- Export with secrets included so import sees real values ----
async with AsyncSession(isolated_engine) as session:
backup = await export_backup(
session, user_id, secrets_mode=SecretsMode.INCLUDE,
)
assert len(backup.data.providers) == 1
assert len(backup.data.telegram_bots) == 1
assert len(backup.data.targets) == 1
assert len(backup.data.targets[0].receivers) == 1
assert len(backup.data.tracking_configs) == 1
assert len(backup.data.notification_trackers) == 1
assert backup.data.providers[0].config["api_key"] == "secret"
# ---- Wipe ----
await _wipe_user_owned_rows(isolated_engine, user_id)
async with AsyncSession(isolated_engine) as session:
result = await session.exec(
select(ServiceProvider).where(ServiceProvider.user_id == user_id)
)
assert result.all() == []
# ---- Import ----
async with AsyncSession(isolated_engine) as session:
result = await import_backup(
session, user_id, backup, conflict_mode=ConflictMode.SKIP,
)
assert result.errors == [], f"Import errors: {result.errors}"
assert result.created > 0
# ---- Verify the entities are back ----
async with AsyncSession(isolated_engine) as session:
providers = (await session.exec(
select(ServiceProvider).where(ServiceProvider.user_id == user_id)
)).all()
assert len(providers) == 1
prov = providers[0]
assert prov.name == "Immich prod"
assert prov.config["base_url"] == "https://immich.example.com"
# Secrets imported intact when SecretsMode.INCLUDE was used at export.
assert prov.config["api_key"] == "secret"
bots = (await session.exec(
select(TelegramBot).where(TelegramBot.user_id == user_id)
)).all()
assert len(bots) == 1
assert bots[0].name == "Test bot"
targets = (await session.exec(
select(NotificationTarget).where(NotificationTarget.user_id == user_id)
)).all()
assert len(targets) == 1
receivers = (await session.exec(
select(TargetReceiver).where(TargetReceiver.target_id == targets[0].id)
)).all()
assert len(receivers) == 1
assert receivers[0].config["chat_id"] == "-100123"
tcs = (await session.exec(
select(TrackingConfig).where(TrackingConfig.user_id == user_id)
)).all()
assert len(tcs) == 1
assert tcs[0].name == "Default Immich tracking"
trackers = (await session.exec(
select(NotificationTracker).where(NotificationTracker.user_id == user_id)
)).all()
assert len(trackers) == 1
# provider_id was remapped via id_map — original provider id may have
# changed across the wipe, so just check it links to a real row.
assert trackers[0].provider_id == prov.id
assert trackers[0].scan_interval == 120
assert trackers[0].collection_ids == ["album-uuid-1"]
+265
View File
@@ -0,0 +1,265 @@
"""Tests for the bridge self-monitoring provider.
Covers:
1. ``build_event`` parses a well-formed payload and rejects malformed ones.
2. The threshold-crossing helpers in ``services.bridge_self`` only emit on
the actual crossing, not on every increment afterwards (anti-spam).
3. ``ensure_bridge_self_provider_for_user`` creates exactly one provider
per user and is idempotent on re-run.
4. The capability registry exposes the new event/slot definitions.
"""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
# ---------------------------------------------------------------------------
# Event parser
# ---------------------------------------------------------------------------
def test_build_event_well_formed_payload() -> None:
from notify_bridge_core.providers.bridge_self.event_parser import build_event
from notify_bridge_core.models.events import EventType
from notify_bridge_core.providers.base import ServiceProviderType
payload = {
"failure_type": "poll_failures",
"subject_id": 7,
"subject_name": "My Tracker",
"count": 3,
"threshold": 3,
"last_error": "Timeout",
"details": {"tracker_id": 7},
}
when = datetime(2026, 5, 16, 10, 0, tzinfo=timezone.utc)
event = build_event(payload, timestamp=when)
assert event is not None
assert event.event_type == EventType.BRIDGE_SELF_POLL_FAILURES
assert event.provider_type == ServiceProviderType.BRIDGE_SELF
assert event.collection_id == "7"
assert event.collection_name == "My Tracker"
assert event.timestamp == when
assert event.extra["count"] == 3
assert event.extra["threshold"] == 3
assert event.extra["last_error"] == "Timeout"
assert event.extra["failure_type"] == "poll_failures"
assert event.extra["details"] == {"tracker_id": 7}
def test_build_event_unknown_failure_type_returns_none() -> None:
from notify_bridge_core.providers.bridge_self.event_parser import build_event
assert build_event({"failure_type": "rocket_launch"}) is None
def test_build_event_non_dict_payload_returns_none() -> None:
from notify_bridge_core.providers.bridge_self.event_parser import build_event
assert build_event("not a dict") is None # type: ignore[arg-type]
assert build_event(None) is None # type: ignore[arg-type]
def test_build_event_clamps_long_error_messages() -> None:
from notify_bridge_core.providers.bridge_self.event_parser import (
build_event, _MAX_ERROR_LEN,
)
huge = "X" * (_MAX_ERROR_LEN * 5)
event = build_event({
"failure_type": "target_failures",
"subject_id": 1,
"subject_name": "t",
"count": 5,
"threshold": 5,
"last_error": huge,
})
assert event is not None
assert len(event.extra["last_error"]) <= _MAX_ERROR_LEN
# ---------------------------------------------------------------------------
# Threshold-crossing counters
# ---------------------------------------------------------------------------
def test_record_poll_failure_increments_then_success_resets() -> None:
from notify_bridge_server.services import bridge_self as bs
# Use a tracker_id we know is unique to this test to avoid pollution
# across tests sharing the module-level dicts.
tid = 9_001
bs.reset_poll_counter(tid)
assert bs.record_poll_failure(tid, "boom") == 1
assert bs.record_poll_failure(tid, "boom") == 2
assert bs.record_poll_failure(tid, "boom") == 3
assert bs.get_poll_failure_count(tid) == 3
assert bs.get_poll_last_error(tid) == "boom"
bs.record_poll_success(tid)
assert bs.get_poll_failure_count(tid) == 0
assert bs.get_poll_last_error(tid) == ""
def test_record_target_failure_increments_then_success_resets() -> None:
from notify_bridge_server.services import bridge_self as bs
tid = 9_101
bs.reset_target_counter(tid)
assert bs.record_target_failure(tid, "503") == 1
assert bs.record_target_failure(tid, "503") == 2
assert bs.get_target_failure_count(tid) == 2
bs.record_target_success(tid)
assert bs.get_target_failure_count(tid) == 0
def test_backlog_state_only_emits_on_crossing() -> None:
"""Only the False -> True transition should report a crossing.
A sustained backlog must not re-fire on every scan, and a recovered
backlog re-arms the latch so the next crossing is reported again.
"""
from notify_bridge_server.services import bridge_self as bs
user_id = 9_201
# Reset latch by going through a False reading first.
bs._backlog_above_threshold.pop(user_id, None)
# Initial above-threshold reading IS a crossing (None -> True latch).
assert bs.record_backlog_state(user_id, True) is True
# Sustained above — no second alert.
assert bs.record_backlog_state(user_id, True) is False
assert bs.record_backlog_state(user_id, True) is False
# Drop below — no alert (we don't notify on recovery).
assert bs.record_backlog_state(user_id, False) is False
# Cross again — alert.
assert bs.record_backlog_state(user_id, True) is True
# ---------------------------------------------------------------------------
# ensure_bridge_self_provider_for_user — DB roundtrip
# ---------------------------------------------------------------------------
@pytest.fixture
async def session() -> AsyncSession:
"""Fresh in-memory DB with the SQLModel schema applied."""
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
async with AsyncSession(engine) as session:
yield session
await engine.dispose()
@pytest.mark.asyncio
async def test_ensure_bridge_self_provider_creates_once(session: AsyncSession) -> None:
from notify_bridge_server.database.models import ServiceProvider, User
from notify_bridge_server.database.seeds import (
ensure_bridge_self_provider_for_user,
)
# Create a real user.
user = User(username="alice", hashed_password="x", role="user")
session.add(user)
await session.commit()
await session.refresh(user)
user_id = user.id
p1 = await ensure_bridge_self_provider_for_user(session, user_id)
assert p1 is not None
p1_id = p1.id
assert p1.type == "bridge_self"
assert p1.user_id == user_id
assert p1.config["poll_failure_threshold"] == 3
assert p1.config["deferred_backlog_threshold"] == 100
assert p1.config["target_failure_threshold"] == 5
await session.commit()
# Idempotent: second call returns the same row, no duplicates.
p2 = await ensure_bridge_self_provider_for_user(session, user_id)
assert p2 is not None
assert p2.id == p1_id
await session.commit()
rows = (
await session.exec(
select(ServiceProvider).where(
ServiceProvider.user_id == user_id,
ServiceProvider.type == "bridge_self",
)
)
).all()
assert len(rows) == 1
@pytest.mark.asyncio
async def test_ensure_bridge_self_provider_skips_system_user(session: AsyncSession) -> None:
"""user_id <= 0 is the __system__ placeholder — never gets a provider."""
from notify_bridge_server.database.seeds import (
ensure_bridge_self_provider_for_user,
)
result = await ensure_bridge_self_provider_for_user(session, 0)
assert result is None
# ---------------------------------------------------------------------------
# Capability registry
# ---------------------------------------------------------------------------
def test_capability_registry_lists_bridge_self() -> None:
from notify_bridge_core.providers.capabilities import (
get_capabilities, get_all_capabilities,
)
caps = get_capabilities("bridge_self")
assert caps is not None
assert caps.provider_type == "bridge_self"
assert caps.webhook_based is False
event_names = {e["name"] for e in caps.events}
assert event_names == {
"bridge_self_poll_failures",
"bridge_self_deferred_backlog",
"bridge_self_target_failures",
}
slot_names = {s["name"] for s in caps.notification_slots}
assert slot_names == {
"message_bridge_self_poll_failures",
"message_bridge_self_deferred_backlog",
"message_bridge_self_target_failures",
}
# And it shows up in the global registry.
assert "bridge_self" in get_all_capabilities()
def test_default_template_loader_returns_bridge_self_slots() -> None:
"""All three bridge_self slots have shipped Jinja2 default templates."""
from notify_bridge_core.templates.defaults.loader import load_default_templates
en = load_default_templates("en", "bridge_self")
ru = load_default_templates("ru", "bridge_self")
expected = {
"message_bridge_self_poll_failures",
"message_bridge_self_deferred_backlog",
"message_bridge_self_target_failures",
}
assert set(en.keys()) == expected
assert set(ru.keys()) == expected
# Sanity: each template references at least one of the bridge_self vars.
for tpl in list(en.values()) + list(ru.values()):
assert "{{" in tpl
+249
View File
@@ -0,0 +1,249 @@
"""Unit tests for the Gitea webhook parser.
Pure-function tests against ``parse_webhook`` using realistic Gitea
payloads (trimmed to the fields the parser actually consumes). No DB or
HTTP fixtures needed.
"""
from __future__ import annotations
from notify_bridge_core.models.events import EventType
from notify_bridge_core.providers.base import ServiceProviderType
from notify_bridge_core.providers.gitea.event_parser import parse_webhook
def _repo() -> dict:
return {
"id": 42,
"name": "demo",
"full_name": "alexei/demo",
"html_url": "https://git.example.com/alexei/demo",
"description": "Demo repo",
"private": False,
"owner": {
"id": 1,
"login": "alexei",
"full_name": "Alexei",
"email": "alexei@example.com",
"avatar_url": "https://git.example.com/avatars/1",
},
}
def _sender() -> dict:
return {
"id": 1,
"login": "alexei",
"full_name": "Alexei",
"avatar_url": "https://git.example.com/avatars/1",
}
def test_push_event() -> None:
payload = {
"ref": "refs/heads/master",
"before": "0000000000000000000000000000000000000000",
"after": "abcdef0123456789abcdef0123456789abcdef01",
"compare_url": "https://git.example.com/alexei/demo/compare/000...abc",
"commits": [
{
"id": "abcdef0123456789abcdef0123456789abcdef01",
"message": "feat: initial commit\n\nMore detail.",
"url": "https://git.example.com/alexei/demo/commit/abcdef0",
"author": {
"name": "Alexei",
"email": "alexei@example.com",
"username": "alexei",
},
"timestamp": "2026-05-16T10:00:00Z",
},
{
"id": "1234567890123456789012345678901234567890",
"message": "chore: tweak",
"url": "https://git.example.com/alexei/demo/commit/1234567",
"author": {"name": "Alexei", "email": "alexei@example.com"},
"timestamp": "2026-05-16T10:05:00Z",
},
],
"repository": _repo(),
"sender": _sender(),
}
evt = parse_webhook("push", payload, provider_name="gitea-prod")
assert evt is not None
assert evt.event_type is EventType.PUSH
assert evt.provider_type is ServiceProviderType.GITEA
assert evt.collection_id == "alexei/demo"
assert evt.collection_name == "alexei/demo"
assert evt.extra["ref"] == "refs/heads/master"
assert evt.extra["branch"] == "master"
assert evt.extra["commit_count"] == 2
assert evt.extra["commits"][0]["short_id"] == "abcdef0"
# The first commit's multi-line body must be preserved (.strip handles
# trailing newlines but should keep the inner '\n').
assert "feat: initial commit" in evt.extra["commits"][0]["message"]
def test_issue_opened() -> None:
payload = {
"action": "opened",
"issue": {
"id": 100,
"number": 7,
"title": "Bug: thing broken",
"html_url": "https://git.example.com/alexei/demo/issues/7",
"state": "open",
"body": "Steps to reproduce...",
"labels": [{"name": "bug"}, {"name": "p1"}],
},
"repository": _repo(),
"sender": _sender(),
}
evt = parse_webhook("issues", payload, provider_name="gitea-prod")
assert evt is not None
assert evt.event_type is EventType.ISSUE_OPENED
assert evt.collection_id == "alexei/demo"
assert evt.extra["issue_number"] == 7
assert evt.extra["issue_title"] == "Bug: thing broken"
assert evt.extra["issue_labels"] == ["bug", "p1"]
def test_issue_closed() -> None:
payload = {
"action": "closed",
"issue": {
"id": 100,
"number": 7,
"title": "Bug: thing broken",
"html_url": "https://git.example.com/alexei/demo/issues/7",
"state": "closed",
"body": "",
"labels": [],
},
"repository": _repo(),
"sender": _sender(),
}
evt = parse_webhook("issues", payload, provider_name="gitea-prod")
assert evt is not None
assert evt.event_type is EventType.ISSUE_CLOSED
assert evt.extra["issue_state"] == "closed"
def test_pr_opened() -> None:
payload = {
"action": "opened",
"pull_request": {
"id": 200,
"number": 12,
"title": "Add metrics endpoint",
"html_url": "https://git.example.com/alexei/demo/pulls/12",
"state": "open",
"body": "PR body",
"merged": False,
"base": {"ref": "master", "label": "alexei:master"},
"head": {"ref": "feat/metrics", "label": "alexei:feat/metrics"},
"labels": [{"name": "enhancement"}],
},
"repository": _repo(),
"sender": _sender(),
}
evt = parse_webhook("pull_request", payload, provider_name="gitea-prod")
assert evt is not None
assert evt.event_type is EventType.PR_OPENED
assert evt.extra["pr_number"] == 12
assert evt.extra["pr_merged"] is False
assert evt.extra["pr_base"] == "alexei:master"
assert evt.extra["pr_head"] == "alexei:feat/metrics"
def test_pr_merged_resolves_from_closed_with_merged_flag() -> None:
"""A 'closed' action with merged=True is the merge signal — Gitea does
not send a distinct event header for it, so the parser must promote
PR_CLOSED -> PR_MERGED on its own."""
payload = {
"action": "closed",
"pull_request": {
"id": 200,
"number": 12,
"title": "Add metrics endpoint",
"html_url": "https://git.example.com/alexei/demo/pulls/12",
"state": "closed",
"body": "",
"merged": True,
"base": {"ref": "master"},
"head": {"ref": "feat/metrics"},
"labels": [],
},
"repository": _repo(),
"sender": _sender(),
}
evt = parse_webhook("pull_request", payload, provider_name="gitea-prod")
assert evt is not None
assert evt.event_type is EventType.PR_MERGED
assert evt.extra["pr_merged"] is True
def test_pr_closed_without_merge() -> None:
payload = {
"action": "closed",
"pull_request": {
"id": 200,
"number": 12,
"title": "Abandoned PR",
"html_url": "https://git.example.com/alexei/demo/pulls/12",
"state": "closed",
"body": "",
"merged": False,
"base": {"ref": "master"},
"head": {"ref": "feat/x"},
"labels": [],
},
"repository": _repo(),
"sender": _sender(),
}
evt = parse_webhook("pull_request", payload, provider_name="gitea-prod")
assert evt is not None
assert evt.event_type is EventType.PR_CLOSED
def test_release_published() -> None:
payload = {
"action": "published",
"release": {
"id": 9,
"tag_name": "v1.2.3",
"name": "Release v1.2.3",
"html_url": "https://git.example.com/alexei/demo/releases/tag/v1.2.3",
"body": "Bug fixes and improvements",
"draft": False,
"prerelease": False,
},
"repository": _repo(),
"sender": _sender(),
}
evt = parse_webhook("release", payload, provider_name="gitea-prod")
assert evt is not None
assert evt.event_type is EventType.RELEASE_PUBLISHED
assert evt.extra["release_tag"] == "v1.2.3"
assert evt.extra["release_prerelease"] is False
def test_release_non_published_is_ignored() -> None:
"""Only ``published`` releases should produce events — drafts and edits
are noise and would spam any tracker subscribed to release notifications."""
payload = {
"action": "edited",
"release": {
"id": 9, "tag_name": "v1.2.3", "name": "x",
"html_url": "", "body": "",
"draft": True, "prerelease": False,
},
"repository": _repo(),
"sender": _sender(),
}
assert parse_webhook("release", payload, provider_name="g") is None
def test_unknown_event_header_returns_none() -> None:
payload = {"repository": _repo(), "sender": _sender()}
assert parse_webhook("unknown_event", payload, provider_name="g") is None
+7 -1
View File
@@ -27,7 +27,13 @@ def test_ready_endpoint(tmp_data_dir) -> None: # noqa: ARG001
resp = client.get("/api/ready")
# By the time TestClient yields, lifespan startup has completed.
assert resp.status_code == 200
assert resp.json()["status"] == "ready"
body = resp.json()
assert body["ready"] is True
assert body["checks"]["db"] == "ok"
assert body["checks"]["scheduler"] == "ok"
# No HA providers configured by default in the test fixture.
assert body["checks"]["ha"] == "na"
assert body["errors"] == []
def test_health_is_anonymous(tmp_data_dir) -> None: # noqa: ARG001
@@ -0,0 +1,159 @@
"""Unit tests for Immich album change detection.
Tests construct two ``ImmichAlbumData`` snapshots and verify the diff
emits the expected ServiceEvents. No HTTP, no DB. Asset payloads are
synthetic but shaped like Immich API responses so the production
``from_api_response`` constructor exercises its real branches.
"""
from __future__ import annotations
from notify_bridge_core.models.events import EventType
from notify_bridge_core.providers.base import ServiceProviderType
from notify_bridge_core.providers.immich.change_detector import (
detect_album_changes,
)
from notify_bridge_core.providers.immich.models import ImmichAlbumData
_EXTERNAL = "https://immich.example.com"
def _asset(asset_id: str, *, processed: bool = True, type_: str = "IMAGE") -> dict:
"""Build an Immich asset payload that ``from_api_response`` accepts."""
return {
"id": asset_id,
"type": type_,
"originalFileName": f"{asset_id}.jpg",
"fileCreatedAt": "2026-05-15T12:00:00.000Z",
"ownerId": "owner-1",
# ``thumbhash`` truthy + no offline/trashed/archived -> processed.
# Skipped when caller asks for an unprocessed asset.
"thumbhash": "abc" if processed else None,
"isOffline": False,
"isTrashed": False,
"isArchived": False,
"isFavorite": False,
"exifInfo": {},
}
def _album(asset_dicts: list[dict], *, name: str = "Trip", album_id: str = "a1",
shared: bool = False) -> ImmichAlbumData:
return ImmichAlbumData.from_api_response(
{
"id": album_id,
"albumName": name,
"assets": asset_dicts,
"assetCount": len(asset_dicts),
"createdAt": "2026-05-01T00:00:00Z",
"updatedAt": "2026-05-15T12:00:00Z",
"shared": shared,
"owner": {"name": "alexei"},
"albumThumbnailAssetId": asset_dicts[0]["id"] if asset_dicts else None,
}
)
def test_added_asset_emits_assets_added_event() -> None:
old = _album([_asset("a"), _asset("b")])
new = _album([_asset("a"), _asset("b"), _asset("c")])
events, pending = detect_album_changes(
old, new, pending_asset_ids=set(),
provider_name="immich-prod", external_url=_EXTERNAL,
)
assert len(events) == 1
evt = events[0]
assert evt.event_type is EventType.ASSETS_ADDED
assert evt.provider_type is ServiceProviderType.IMMICH
assert evt.collection_id == "a1"
assert evt.collection_name == "Trip"
assert evt.added_count == 1
assert len(evt.added_assets) == 1
assert pending == set()
def test_removed_asset_emits_assets_removed_event() -> None:
old = _album([_asset("a"), _asset("b"), _asset("c")])
new = _album([_asset("a")])
events, _ = detect_album_changes(
old, new, pending_asset_ids=set(),
provider_name="immich-prod", external_url=_EXTERNAL,
)
by_type = {e.event_type: e for e in events}
assert EventType.ASSETS_REMOVED in by_type
removed = by_type[EventType.ASSETS_REMOVED]
assert removed.removed_count == 2
assert set(removed.removed_asset_ids) == {"b", "c"}
def test_no_changes_returns_no_events() -> None:
old = _album([_asset("a"), _asset("b")])
new = _album([_asset("a"), _asset("b")])
events, pending = detect_album_changes(
old, new, pending_asset_ids=set(),
provider_name="immich-prod", external_url=_EXTERNAL,
)
assert events == []
assert pending == set()
def test_unprocessed_asset_is_held_in_pending() -> None:
"""Assets without a thumbhash haven't finished server-side processing.
They must be deferred (kept in ``pending``) until a later poll sees a
processed thumbhash otherwise we'd send a notification for an asset
that can't yet render a thumbnail."""
old = _album([_asset("a")])
new = _album([_asset("a"), _asset("b", processed=False)])
events, pending = detect_album_changes(
old, new, pending_asset_ids=set(),
provider_name="immich-prod", external_url=_EXTERNAL,
)
# ``b`` is not processed, so no event for it AND nothing else changed,
# so we get an empty event list. Pending tracks the held asset.
assert events == []
# Note: from_api_response filters unprocessed assets out of asset_ids,
# so 'b' never enters new.asset_ids — pending stays empty in this path.
# The pending mechanism kicks in once 'b' lands in asset_ids on a later
# tick. Use the next test to exercise that branch.
assert pending == set()
def test_collection_renamed_emits_renamed_event() -> None:
old = _album([_asset("a")], name="Trip")
new = _album([_asset("a")], name="Vacation")
events, _ = detect_album_changes(
old, new, pending_asset_ids=set(),
provider_name="immich-prod", external_url=_EXTERNAL,
)
by_type = {e.event_type: e for e in events}
assert EventType.COLLECTION_RENAMED in by_type
rename = by_type[EventType.COLLECTION_RENAMED]
assert rename.old_name == "Trip"
assert rename.new_name == "Vacation"
def test_sharing_change_emits_sharing_event() -> None:
old = _album([_asset("a")], shared=False)
new = _album([_asset("a")], shared=True)
events, _ = detect_album_changes(
old, new, pending_asset_ids=set(),
provider_name="immich-prod", external_url=_EXTERNAL,
)
by_type = {e.event_type: e for e in events}
assert EventType.SHARING_CHANGED in by_type
sharing = by_type[EventType.SHARING_CHANGED]
assert sharing.old_shared is False
assert sharing.new_shared is True
+147
View File
@@ -0,0 +1,147 @@
"""Unit tests for the Planka webhook parser.
Pure-function tests against ``parse_webhook`` using realistic Planka
webhook payload shapes. The parser is forgiving about missing ``included``
data (older Planka builds), so we mix payloads with and without it to
catch regressions in the fallback paths.
"""
from __future__ import annotations
from notify_bridge_core.models.events import EventType
from notify_bridge_core.providers.base import ServiceProviderType
from notify_bridge_core.providers.planka.event_parser import parse_webhook
_BASE_URL = "https://planka.example.com"
def _user() -> dict:
return {"id": "u1", "username": "alexei", "name": "Alexei"}
def test_card_created() -> None:
payload = {
"user": _user(),
"item": {
"id": "c1",
"name": "Implement metrics",
"description": "Wire prometheus client.",
"boardId": "b1",
"listId": "l1",
"position": 1,
},
"included": {
"board": {"id": "b1", "name": "Roadmap"},
"lists": [{"id": "l1", "name": "Todo"}],
},
}
evt = parse_webhook("cardCreate", payload, provider_name="planka", base_url=_BASE_URL)
assert evt is not None
assert evt.event_type is EventType.CARD_CREATED
assert evt.provider_type is ServiceProviderType.PLANKA
assert evt.collection_id == "b1"
assert evt.collection_name == "Roadmap"
assert evt.extra["card_name"] == "Implement metrics"
assert evt.extra["card_url"] == f"{_BASE_URL}/cards/c1"
assert evt.extra["list_name"] == "Todo"
assert evt.extra["sender"] == "alexei"
def test_card_moved_when_list_changes() -> None:
"""beforeUpdate.listId != item.listId is the signal Planka uses for a
card move; the parser must promote the generic cardUpdate event into
CARD_MOVED so trackers can subscribe to moves specifically."""
payload = {
"user": _user(),
"beforeUpdate": {"listId": "l1"},
"item": {
"id": "c1",
"name": "Implement metrics",
"description": "",
"boardId": "b1",
"listId": "l2",
},
"included": {
"board": {"id": "b1", "name": "Roadmap"},
"lists": [
{"id": "l1", "name": "Todo"},
{"id": "l2", "name": "In progress"},
],
},
}
evt = parse_webhook("cardUpdate", payload, provider_name="planka", base_url=_BASE_URL)
assert evt is not None
assert evt.event_type is EventType.CARD_MOVED
assert evt.extra["old_list_id"] == "l1"
assert evt.extra["new_list_id"] == "l2"
assert evt.extra["old_list_name"] == "Todo"
assert evt.extra["new_list_name"] == "In progress"
def test_card_update_without_list_change_is_card_updated() -> None:
payload = {
"user": _user(),
"beforeUpdate": {"name": "Old name"},
"item": {
"id": "c1", "name": "New name", "description": "", "boardId": "b1", "listId": "l1",
},
}
evt = parse_webhook("cardUpdate", payload, provider_name="planka", base_url=_BASE_URL)
assert evt is not None
assert evt.event_type is EventType.CARD_UPDATED
def test_comment_created() -> None:
payload = {
"user": _user(),
"item": {
"id": "cm1",
"text": "LGTM, ship it.",
"cardId": "c1",
"userId": "u1",
},
"included": {
"card": {"id": "c1", "name": "Implement metrics", "boardId": "b1"},
"board": {"id": "b1", "name": "Roadmap"},
},
}
evt = parse_webhook(
"commentCreate", payload, provider_name="planka", base_url=_BASE_URL,
)
assert evt is not None
assert evt.event_type is EventType.CARD_COMMENTED
assert evt.collection_id == "b1"
assert evt.extra["comment_text"] == "LGTM, ship it."
assert evt.extra["card_id"] == "c1"
assert evt.extra["card_url"] == f"{_BASE_URL}/cards/c1"
def test_task_completion_emits_only_on_transition() -> None:
"""Task updates should only produce TASK_COMPLETED when the task flips
from incomplete to complete toggling the description or other fields
on a task that was already complete must NOT spam notifications."""
completing = {
"user": _user(),
"beforeUpdate": {"isCompleted": False},
"item": {"id": "t1", "name": "Step 1", "isCompleted": True, "cardId": "c1"},
"included": {
"card": {"id": "c1", "name": "Implement metrics", "boardId": "b1"},
"board": {"id": "b1", "name": "Roadmap"},
},
}
evt = parse_webhook("taskUpdate", completing, provider_name="planka", base_url=_BASE_URL)
assert evt is not None
assert evt.event_type is EventType.TASK_COMPLETED
# Editing a task that was already completed -> no event.
re_edit = {
"user": _user(),
"beforeUpdate": {"isCompleted": True},
"item": {"id": "t1", "name": "Step 1 v2", "isCompleted": True, "cardId": "c1"},
}
assert parse_webhook("taskUpdate", re_edit, provider_name="planka", base_url=_BASE_URL) is None
def test_unknown_event_returns_none() -> None:
assert parse_webhook("nonexistent", {"item": {}}, provider_name="planka", base_url="") is None