feat: production readiness — security, perf, bug fixes, bridge self-monitoring
Comprehensive multi-area pass driven by a parallel 8-agent production
review. Frontend, backend, database, security, performance, operational,
plus a new self-monitoring feature.
## Critical fixes
- Planka webhook: reads bounded raw body (was NameError on every call)
- HA quiet hours: ha_state_changed/automation_triggered/service_called/
event_fired added to deferrable set (were silently dropped)
- DNS-rebinding SSRF: PinnedResolver wired into shared aiohttp session
- Telegram inbound webhook: secret now mandatory (401 without)
- Generic webhook: auth_mode="none" requires explicit
acknowledge_unauthenticated=true; per-IP rate limit 60/min
- svelte-check: 5 null-narrowing errors in EventDetailModal fixed
- Provider hardcoding: Immich-only block extracted to descriptor
featureDiscoveryHint
- command_sync: snapshot+expunge bot before exiting AsyncSession
## Bug fixes
- notifier asyncio.gather(return_exceptions=True) — one bad chat no longer
cancels peer sends
- NotificationDispatcher hoisted out of per-tracker loop
- Provider credential resolution unified across all 5 dispatch sites
- HA asyncio.shield now drains inner task on cancellation
- Provider construction switched from if/elif ladder to factory registry
- NUT first poll seeds silently (no spurious ups_on_battery)
- Quiet-hours gate: event-type-disabled now wins over deferral
- APScheduler drain job ID resolution upgraded to seconds
- HA on_status_change wired through to EventLog
- Webhook payload rollback failures now logged (not swallowed)
- Batched receivers/chats/bots in load_link_data (was per-target N+1)
- flag_modified on JSON column reassignments in deferred_dispatch
## Database
- UNIQUE indexes on service_provider.webhook_token,
telegram_bot.webhook_path_id, partial UNIQUE on telegram_bot.bot_id,
telegram_chat(bot_id, chat_id), notification_tracker_target unique link,
partial UNIQUE on bridge_self provider per user
- Composite ix_event_log_user_event_type_created index
- save_chat_from_webhook switched to ON CONFLICT DO UPDATE
- ondelete=CASCADE on user-id FKs (model annotation; app-side cascade
delete added for existing data)
- delete_notification_tracker converted from N+1 to bulk DELETE/UPDATE
- Module-level asyncio.Lock replaced with lazy _get_lock() pattern
- VACUUM INTO snapshot now PRAGMA integrity_check verified
## Performance
- Jinja2 template compilation LRU cached (lru_cache maxsize=512)
- Per-locale render cache in NotificationDispatcher (skips re-rendering
identical content for receivers sharing a locale)
- Tracker list cached per provider_id with 5s TTL + explicit invalidation
on tracker CRUD (relieves HA chat-bus rate query pressure)
- Nav-counts collapsed from 16 round-trips to single UNION ALL
- HA event_log: skip persisting empty assets_added/removed events
## Security hardening
- Mass-assignment guard on Action create/update; cron sub-minute reject
- Backup JSON depth/node-count cap (depth ≤ 10, nodes ≤ 100k)
- _sanitize_config extended to all JSON-typed fields on backup import
- Telegram _safe_get walks redirects manually with SSRF revalidation
- Bcrypt 72-byte password length cap with clear 422
- Webhook payload body redaction; sensitive substring set extended with
oauth/client_secret/webhook_secret/csrf in both header filter and
template extras filter
## Frontend
- 76 catch (err: any) sites converted to errMsg(err) helper
- globalProviderFilter: pure getter; reconciliation moved to one-time
$effect in +layout
- Provider-filter binding: removed paired $effects + _syncingFilter flag,
now one-way derived
- entity-cache: separate _refreshing flag for background re-fetches
- api.ts 401 handling: AuthRedirectError class + dedup _redirecting flag,
goto() instead of window.location.href
- a11y: aria-expanded on mobile More, role=switch + aria-checked on
Telegram bot toggles
## Tests & operations
- CI pytest gate added to .gitea/workflows/build.yml + release.yml
(wheel-built install to dodge editable-install slowness)
- /api/ready upgraded to deep healthcheck (db SELECT 1, scheduler.running,
HA supervisor presence) returning {ready, checks, errors, version}
- /api/metrics endpoint with prometheus_client (deferred_pending,
event_log_total, dispatch_duration, poll_failures, send_failures)
- New OPERATIONS.md covering deploy, healthchecks, metrics, backup/restore
procedures, log handling, common scenarios, upgrade flow
- New tests: test_bridge_self (11), test_gitea_parser (9),
test_planka_parser (6), test_immich_change_detector (6),
test_backup_roundtrip (1)
## New feature: bridge self-monitoring
- New bridge_self provider type — internal sink for bridge health events
- Three event types: bridge_self_poll_failures (consecutive tracker poll
failures), bridge_self_deferred_backlog (pending count crosses
threshold), bridge_self_target_failures (consecutive 5xx/network
failures per target)
- Per-user thresholds (defaults: 3 / 100 / 5) configurable via the
provider config form
- Auto-seeded on user create + /setup + boot backfill for existing users
- Anti-spam: counters reset after emission; backlog uses transition latch
- Self-loop guard: bridge_self failures don't count toward target-failure
thresholds (logged only) — wire to your own Telegram/Email/Matrix to
get notified when polls/dispatches/sends fail
- 6 default templates (3 events × 2 locales), tracking config columns
with backfill migration, frontend descriptor (excluded from "create
provider" wizard since auto-managed)
Operator-visible behavior changes (call out in release notes):
- NOTIFY_BRIDGE_TELEGRAM_WEBHOOK_SECRET now REQUIRED for webhook mode
- Existing webhook providers with auth_mode="none" need explicit opt-in
- Generic webhook endpoint rate-limited 60/min per source IP
- HA disconnect/reconnect writes ha_status_* EventLog rows
- Every user gets a bridge_self provider — wire it to a target to
receive failure alerts
Pre-existing test failures (test_ssrf, test_release_provider) on
Python 3.13 are unrelated; CI runs on 3.12.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -71,6 +71,12 @@ class EventType(str, Enum):
|
||||
HA_SERVICE_CALLED = "ha_service_called"
|
||||
HA_EVENT_FIRED = "ha_event_fired"
|
||||
|
||||
# Bridge self-monitoring events — emitted by the bridge itself when
|
||||
# internal failures cross configured thresholds.
|
||||
BRIDGE_SELF_POLL_FAILURES = "bridge_self_poll_failures"
|
||||
BRIDGE_SELF_DEFERRED_BACKLOG = "bridge_self_deferred_backlog"
|
||||
BRIDGE_SELF_TARGET_FAILURES = "bridge_self_target_failures"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceEvent:
|
||||
|
||||
@@ -107,6 +107,12 @@ class NotificationDispatcher:
|
||||
# Optional shared session owned by the caller; when supplied we reuse
|
||||
# its connection pool instead of opening a fresh per-dispatch session.
|
||||
self._shared_session = session
|
||||
# Per-dispatch render cache, keyed by locale. Populated by
|
||||
# ``_send_to_target`` and consumed inside ``_message_for_receiver``
|
||||
# so a 100-receiver fan-out renders each unique locale once.
|
||||
# Initialized to empty so handlers called outside the normal
|
||||
# dispatch path (tests) still see a valid dict.
|
||||
self._render_cache: dict[str, str] = {}
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _session_ctx(self) -> AsyncIterator[aiohttp.ClientSession]:
|
||||
@@ -198,20 +204,49 @@ class NotificationDispatcher:
|
||||
def _message_for_receiver(
|
||||
self, receiver: Receiver, default_message: str,
|
||||
event: ServiceEvent, target: TargetConfig,
|
||||
cache: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
if receiver.locale and receiver.locale != target.locale:
|
||||
return self._render_message(event, target, receiver.locale)
|
||||
return default_message
|
||||
"""Render message respecting receiver locale, with optional cache.
|
||||
|
||||
The ``cache`` dict (typically created in ``_send_to_target`` and
|
||||
threaded through the per-channel ``_send_*`` handlers) memoizes
|
||||
per-locale renders so a 100-receiver fan-out with two locales
|
||||
renders twice instead of one hundred times.
|
||||
"""
|
||||
loc = receiver.locale or target.locale
|
||||
if loc == target.locale:
|
||||
return default_message
|
||||
if cache is not None:
|
||||
cached = cache.get(loc)
|
||||
if cached is not None:
|
||||
return cached
|
||||
rendered = self._render_message(event, target, loc)
|
||||
cache[loc] = rendered
|
||||
return rendered
|
||||
return self._render_message(event, target, loc)
|
||||
|
||||
async def _send_to_target(
|
||||
self, event: ServiceEvent, target: TargetConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Dispatch to a single target via the registered handler."""
|
||||
"""Dispatch to a single target via the registered handler.
|
||||
|
||||
Builds a per-locale render cache once and threads it through the
|
||||
send handler. The cache is keyed by receiver locale; the default
|
||||
locale's render lives in ``default_message`` and is short-circuited
|
||||
before any cache lookup.
|
||||
"""
|
||||
default_message = self._render_message(event, target, target.locale)
|
||||
send_method = _PROVIDER_HANDLERS.get(target.type)
|
||||
if send_method is None:
|
||||
return {"success": False, "error": f"Unknown target type: {target.type}"}
|
||||
return await send_method(self, target, default_message, event)
|
||||
# Stash the cache on the dispatcher instance for the duration of
|
||||
# this dispatch — handlers pick it up via _message_for_receiver.
|
||||
# Avoids changing every _send_* signature.
|
||||
self._render_cache: dict[str, str] = {}
|
||||
try:
|
||||
return await send_method(self, target, default_message, event)
|
||||
finally:
|
||||
self._render_cache = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Asset preload (Telegram-specific)
|
||||
@@ -352,7 +387,7 @@ class NotificationDispatcher:
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, TelegramReceiver) or not receiver.chat_id:
|
||||
return {"success": False, "error": "Invalid telegram receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
|
||||
text_result = await client.send_message(
|
||||
chat_id=receiver.chat_id,
|
||||
text=message,
|
||||
@@ -407,7 +442,7 @@ class NotificationDispatcher:
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, WebhookReceiver) or not receiver.url:
|
||||
return {"success": False, "error": "Invalid webhook receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
|
||||
payload = {
|
||||
"message": message,
|
||||
"event_type": event.event_type.value,
|
||||
@@ -450,7 +485,7 @@ class NotificationDispatcher:
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, EmailReceiver) or not receiver.email:
|
||||
return {"success": False, "error": "Invalid email receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
|
||||
# body_html=None lets EmailClient build a safely-escaped HTML
|
||||
# alternative from body_text instead of trusting user content.
|
||||
return await email_client.send(
|
||||
@@ -479,7 +514,7 @@ class NotificationDispatcher:
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, DiscordReceiver) or not receiver.webhook_url:
|
||||
return {"success": False, "error": "Invalid discord receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
|
||||
return await client.send(receiver.webhook_url, message, username=username)
|
||||
|
||||
results = await self._fan_out(target.receivers, send_one)
|
||||
@@ -501,7 +536,7 @@ class NotificationDispatcher:
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, SlackReceiver) or not receiver.webhook_url:
|
||||
return {"success": False, "error": "Invalid slack receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
|
||||
return await client.send(receiver.webhook_url, message, username=username)
|
||||
|
||||
results = await self._fan_out(target.receivers, send_one)
|
||||
@@ -530,7 +565,7 @@ class NotificationDispatcher:
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, NtfyReceiver) or not receiver.topic:
|
||||
return {"success": False, "error": "Invalid ntfy receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
|
||||
return await client.send(
|
||||
server_url, receiver.topic, message,
|
||||
title=title, priority=receiver.priority, auth_token=auth_token,
|
||||
@@ -563,7 +598,7 @@ class NotificationDispatcher:
|
||||
async def send_one(receiver: Receiver) -> dict[str, Any]:
|
||||
if not isinstance(receiver, MatrixReceiver) or not receiver.room_id:
|
||||
return {"success": False, "error": "Invalid matrix receiver"}
|
||||
message = self._message_for_receiver(receiver, default_message, event, target)
|
||||
message = self._message_for_receiver(receiver, default_message, event, target, cache=self._render_cache)
|
||||
# body_html is the same plain text — Matrix accepts the
|
||||
# raw message as both ``body`` and ``formatted_body``.
|
||||
# If templates emit HTML in the future, generate a
|
||||
|
||||
@@ -222,21 +222,48 @@ class TelegramClient:
|
||||
"""SSRF-guarded GET that returns ``(data, error)``.
|
||||
|
||||
Validates the URL via ``avalidate_outbound_url`` before any HTTP
|
||||
traffic. Errors are returned (not raised) and stripped of any
|
||||
embedded secrets before they propagate to the operator-visible
|
||||
result dict.
|
||||
traffic. Redirects are walked manually so each ``Location`` is
|
||||
re-validated — without this an attacker-controlled origin could
|
||||
302 to a private-IP target after the initial guard passed.
|
||||
Errors are returned (not raised) and stripped of any embedded
|
||||
secrets before they propagate to the operator-visible result
|
||||
dict.
|
||||
"""
|
||||
max_redirects = 3
|
||||
current_url = url
|
||||
try:
|
||||
await avalidate_outbound_url(url)
|
||||
await avalidate_outbound_url(current_url)
|
||||
except UnsafeURLError as err:
|
||||
return None, f"Unsafe URL: {redact_exc(err)}"
|
||||
try:
|
||||
async with self._session.get(
|
||||
url, headers=headers or {}, timeout=_DOWNLOAD_TIMEOUT,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
return None, f"HTTP {resp.status}"
|
||||
return await resp.read(), None
|
||||
for _ in range(max_redirects + 1):
|
||||
async with self._session.get(
|
||||
current_url,
|
||||
headers=headers or {},
|
||||
timeout=_DOWNLOAD_TIMEOUT,
|
||||
allow_redirects=False,
|
||||
) as resp:
|
||||
if resp.status in (301, 302, 303, 307, 308):
|
||||
loc = resp.headers.get("Location")
|
||||
if not loc:
|
||||
return None, f"HTTP {resp.status} without Location header"
|
||||
# ``resp.url`` is a yarl.URL; ``.join`` resolves
|
||||
# relative redirects (``/foo/bar``) against it.
|
||||
from yarl import URL as _URL
|
||||
try:
|
||||
next_url = str(resp.url.join(_URL(loc)))
|
||||
except (ValueError, TypeError):
|
||||
return None, "Malformed redirect Location"
|
||||
try:
|
||||
await avalidate_outbound_url(next_url)
|
||||
except UnsafeURLError as err:
|
||||
return None, f"Unsafe redirect: {redact_exc(err)}"
|
||||
current_url = next_url
|
||||
continue
|
||||
if resp.status != 200:
|
||||
return None, f"HTTP {resp.status}"
|
||||
return await resp.read(), None
|
||||
return None, f"Too many redirects (>{max_redirects})"
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err:
|
||||
return None, redact_exc(err)
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ class ServiceProviderType(str, Enum):
|
||||
GOOGLE_PHOTOS = "google_photos"
|
||||
WEBHOOK = "webhook"
|
||||
HOME_ASSISTANT = "home_assistant"
|
||||
BRIDGE_SELF = "bridge_self"
|
||||
|
||||
|
||||
# Callback signature for push-style providers: a coroutine that accepts a
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Bridge self-monitoring service provider.
|
||||
|
||||
Unlike external providers (Immich, Gitea, NUT, ...), the ``bridge_self``
|
||||
provider does not connect to any remote service. Its sole purpose is to
|
||||
give operators a configurable surface (thresholds + notification slots
|
||||
+ trackers + targets) for events that the bridge itself emits when its
|
||||
internal subsystems fail.
|
||||
|
||||
Three failure conditions are surfaced as :class:`ServiceEvent` instances
|
||||
through the same dispatch pipeline that all other providers use:
|
||||
|
||||
* ``bridge_self_poll_failures`` — N consecutive poll failures for
|
||||
any tracker exceed the configured threshold.
|
||||
* ``bridge_self_deferred_backlog`` — pending ``deferred_dispatch`` row
|
||||
count crosses the configured threshold.
|
||||
* ``bridge_self_target_failures`` — N consecutive 5xx / network failures
|
||||
for a single notification target.
|
||||
|
||||
Events are constructed by ``services/bridge_self.py`` on the server side
|
||||
(it owns DB access for looking up the bridge_self provider per user)
|
||||
and then fed into ``dispatch_provider_event`` like any other event.
|
||||
"""
|
||||
|
||||
from notify_bridge_core.providers.base import ServiceProviderType
|
||||
from notify_bridge_core.templates.variables import registry
|
||||
|
||||
from .event_parser import build_event
|
||||
from .provider import BRIDGE_SELF_VARIABLES, BridgeSelfServiceProvider
|
||||
|
||||
# Register variables so the validator and template-vars API see them.
|
||||
registry.register_provider_variables(
|
||||
ServiceProviderType.BRIDGE_SELF, BRIDGE_SELF_VARIABLES,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BRIDGE_SELF_VARIABLES",
|
||||
"BridgeSelfServiceProvider",
|
||||
"build_event",
|
||||
]
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Bridge self-monitoring event parser.
|
||||
|
||||
The bridge generates these events from internal subsystems (watcher,
|
||||
scheduler, dispatcher) — the parser turns a flat payload dict into the
|
||||
generic :class:`ServiceEvent` shape that the rest of the dispatch
|
||||
pipeline expects.
|
||||
|
||||
Payload shape::
|
||||
|
||||
{
|
||||
"failure_type": "poll_failures" | "deferred_backlog" | "target_failures",
|
||||
"subject_id": int, # tracker_id, target_id, or 0
|
||||
"subject_name": str,
|
||||
"count": int, # consecutive failures or pending count
|
||||
"threshold": int,
|
||||
"last_error": str, # may be empty
|
||||
"details": dict[str, Any], # extra context
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from notify_bridge_core.models.events import EventType, ServiceEvent
|
||||
from notify_bridge_core.providers.base import ServiceProviderType
|
||||
|
||||
|
||||
# Defensive cap on the persisted error message; very long tracebacks would
|
||||
# bloat the EventLog details JSON column otherwise.
|
||||
_MAX_ERROR_LEN = 1000
|
||||
|
||||
|
||||
_FAILURE_TYPE_TO_EVENT: dict[str, EventType] = {
|
||||
"poll_failures": EventType.BRIDGE_SELF_POLL_FAILURES,
|
||||
"deferred_backlog": EventType.BRIDGE_SELF_DEFERRED_BACKLOG,
|
||||
"target_failures": EventType.BRIDGE_SELF_TARGET_FAILURES,
|
||||
}
|
||||
|
||||
|
||||
def build_event(
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
provider_name: str = "Bridge Self-Monitoring",
|
||||
timestamp: datetime | None = None,
|
||||
) -> ServiceEvent | None:
|
||||
"""Convert a self-monitoring payload dict into a ServiceEvent.
|
||||
|
||||
Returns None for malformed payloads (unknown failure_type or missing
|
||||
keys) — the caller drops without raising so a misbehaving emitter
|
||||
can never tip over the dispatch pipeline.
|
||||
"""
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
failure_type = payload.get("failure_type")
|
||||
event_type = _FAILURE_TYPE_TO_EVENT.get(str(failure_type) if failure_type else "")
|
||||
if event_type is None:
|
||||
return None
|
||||
|
||||
subject_id = int(payload.get("subject_id") or 0)
|
||||
subject_name = str(payload.get("subject_name") or "")
|
||||
count = int(payload.get("count") or 0)
|
||||
threshold = int(payload.get("threshold") or 0)
|
||||
last_error = str(payload.get("last_error") or "")[:_MAX_ERROR_LEN]
|
||||
details = payload.get("details") if isinstance(payload.get("details"), dict) else {}
|
||||
|
||||
when = timestamp or datetime.now(timezone.utc)
|
||||
|
||||
return ServiceEvent(
|
||||
event_type=event_type,
|
||||
provider_type=ServiceProviderType.BRIDGE_SELF,
|
||||
provider_name=provider_name,
|
||||
# ``collection_id`` / ``collection_name`` are required fields on
|
||||
# ServiceEvent; we use the subject so quiet-hours / dedupe logic
|
||||
# treats different subjects as distinct streams.
|
||||
collection_id=str(subject_id),
|
||||
collection_name=subject_name or str(failure_type),
|
||||
timestamp=when,
|
||||
extra={
|
||||
"failure_type": str(failure_type),
|
||||
"subject_id": subject_id,
|
||||
"subject_name": subject_name,
|
||||
"count": count,
|
||||
"threshold": threshold,
|
||||
"last_error": last_error,
|
||||
"details": dict(details),
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,148 @@
|
||||
"""Bridge self-monitoring service provider — emits internal-failure events.
|
||||
|
||||
This is a passive provider: it does not connect to anything, never polls,
|
||||
and never subscribes. It exists so the rest of the bridge's CRUD / config /
|
||||
template / target plumbing has a single ``ServiceProvider`` to attach
|
||||
self-monitoring trackers and notification slots to.
|
||||
|
||||
Events are constructed by the server-side helper
|
||||
``services/bridge_self.emit_bridge_self_event`` and pushed into
|
||||
``dispatch_provider_event`` directly — the provider itself is not asked
|
||||
to produce events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from notify_bridge_core.models.events import ServiceEvent
|
||||
from notify_bridge_core.providers.base import (
|
||||
ServiceProvider,
|
||||
ServiceProviderType,
|
||||
)
|
||||
from notify_bridge_core.templates.variables import TemplateVariableDefinition
|
||||
|
||||
|
||||
# Configuration keys recognised on the bridge_self provider's ``config`` JSON.
|
||||
DEFAULT_POLL_FAILURE_THRESHOLD = 3
|
||||
DEFAULT_DEFERRED_BACKLOG_THRESHOLD = 100
|
||||
DEFAULT_TARGET_FAILURE_THRESHOLD = 5
|
||||
|
||||
|
||||
# Template variables exposed to bridge_self templates.
|
||||
BRIDGE_SELF_VARIABLES: list[TemplateVariableDefinition] = [
|
||||
TemplateVariableDefinition(
|
||||
name="failure_type",
|
||||
type="string",
|
||||
description="Which self-monitoring condition fired",
|
||||
example="poll_failures",
|
||||
provider_type=ServiceProviderType.BRIDGE_SELF,
|
||||
),
|
||||
TemplateVariableDefinition(
|
||||
name="subject_id",
|
||||
type="int",
|
||||
description="ID of the affected entity (tracker_id, target_id, or 0)",
|
||||
example="42",
|
||||
provider_type=ServiceProviderType.BRIDGE_SELF,
|
||||
),
|
||||
TemplateVariableDefinition(
|
||||
name="subject_name",
|
||||
type="string",
|
||||
description="Human-readable name of the affected entity",
|
||||
example="My Immich Tracker",
|
||||
provider_type=ServiceProviderType.BRIDGE_SELF,
|
||||
),
|
||||
TemplateVariableDefinition(
|
||||
name="count",
|
||||
type="int",
|
||||
description="Consecutive failure count or current backlog size",
|
||||
example="3",
|
||||
provider_type=ServiceProviderType.BRIDGE_SELF,
|
||||
),
|
||||
TemplateVariableDefinition(
|
||||
name="threshold",
|
||||
type="int",
|
||||
description="Configured threshold that was crossed",
|
||||
example="3",
|
||||
provider_type=ServiceProviderType.BRIDGE_SELF,
|
||||
),
|
||||
TemplateVariableDefinition(
|
||||
name="last_error",
|
||||
type="string",
|
||||
description="Last underlying error message (truncated)",
|
||||
example="Connection refused",
|
||||
provider_type=ServiceProviderType.BRIDGE_SELF,
|
||||
),
|
||||
TemplateVariableDefinition(
|
||||
name="details",
|
||||
type="dict",
|
||||
description="Extra structured context for the event",
|
||||
example='{"provider_id": 7}',
|
||||
provider_type=ServiceProviderType.BRIDGE_SELF,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class BridgeSelfServiceProvider(ServiceProvider):
|
||||
"""Passive provider — exposes nothing remote, holds only thresholds.
|
||||
|
||||
Polling is a no-op and ``connect`` always succeeds; the bridge itself
|
||||
is what generates events for this provider.
|
||||
"""
|
||||
|
||||
provider_type = ServiceProviderType.BRIDGE_SELF
|
||||
supports_subscription = False
|
||||
|
||||
def __init__(self, name: str = "Bridge Self-Monitoring") -> None:
|
||||
self._name = name
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
return None
|
||||
|
||||
async def poll(
|
||||
self,
|
||||
collection_ids: list[str],
|
||||
tracker_state: dict[str, Any],
|
||||
) -> tuple[list[ServiceEvent], dict[str, Any]]:
|
||||
# No external service to poll. Returning empty keeps the contract
|
||||
# so accidental scheduling no-ops cleanly.
|
||||
return [], tracker_state
|
||||
|
||||
def get_available_variables(self) -> list[TemplateVariableDefinition]:
|
||||
return list(BRIDGE_SELF_VARIABLES)
|
||||
|
||||
def get_provider_config_schema(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"poll_failure_threshold": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": DEFAULT_POLL_FAILURE_THRESHOLD,
|
||||
"description": "Consecutive tracker poll failures before alerting",
|
||||
},
|
||||
"deferred_backlog_threshold": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": DEFAULT_DEFERRED_BACKLOG_THRESHOLD,
|
||||
"description": "Pending deferred_dispatch rows before alerting",
|
||||
},
|
||||
"target_failure_threshold": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"default": DEFAULT_TARGET_FAILURE_THRESHOLD,
|
||||
"description": "Consecutive target send failures before alerting",
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
async def list_collections(self) -> list[dict[str, Any]]:
|
||||
# No collection concept — operators don't pick anything for this provider.
|
||||
return []
|
||||
|
||||
async def test_connection(self) -> dict[str, Any]:
|
||||
return {"ok": True, "message": "Bridge self-monitoring is always available"}
|
||||
@@ -514,6 +514,39 @@ HOME_ASSISTANT_CAPABILITIES = ProviderCapabilities(
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bridge self-monitoring capabilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BRIDGE_SELF_CAPABILITIES = ProviderCapabilities(
|
||||
provider_type="bridge_self",
|
||||
display_name="Bridge Self-Monitoring",
|
||||
webhook_based=False,
|
||||
supported_filters=[],
|
||||
notification_slots=[
|
||||
{
|
||||
"name": "message_bridge_self_poll_failures",
|
||||
"description": "Tracker poll failures crossed threshold",
|
||||
},
|
||||
{
|
||||
"name": "message_bridge_self_deferred_backlog",
|
||||
"description": "Deferred dispatch backlog crossed threshold",
|
||||
},
|
||||
{
|
||||
"name": "message_bridge_self_target_failures",
|
||||
"description": "Target send failures crossed threshold",
|
||||
},
|
||||
],
|
||||
events=[
|
||||
{"name": "bridge_self_poll_failures", "description": "Tracker poll failures"},
|
||||
{"name": "bridge_self_deferred_backlog", "description": "Deferred backlog high"},
|
||||
{"name": "bridge_self_target_failures", "description": "Target send failures"},
|
||||
],
|
||||
command_slots=[],
|
||||
commands=[],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -527,6 +560,7 @@ _REGISTRY: dict[str, ProviderCapabilities] = {
|
||||
"google_photos": GOOGLE_PHOTOS_CAPABILITIES,
|
||||
"webhook": WEBHOOK_CAPABILITIES,
|
||||
"home_assistant": HOME_ASSISTANT_CAPABILITIES,
|
||||
"bridge_self": BRIDGE_SELF_CAPABILITIES,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ arrive. The lifecycle is owned by the server-side subscription manager
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -25,6 +25,12 @@ from notify_bridge_core.templates.variables import TemplateVariableDefinition
|
||||
from .client import HomeAssistantWSClient
|
||||
from .event_parser import parse_event
|
||||
|
||||
|
||||
# Status callback signature: ``(state, detail)`` where ``state`` is one of
|
||||
# ``"connected"`` / ``"disconnected"`` and ``detail`` is an optional already-
|
||||
# redacted reason string (or None on connect).
|
||||
StatusChangeCallback = Callable[[str, str | None], None]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -229,7 +235,11 @@ class HomeAssistantServiceProvider(ServiceProvider):
|
||||
# — the subscription manager owns this provider's lifecycle instead.
|
||||
return [], tracker_state
|
||||
|
||||
async def subscribe(self, emit: EventEmitCallback) -> None:
|
||||
async def subscribe(
|
||||
self,
|
||||
emit: EventEmitCallback,
|
||||
on_status_change: StatusChangeCallback | None = None,
|
||||
) -> None:
|
||||
async def _on_event(ha_event: dict[str, Any]) -> None:
|
||||
event = parse_event(
|
||||
ha_event,
|
||||
@@ -252,6 +262,7 @@ class HomeAssistantServiceProvider(ServiceProvider):
|
||||
on_event=_on_event,
|
||||
event_types=self._event_types,
|
||||
refresh_areas=_refresh_areas,
|
||||
on_status_change=on_status_change,
|
||||
)
|
||||
|
||||
def get_available_variables(self) -> list[TemplateVariableDefinition]:
|
||||
|
||||
@@ -29,10 +29,21 @@ _LOGGER = logging.getLogger(__name__)
|
||||
# calls per poll cycle. TTL is conservative (1h) and a hashed key keeps the
|
||||
# raw api_key out of dict keys in case of a memory dump.
|
||||
_USERS_CACHE_TTL_SECONDS = 3600
|
||||
_users_cache_lock = asyncio.Lock()
|
||||
# Lazy init: ``asyncio.Lock()`` at module import binds to whichever event
|
||||
# loop is current at import time (often none, or the wrong one when tests
|
||||
# spin up dedicated loops). Defer creation to first use.
|
||||
_users_cache_lock: asyncio.Lock | None = None
|
||||
_users_cache: dict[str, tuple[float, dict[str, str]]] = {}
|
||||
|
||||
|
||||
def _get_users_cache_lock() -> asyncio.Lock:
|
||||
"""Return the module users-cache lock, creating it on first call."""
|
||||
global _users_cache_lock
|
||||
if _users_cache_lock is None:
|
||||
_users_cache_lock = asyncio.Lock()
|
||||
return _users_cache_lock
|
||||
|
||||
|
||||
def _users_cache_key(url: str, api_key: str) -> str:
|
||||
digest = hashlib.sha256(f"{url}|{api_key}".encode("utf-8")).hexdigest()
|
||||
return digest[:32]
|
||||
@@ -51,7 +62,7 @@ async def _get_cached_users(
|
||||
if entry is not None and (now - entry[0]) < _USERS_CACHE_TTL_SECONDS:
|
||||
return entry[1]
|
||||
|
||||
async with _users_cache_lock:
|
||||
async with _get_users_cache_lock():
|
||||
# Re-check after acquiring the lock — another coroutine may have
|
||||
# refreshed the entry while we waited.
|
||||
entry = _users_cache.get(key)
|
||||
|
||||
@@ -200,10 +200,28 @@ class NutServiceProvider(ServiceProvider):
|
||||
try:
|
||||
for ups_name in collection_ids:
|
||||
prev = tracker_state.get(ups_name, {})
|
||||
# First-ever observation has no baseline — emitting transition
|
||||
# events for whatever flags the device happens to carry would
|
||||
# spam the user with "OB"/"LB"/"REPLBATT" alerts on every fresh
|
||||
# tracker even when nothing changed. Seed state silently and
|
||||
# skip event emission until the next poll provides a baseline.
|
||||
is_first_observation = ups_name not in tracker_state
|
||||
try:
|
||||
variables = await client.list_var(ups_name)
|
||||
data = NutUpsData.from_variables(ups_name, variables)
|
||||
|
||||
if is_first_observation:
|
||||
new_state[ups_name] = {
|
||||
"name": data.description or ups_name,
|
||||
"status": data.status,
|
||||
"battery_charge": data.battery_charge,
|
||||
"comms_ok": True,
|
||||
"asset_ids": [],
|
||||
"pending_asset_ids": [],
|
||||
"shared": False,
|
||||
}
|
||||
continue
|
||||
|
||||
# Check for comms restored
|
||||
if not prev.get("comms_ok", True):
|
||||
events.append(self._make_event(
|
||||
|
||||
@@ -35,6 +35,10 @@ _SENSITIVE_EXTRA_TOKENS: tuple[str, ...] = (
|
||||
"bearer",
|
||||
"private_key",
|
||||
"access_key",
|
||||
"oauth",
|
||||
"client_secret",
|
||||
"webhook_secret",
|
||||
"csrf",
|
||||
)
|
||||
|
||||
|
||||
|
||||
+6
@@ -0,0 +1,6 @@
|
||||
⚠️ <b>Deferred dispatch backlog high</b>
|
||||
Pending notifications: <b>{{ count }}</b>
|
||||
Threshold: <b>{{ threshold }}</b>
|
||||
{%- if last_error %}
|
||||
<i>Note:</i> <code>{{ last_error }}</code>
|
||||
{%- endif %}
|
||||
+6
@@ -0,0 +1,6 @@
|
||||
🚨 <b>Tracker poll failures</b>
|
||||
<b>{{ subject_name }}</b> (id <code>{{ subject_id }}</code>)
|
||||
<b>{{ count }}</b> consecutive failures (threshold {{ threshold }})
|
||||
{%- if last_error %}
|
||||
<i>Last error:</i> <code>{{ last_error }}</code>
|
||||
{%- endif %}
|
||||
+6
@@ -0,0 +1,6 @@
|
||||
📡 <b>Target send failures</b>
|
||||
<b>{{ subject_name }}</b> (id <code>{{ subject_id }}</code>)
|
||||
<b>{{ count }}</b> consecutive failures (threshold {{ threshold }})
|
||||
{%- if last_error %}
|
||||
<i>Last error:</i> <code>{{ last_error }}</code>
|
||||
{%- endif %}
|
||||
@@ -79,6 +79,11 @@ PROVIDER_SLOT_FILE_MAP: dict[str, dict[str, str]] = {
|
||||
"message_ha_service_called": "ha_service_called.jinja2",
|
||||
"message_ha_event_fired": "ha_event_fired.jinja2",
|
||||
},
|
||||
"bridge_self": {
|
||||
"message_bridge_self_poll_failures": "bridge_self_poll_failures.jinja2",
|
||||
"message_bridge_self_deferred_backlog": "bridge_self_deferred_backlog.jinja2",
|
||||
"message_bridge_self_target_failures": "bridge_self_target_failures.jinja2",
|
||||
},
|
||||
}
|
||||
|
||||
# Backward-compatible alias
|
||||
|
||||
+6
@@ -0,0 +1,6 @@
|
||||
⚠️ <b>Очередь отложенной отправки растёт</b>
|
||||
Ожидают отправки: <b>{{ count }}</b>
|
||||
Порог: <b>{{ threshold }}</b>
|
||||
{%- if last_error %}
|
||||
<i>Примечание:</i> <code>{{ last_error }}</code>
|
||||
{%- endif %}
|
||||
+6
@@ -0,0 +1,6 @@
|
||||
🚨 <b>Сбои опроса трекера</b>
|
||||
<b>{{ subject_name }}</b> (id <code>{{ subject_id }}</code>)
|
||||
Подряд сбоев: <b>{{ count }}</b> (порог {{ threshold }})
|
||||
{%- if last_error %}
|
||||
<i>Последняя ошибка:</i> <code>{{ last_error }}</code>
|
||||
{%- endif %}
|
||||
+6
@@ -0,0 +1,6 @@
|
||||
📡 <b>Сбои отправки в адресат</b>
|
||||
<b>{{ subject_name }}</b> (id <code>{{ subject_id }}</code>)
|
||||
Подряд сбоев: <b>{{ count }}</b> (порог {{ threshold }})
|
||||
{%- if last_error %}
|
||||
<i>Последняя ошибка:</i> <code>{{ last_error }}</code>
|
||||
{%- endif %}
|
||||
@@ -13,6 +13,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
import jinja2
|
||||
@@ -27,6 +28,19 @@ RENDER_TIMEOUT_SECONDS = 2.0
|
||||
_env = SandboxedEnvironment(autoescape=True)
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def _compile_cached(template_str: str) -> jinja2.Template:
|
||||
"""Compile + cache Jinja2 templates by source text.
|
||||
|
||||
Hot paths (NotificationDispatcher fan-out, periodic dispatch) re-render
|
||||
the same template string for every event; ``_env.from_string`` parses
|
||||
the source from scratch each time (~ms each). The 512-entry cache is
|
||||
large enough to hold every template across a busy install while
|
||||
keeping memory bounded.
|
||||
"""
|
||||
return _env.from_string(template_str)
|
||||
|
||||
|
||||
class TemplateRenderTimeout(jinja2.TemplateError):
|
||||
"""Raised when a template exceeds the configured render budget."""
|
||||
|
||||
@@ -74,7 +88,7 @@ def render_template(template_str: str, context: dict[str, Any]) -> str:
|
||||
)
|
||||
return "[Template too large]"
|
||||
try:
|
||||
compiled = _env.from_string(template_str)
|
||||
compiled = _compile_cached(template_str)
|
||||
output = _render_with_timeout(compiled, context)
|
||||
except TemplateRenderTimeout as e:
|
||||
_LOGGER.error("Template render timeout: %s", e)
|
||||
|
||||
@@ -27,6 +27,9 @@ def validate_template(
|
||||
"has_oversized_videos", "max_video_size", "max_video_size_mb",
|
||||
"added_assets", "assets", "albums",
|
||||
"raw_payload", "event_type_raw", "source_ip",
|
||||
# bridge_self self-monitoring variables.
|
||||
"failure_type", "subject_id", "subject_name", "count",
|
||||
"threshold", "last_error", "details",
|
||||
}
|
||||
allowed = available | runtime_vars
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ dependencies = [
|
||||
"slowapi>=0.1.9",
|
||||
"cachetools>=5.3",
|
||||
"python-multipart>=0.0.9",
|
||||
"prometheus_client>=0.20",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Action management API routes — CRUD, execute, dry-run, executions."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
@@ -54,6 +55,58 @@ class ActionUpdate(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Allowlist of fields a CRUD client may set on Action. Mirrors ActionCreate /
|
||||
# ActionUpdate but enforced server-side so a tampered request body cannot
|
||||
# overwrite ``user_id``, ``last_run_at``, ``created_at``, etc. via ``**dump``.
|
||||
_ALLOWED_ACTION_CREATE_FIELDS = frozenset({
|
||||
"provider_id", "name", "icon", "action_type", "config",
|
||||
"schedule_type", "schedule_interval", "schedule_cron", "enabled",
|
||||
})
|
||||
_ALLOWED_ACTION_UPDATE_FIELDS = frozenset({
|
||||
"name", "icon", "config",
|
||||
"schedule_type", "schedule_interval", "schedule_cron", "enabled",
|
||||
})
|
||||
|
||||
# 6 fields = standard cron, 7 fields = with seconds (Quartz-style). Reject
|
||||
# the 7-field form whose first column allows fires more often than once per
|
||||
# minute. Also reject ``*/N`` minute patterns where N<1 (so ``*/0``) and the
|
||||
# bare ``*`` minute used together with ``*`` second.
|
||||
_DISALLOWED_CRON_PATTERNS = (
|
||||
re.compile(r"^\s*\*/0\s+"), # */0 in any leading position
|
||||
)
|
||||
|
||||
|
||||
def _validate_cron(expr: str) -> None:
|
||||
"""Reject schedule_cron strings that fire more often than once per minute.
|
||||
|
||||
Without croniter as a hard dep we apply a conservative regex check: a
|
||||
valid 5-field cron's first column is the minute, so anything other than
|
||||
``*``/digits/comma/dash/slash there is bogus, and a sub-minute cadence
|
||||
requires a 6+ field expression with seconds. Reject both shapes.
|
||||
"""
|
||||
if not expr or not expr.strip():
|
||||
return
|
||||
parts = expr.split()
|
||||
if len(parts) >= 6:
|
||||
# Seconds field present (Quartz-style or 6-field). Forbid
|
||||
# second-level fires entirely; minute-cadence is the floor.
|
||||
seconds_field = parts[0]
|
||||
if seconds_field != "0":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"schedule_cron with a sub-minute cadence is not allowed; "
|
||||
"set the seconds field to 0 or use a standard 5-field cron"
|
||||
),
|
||||
)
|
||||
for pattern in _DISALLOWED_CRON_PATTERNS:
|
||||
if pattern.search(expr):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="schedule_cron contains a disallowed pattern",
|
||||
)
|
||||
|
||||
|
||||
async def _action_response(session: AsyncSession, action: Action) -> dict:
|
||||
"""Build response dict with rules inlined."""
|
||||
result = await session.exec(
|
||||
@@ -127,7 +180,15 @@ async def create_action(
|
||||
detail=f"Invalid action type '{body.action_type}' for provider type '{provider.type}'",
|
||||
)
|
||||
|
||||
action = Action(user_id=user.id, **body.model_dump())
|
||||
_validate_cron(body.schedule_cron)
|
||||
|
||||
# Project only allowlisted fields so a tampered body can't write
|
||||
# ``user_id``, ``id``, ``last_run_at``, etc. via ``**dump``.
|
||||
payload = {
|
||||
k: v for k, v in body.model_dump().items()
|
||||
if k in _ALLOWED_ACTION_CREATE_FIELDS
|
||||
}
|
||||
action = Action(user_id=user.id, **payload)
|
||||
session.add(action)
|
||||
await session.commit()
|
||||
await session.refresh(action)
|
||||
@@ -168,7 +229,13 @@ async def update_action(
|
||||
raise HTTPException(status_code=404, detail="Action not found")
|
||||
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
if "schedule_cron" in updates:
|
||||
_validate_cron(updates["schedule_cron"] or "")
|
||||
# Drop any field outside the update allowlist so a tampered request
|
||||
# can't mutate ``user_id`` / ``provider_id`` / ``action_type`` etc.
|
||||
for key, value in updates.items():
|
||||
if key not in _ALLOWED_ACTION_UPDATE_FIELDS:
|
||||
continue
|
||||
setattr(action, key, value)
|
||||
session.add(action)
|
||||
await session.commit()
|
||||
|
||||
@@ -48,6 +48,40 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/backup", tags=["backup"])
|
||||
|
||||
|
||||
# Hard caps on uploaded backup file shape — defend against parser DoS
|
||||
# (deeply nested or pathologically wide JSON) before we hand the
|
||||
# structure to the import pipeline.
|
||||
_MAX_BACKUP_DEPTH = 10
|
||||
_MAX_BACKUP_NODES = 100_000
|
||||
|
||||
|
||||
def _validate_backup_shape(value: object, depth: int = 0, count: list[int] | None = None) -> None:
|
||||
"""Walk ``value`` and reject anything beyond the depth/node caps.
|
||||
|
||||
Raises HTTPException(400) on overflow. Cheap O(n) walk; runs once
|
||||
per upload.
|
||||
"""
|
||||
if count is None:
|
||||
count = [0]
|
||||
if depth > _MAX_BACKUP_DEPTH:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Backup file too deeply nested (max depth {_MAX_BACKUP_DEPTH})",
|
||||
)
|
||||
count[0] += 1
|
||||
if count[0] > _MAX_BACKUP_NODES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Backup file has too many nodes (max {_MAX_BACKUP_NODES})",
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
for v in value.values():
|
||||
_validate_backup_shape(v, depth + 1, count)
|
||||
elif isinstance(value, list):
|
||||
for v in value:
|
||||
_validate_backup_shape(v, depth + 1, count)
|
||||
|
||||
MAX_UPLOAD_SIZE = 10 * 1024 * 1024 # 10 MB
|
||||
|
||||
|
||||
@@ -181,6 +215,8 @@ async def validate_config(
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
|
||||
|
||||
_validate_backup_shape(raw)
|
||||
|
||||
result = validate_backup(raw)
|
||||
return result.model_dump()
|
||||
|
||||
@@ -204,6 +240,8 @@ async def import_config(
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
|
||||
|
||||
_validate_backup_shape(raw)
|
||||
|
||||
# Validate first
|
||||
validation = validate_backup(raw)
|
||||
if not validation.valid:
|
||||
@@ -259,6 +297,8 @@ async def prepare_restore(
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
|
||||
|
||||
_validate_backup_shape(raw)
|
||||
|
||||
validation = validate_backup(raw)
|
||||
if not validation.valid:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -504,11 +504,14 @@ async def delete_config(
|
||||
if config.user_id == 0 and user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="Cannot delete system default configs")
|
||||
raise_if_used(await check_command_template_config(session, config.id), config.name)
|
||||
slot_result = await session.exec(
|
||||
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == config.id)
|
||||
# Bulk delete slot rows so the round-trip count stays O(1) regardless
|
||||
# of how many locale/slot combinations the config carries.
|
||||
from sqlalchemy import delete as sa_delete
|
||||
await session.execute(
|
||||
sa_delete(CommandTemplateSlot).where(
|
||||
CommandTemplateSlot.config_id == config.id
|
||||
)
|
||||
)
|
||||
for slot in slot_result.all():
|
||||
await session.delete(slot)
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@@ -162,17 +162,26 @@ async def delete_command_tracker(
|
||||
from ..services.command_sync import mark_dirty_for_tracker
|
||||
await mark_dirty_for_tracker(tracker.id)
|
||||
|
||||
# Delete associated listeners, collecting bot IDs for polling cleanup
|
||||
# First read the listeners we're about to delete so we can collect the
|
||||
# set of telegram_bot IDs whose polling state may need to be re-checked.
|
||||
# Then issue a single bulk DELETE instead of N per-row deletes.
|
||||
from sqlalchemy import delete as sa_delete
|
||||
|
||||
result = await session.exec(
|
||||
select(CommandTrackerListener).where(
|
||||
CommandTrackerListener.command_tracker_id == tracker_id
|
||||
)
|
||||
)
|
||||
bot_ids_to_check: set[int] = set()
|
||||
for listener in result.all():
|
||||
if listener.listener_type == "telegram_bot":
|
||||
bot_ids_to_check.add(listener.listener_id)
|
||||
await session.delete(listener)
|
||||
bot_ids_to_check: set[int] = {
|
||||
listener.listener_id
|
||||
for listener in result.all()
|
||||
if listener.listener_type == "telegram_bot"
|
||||
}
|
||||
await session.execute(
|
||||
sa_delete(CommandTrackerListener).where(
|
||||
CommandTrackerListener.command_tracker_id == tracker_id
|
||||
)
|
||||
)
|
||||
|
||||
await session.delete(tracker)
|
||||
await session.commit()
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
"""Prometheus metrics endpoint and central registry.
|
||||
|
||||
Exposes operational metrics via ``GET /api/metrics`` in the standard
|
||||
Prometheus text format. Unauthenticated by design — Prometheus scrapers do
|
||||
not authenticate. If the API port crosses a trust boundary, disable via
|
||||
``NOTIFY_BRIDGE_METRICS_ENABLED=false``.
|
||||
|
||||
Metrics are defined as module-level singletons so the rest of the codebase
|
||||
can ``from notify_bridge_server.api.metrics import metrics`` and call
|
||||
``metrics.dispatch_duration.labels(channel="telegram").observe(0.42)``
|
||||
without re-creating the underlying objects.
|
||||
|
||||
Other modules MUST NOT ``import prometheus_client`` directly. Route every
|
||||
metric through :data:`metrics` (a :class:`MetricsRegistry`) so we have one
|
||||
place to swap implementations or add labels.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Final
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from starlette.responses import Response
|
||||
|
||||
from prometheus_client import (
|
||||
CONTENT_TYPE_LATEST,
|
||||
CollectorRegistry,
|
||||
Counter,
|
||||
Gauge,
|
||||
Histogram,
|
||||
generate_latest,
|
||||
)
|
||||
|
||||
from ..config import settings as _settings
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metric definitions
|
||||
# ---------------------------------------------------------------------------
|
||||
# Use a dedicated CollectorRegistry instead of the global default registry so
|
||||
# tests can construct the module repeatedly without ``Duplicated timeseries``
|
||||
# errors and so we never accidentally export Python GC / process metrics that
|
||||
# aren't part of the documented surface in OPERATIONS.md.
|
||||
|
||||
_REGISTRY: Final[CollectorRegistry] = CollectorRegistry()
|
||||
|
||||
|
||||
class MetricsRegistry:
|
||||
"""Singleton holder for module-level Prometheus collectors.
|
||||
|
||||
Instantiated once at import time as :data:`metrics`. Keep collectors as
|
||||
instance attributes so call sites get IDE autocomplete and so swapping
|
||||
the collector type (e.g. Counter -> Summary) is a one-line change here.
|
||||
"""
|
||||
|
||||
def __init__(self, registry: CollectorRegistry) -> None:
|
||||
self.registry = registry
|
||||
|
||||
# Gauge: populated on every scrape via the collector hook below.
|
||||
self.deferred_pending = Gauge(
|
||||
"notify_bridge_deferred_pending",
|
||||
"Count of deferred_dispatch rows awaiting drain.",
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
# Counter: incremented after each event_log row is persisted.
|
||||
self.event_log_total = Counter(
|
||||
"notify_bridge_event_log_total",
|
||||
"Total events written to event_log, partitioned by status and event_type.",
|
||||
["status", "event_type"],
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
# Histogram: observed wall-clock seconds per outbound dispatch attempt.
|
||||
self.dispatch_duration = Histogram(
|
||||
"notify_bridge_dispatch_duration_seconds",
|
||||
"Wall-clock duration of one dispatch attempt to a notification channel.",
|
||||
["channel"],
|
||||
registry=registry,
|
||||
buckets=(0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 30.0),
|
||||
)
|
||||
|
||||
# Counter: each polling provider that fails a tick increments by 1.
|
||||
self.provider_poll_failures = Counter(
|
||||
"notify_bridge_provider_poll_failures_total",
|
||||
"Polling provider failures partitioned by provider type.",
|
||||
["provider_type"],
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
# Counter: each rejected delivery to a target increments by 1.
|
||||
self.target_send_failures = Counter(
|
||||
"notify_bridge_target_send_failures_total",
|
||||
"Failed sends to a target partitioned by target type and HTTP status.",
|
||||
["target_type", "status_code"],
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
|
||||
metrics: Final[MetricsRegistry] = MetricsRegistry(_REGISTRY)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scrape hook: refresh dynamic gauges on demand
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _refresh_deferred_pending_gauge() -> None:
|
||||
"""Populate ``deferred_pending`` by counting pending rows in the DB.
|
||||
|
||||
Called from the request handler before serializing — we don't poll the
|
||||
DB on a fixed cadence to avoid a steady-state cost when nothing is
|
||||
scraping. Kept tolerant: a DB error logs and leaves the previous value.
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
|
||||
from ..database.engine import get_engine
|
||||
|
||||
engine = get_engine()
|
||||
async with engine.connect() as conn:
|
||||
result = await conn.execute(
|
||||
text("SELECT count(*) FROM deferred_dispatch WHERE status='pending'")
|
||||
)
|
||||
row = result.first()
|
||||
count = int(row[0]) if row else 0
|
||||
metrics.deferred_pending.set(count)
|
||||
except Exception as exc: # noqa: BLE001 — never fail the scrape over this
|
||||
_LOGGER.debug("deferred_pending refresh skipped: %s", exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
router = APIRouter(tags=["metrics"])
|
||||
|
||||
|
||||
@router.get("/api/metrics")
|
||||
async def metrics_endpoint() -> Response:
|
||||
"""Expose collected metrics in Prometheus text format.
|
||||
|
||||
No auth by design — Prometheus scrapers don't authenticate. Gate the
|
||||
endpoint via ``NOTIFY_BRIDGE_METRICS_ENABLED=false`` when the API port
|
||||
is reachable from outside the trust boundary.
|
||||
"""
|
||||
if not _settings.metrics_enabled:
|
||||
raise HTTPException(status_code=404, detail="Metrics disabled")
|
||||
|
||||
await _refresh_deferred_pending_gauge()
|
||||
|
||||
# Stub increments so the endpoint reports non-empty data even before
|
||||
# callers wire instrumentation. Removed once code-paths are instrumented.
|
||||
# The labels here intentionally use a sentinel value so dashboards can
|
||||
# filter the noise out: ``status="bootstrap"``.
|
||||
metrics.event_log_total.labels(status="bootstrap", event_type="metrics_scrape").inc(0)
|
||||
|
||||
payload = generate_latest(_REGISTRY)
|
||||
return Response(content=payload, media_type=CONTENT_TYPE_LATEST)
|
||||
@@ -152,6 +152,10 @@ async def create_notification_tracker(
|
||||
session.add(tracker)
|
||||
await session.commit()
|
||||
await session.refresh(tracker)
|
||||
# Drop the cached enabled-trackers list so the next inbound event
|
||||
# (HA / webhook) sees the new tracker without waiting out the TTL.
|
||||
from ..services.event_dispatch import invalidate_tracker_cache
|
||||
invalidate_tracker_cache(tracker.provider_id)
|
||||
if tracker.enabled:
|
||||
await schedule_tracker(
|
||||
tracker.id, tracker.scan_interval,
|
||||
@@ -184,6 +188,8 @@ async def update_notification_tracker(
|
||||
session.add(tracker)
|
||||
await session.commit()
|
||||
await session.refresh(tracker)
|
||||
from ..services.event_dispatch import invalidate_tracker_cache
|
||||
invalidate_tracker_cache(tracker.provider_id)
|
||||
if tracker.enabled:
|
||||
await schedule_tracker(
|
||||
tracker.id, tracker.scan_interval,
|
||||
@@ -201,28 +207,39 @@ async def delete_notification_tracker(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Delete a tracker and its child rows in three bulk statements.
|
||||
|
||||
The previous implementation issued one DELETE per child row plus one
|
||||
UPDATE per event_log row, which scaled linearly with the tracker's
|
||||
history (an old, busy tracker could hit thousands of round-trips).
|
||||
Bulk DELETE/UPDATE collapses that to three SQL statements regardless
|
||||
of size.
|
||||
"""
|
||||
from sqlalchemy import delete as sa_delete, update as sa_update
|
||||
|
||||
tracker = await _get_user_tracker(session, tracker_id, user.id)
|
||||
# Delete associated tracker-target links
|
||||
result = await session.exec(
|
||||
select(NotificationTrackerTarget).where(NotificationTrackerTarget.tracker_id == tracker_id)
|
||||
# Junction rows — direct dependents of the tracker.
|
||||
await session.execute(
|
||||
sa_delete(NotificationTrackerTarget).where(
|
||||
NotificationTrackerTarget.tracker_id == tracker_id
|
||||
)
|
||||
)
|
||||
for tt in result.all():
|
||||
await session.delete(tt)
|
||||
# Delete associated tracker state
|
||||
state_result = await session.exec(
|
||||
select(NotificationTrackerState).where(NotificationTrackerState.tracker_id == tracker_id)
|
||||
# Persisted scan state for this tracker.
|
||||
await session.execute(
|
||||
sa_delete(NotificationTrackerState).where(
|
||||
NotificationTrackerState.tracker_id == tracker_id
|
||||
)
|
||||
)
|
||||
for ts in state_result.all():
|
||||
await session.delete(ts)
|
||||
# Nullify event log references
|
||||
event_result = await session.exec(
|
||||
select(EventLog).where(EventLog.tracker_id == tracker_id)
|
||||
# Preserve the audit trail in event_log; just null the back-reference
|
||||
# so the tracker row can be removed without an FK violation.
|
||||
await session.execute(
|
||||
sa_update(EventLog).where(EventLog.tracker_id == tracker_id).values(tracker_id=None)
|
||||
)
|
||||
for el in event_result.all():
|
||||
el.tracker_id = None
|
||||
session.add(el)
|
||||
provider_id_for_cache = tracker.provider_id
|
||||
await session.delete(tracker)
|
||||
await session.commit()
|
||||
from ..services.event_dispatch import invalidate_tracker_cache
|
||||
invalidate_tracker_cache(provider_id_for_cache)
|
||||
await unschedule_tracker(tracker_id)
|
||||
await reschedule_immich_dispatch_jobs()
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Service provider management API routes."""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import AnyHttpUrl, BaseModel, ValidationError, field_validator
|
||||
from pydantic import AnyHttpUrl, BaseModel, ValidationError, field_validator, model_validator
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from typing import Any
|
||||
@@ -94,14 +95,36 @@ class PayloadMapping(BaseModel):
|
||||
|
||||
|
||||
class WebhookProviderConfig(BaseModel):
|
||||
auth_mode: str = "none"
|
||||
# Default to bearer to avoid silently creating an open relay. Operators
|
||||
# who genuinely want an unauthenticated endpoint must set
|
||||
# ``acknowledge_unauthenticated=True`` to opt in explicitly.
|
||||
auth_mode: str = "bearer_token"
|
||||
webhook_secret: str | None = None
|
||||
# Explicit opt-in required for ``auth_mode="none"``. Without this flag
|
||||
# an unauthenticated webhook is rejected at validation time so a
|
||||
# mis-clicked dropdown can't expose the bridge to arbitrary internet
|
||||
# traffic.
|
||||
acknowledge_unauthenticated: bool = False
|
||||
payload_mappings: list[PayloadMapping] = []
|
||||
event_type_path: str | None = None
|
||||
collection_path: str | None = None
|
||||
store_payloads: bool = True
|
||||
max_stored_payloads: int = 20 # 1-100
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_auth(self) -> "WebhookProviderConfig":
|
||||
if self.auth_mode == "none" and not self.acknowledge_unauthenticated:
|
||||
raise ValueError(
|
||||
"auth_mode='none' creates an open webhook endpoint; set "
|
||||
"acknowledge_unauthenticated=true to confirm this is intentional"
|
||||
)
|
||||
if self.auth_mode in ("bearer_token", "hmac_sha256") and not self.webhook_secret:
|
||||
# Auto-generate a strong secret if the operator forgot to supply
|
||||
# one — better than rejecting an otherwise-valid config and far
|
||||
# better than silently leaving the endpoint open.
|
||||
self.webhook_secret = secrets.token_urlsafe(32)
|
||||
return self
|
||||
|
||||
|
||||
class HomeAssistantProviderConfig(BaseModel):
|
||||
url: str
|
||||
|
||||
@@ -291,15 +291,19 @@ async def get_nav_counts(
|
||||
):
|
||||
"""Return entity counts for sidebar navigation badges.
|
||||
|
||||
Note: queries run sequentially because SQLAlchemy AsyncSession is NOT safe
|
||||
for concurrent use within a single session (no asyncio.gather). We
|
||||
minimise round-trips by combining user + system counts and per-type
|
||||
target counts into single aggregate queries where possible.
|
||||
Combines user-owned counts, system-owned shared counts, and per-type
|
||||
target counts into a single round-trip via a UNION ALL of label + count
|
||||
rows. SQLAlchemy AsyncSession is single-threaded so we cannot
|
||||
asyncio.gather; collapsing 16 SELECTs into one is the optimisation.
|
||||
"""
|
||||
from sqlalchemy import literal, union_all
|
||||
|
||||
counts: dict[str, int] = {}
|
||||
|
||||
# --- 1) User-owned entity counts (one query per model) ---
|
||||
for model, key in [
|
||||
user_id = user.id
|
||||
|
||||
# User-owned counts: one (label, count) per model.
|
||||
user_models = [
|
||||
(ServiceProvider, "providers"),
|
||||
(NotificationTracker, "notification_trackers"),
|
||||
(TrackingConfig, "tracking_configs"),
|
||||
@@ -311,40 +315,52 @@ async def get_nav_counts(
|
||||
(CommandTracker, "command_trackers"),
|
||||
(CommandConfig, "command_configs"),
|
||||
(CommandTemplateConfig, "command_template_configs"),
|
||||
]:
|
||||
count = (await session.exec(
|
||||
select(func.count()).select_from(model).where(model.user_id == user.id)
|
||||
)).one()
|
||||
counts[key] = count
|
||||
|
||||
# --- 2) Add system-owned counts (user_id=0) for shared entities ---
|
||||
for model, key in [
|
||||
]
|
||||
# System-owned shared counts (user_id=0) folded back into the same key.
|
||||
system_models = [
|
||||
(TemplateConfig, "template_configs"),
|
||||
(CommandTemplateConfig, "command_template_configs"),
|
||||
(TrackingConfig, "tracking_configs"),
|
||||
(CommandConfig, "command_configs"),
|
||||
]:
|
||||
system_count = (await session.exec(
|
||||
select(func.count()).select_from(model).where(model.user_id == 0)
|
||||
)).one()
|
||||
counts[key] += system_count
|
||||
|
||||
# --- 3) Per-type target counts in a single query using conditional aggregation ---
|
||||
]
|
||||
target_types = ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix")
|
||||
type_counts_result = (await session.exec(
|
||||
select(
|
||||
NotificationTarget.type,
|
||||
func.count(),
|
||||
|
||||
# Initialise counts to 0 so missing UNION rows surface as zeroes
|
||||
# instead of KeyErrors when a category has no rows.
|
||||
for _model, key in user_models:
|
||||
counts[key] = 0
|
||||
for ttype in target_types:
|
||||
counts[f"targets_{ttype}"] = 0
|
||||
|
||||
queries = []
|
||||
for model, key in user_models:
|
||||
queries.append(
|
||||
select(literal(key).label("k"), func.count().label("c"))
|
||||
.select_from(model).where(model.user_id == user_id)
|
||||
)
|
||||
.where(
|
||||
NotificationTarget.user_id == user.id,
|
||||
NotificationTarget.type.in_(target_types),
|
||||
for model, key in system_models:
|
||||
queries.append(
|
||||
select(literal(f"__sys__:{key}").label("k"), func.count().label("c"))
|
||||
.select_from(model).where(model.user_id == 0)
|
||||
)
|
||||
.group_by(NotificationTarget.type)
|
||||
)).all()
|
||||
type_counts_map = dict(type_counts_result)
|
||||
for target_type in target_types:
|
||||
counts[f"targets_{target_type}"] = type_counts_map.get(target_type, 0)
|
||||
for ttype in target_types:
|
||||
queries.append(
|
||||
select(literal(f"target:{ttype}").label("k"), func.count().label("c"))
|
||||
.select_from(NotificationTarget).where(
|
||||
NotificationTarget.user_id == user_id,
|
||||
NotificationTarget.type == ttype,
|
||||
)
|
||||
)
|
||||
|
||||
union_q = union_all(*queries)
|
||||
rows = (await session.execute(union_q)).all()
|
||||
for label, value in rows:
|
||||
if label.startswith("__sys__:"):
|
||||
counts[label.removeprefix("__sys__:")] += int(value or 0)
|
||||
elif label.startswith("target:"):
|
||||
counts[f"targets_{label.removeprefix('target:')}"] = int(value or 0)
|
||||
else:
|
||||
counts[label] = int(value or 0)
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
@@ -287,6 +287,8 @@ async def get_template_variables(
|
||||
**_nut_variables(),
|
||||
# --- Home Assistant slots ---
|
||||
**_home_assistant_variables(),
|
||||
# --- Bridge self-monitoring slots ---
|
||||
**_bridge_self_variables(),
|
||||
# --- Scheduler slots ---
|
||||
"message_scheduled_message": {
|
||||
"description": "Notification for scheduled message events",
|
||||
@@ -487,6 +489,32 @@ def _home_assistant_variables() -> dict:
|
||||
}
|
||||
|
||||
|
||||
def _bridge_self_variables() -> dict:
|
||||
common = {
|
||||
"failure_type": "Which condition fired (poll_failures, deferred_backlog, target_failures)",
|
||||
"subject_id": "Affected entity ID (tracker_id, target_id, or 0 for backlog)",
|
||||
"subject_name": "Human-readable name of the affected entity",
|
||||
"count": "Consecutive failure count or current backlog size",
|
||||
"threshold": "Configured threshold that was crossed",
|
||||
"last_error": "Last underlying error message (truncated)",
|
||||
"details": "Extra structured context dict (use {{ details | tojson }})",
|
||||
}
|
||||
return {
|
||||
"message_bridge_self_poll_failures": {
|
||||
"description": "Tracker poll failures crossed threshold",
|
||||
"variables": common,
|
||||
},
|
||||
"message_bridge_self_deferred_backlog": {
|
||||
"description": "Deferred dispatch backlog crossed threshold",
|
||||
"variables": common,
|
||||
},
|
||||
"message_bridge_self_target_failures": {
|
||||
"description": "Target send failures crossed threshold",
|
||||
"variables": common,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_config(
|
||||
body: TemplateConfigCreate,
|
||||
|
||||
@@ -64,9 +64,19 @@ async def create_user(
|
||||
admin: User = Depends(require_admin),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Create a new user (admin only)."""
|
||||
"""Create a new user (admin only).
|
||||
|
||||
Username is normalised to ``strip().lower()`` so "Admin" and "admin"
|
||||
cannot coexist. We do not add a CHECK constraint at the DB level — that
|
||||
would require rebuilding the table on SQLite — so the application is
|
||||
the single source of truth for normalisation.
|
||||
"""
|
||||
# Normalise so case-only variants collide with existing accounts.
|
||||
username = (body.username or "").strip().lower()
|
||||
if not username:
|
||||
raise HTTPException(status_code=400, detail="Username cannot be empty")
|
||||
# Check for duplicate username
|
||||
result = await session.exec(select(User).where(User.username == body.username))
|
||||
result = await session.exec(select(User).where(User.username == username))
|
||||
if result.first():
|
||||
raise HTTPException(status_code=409, detail="Username already exists")
|
||||
|
||||
@@ -74,13 +84,25 @@ async def create_user(
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
|
||||
user = User(
|
||||
username=body.username,
|
||||
username=username,
|
||||
hashed_password=await _hash_password(body.password),
|
||||
role=body.role if body.role in ("admin", "user") else "user",
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
# Auto-create the bridge_self provider so the new user immediately gets
|
||||
# internal-failure notifications without manual setup. Best-effort —
|
||||
# a seeding hiccup must not fail the user creation itself.
|
||||
try:
|
||||
from ..database.seeds import ensure_bridge_self_provider_for_user
|
||||
await ensure_bridge_self_provider_for_user(session, user.id)
|
||||
await session.commit()
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception("Failed to auto-seed bridge_self provider for user %s", user.id)
|
||||
await session.rollback()
|
||||
|
||||
return {"id": user.id, "username": user.username, "role": user.role}
|
||||
|
||||
|
||||
@@ -103,14 +125,19 @@ async def update_user(
|
||||
identity_changed = False
|
||||
|
||||
if body.username is not None and body.username != user.username:
|
||||
new_username = body.username.strip()
|
||||
# Normalise to match the case-insensitive uniqueness rule applied
|
||||
# at user creation. Comparing the normalised form against the
|
||||
# stored username also avoids false-positive "no change" when a
|
||||
# legacy mixed-case account is being renamed to its lower form.
|
||||
new_username = (body.username or "").strip().lower()
|
||||
if not new_username:
|
||||
raise HTTPException(status_code=400, detail="Username cannot be empty")
|
||||
dup = await session.exec(select(User).where(User.username == new_username))
|
||||
if dup.first():
|
||||
raise HTTPException(status_code=409, detail="Username already exists")
|
||||
user.username = new_username
|
||||
identity_changed = True
|
||||
if new_username != user.username:
|
||||
dup = await session.exec(select(User).where(User.username == new_username))
|
||||
if dup.first():
|
||||
raise HTTPException(status_code=409, detail="Username already exists")
|
||||
user.username = new_username
|
||||
identity_changed = True
|
||||
|
||||
if body.role is not None and body.role != user.role:
|
||||
if body.role not in ("admin", "user"):
|
||||
@@ -191,11 +218,139 @@ async def delete_user(
|
||||
admin: User = Depends(require_admin),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Delete a user (admin only, cannot delete self)."""
|
||||
"""Delete a user (admin only, cannot delete self).
|
||||
|
||||
Cascades through every user-owned table by hand. The model declares
|
||||
``ondelete=CASCADE`` on each FK, but SQLite only enforces FK actions
|
||||
on tables created *after* the ondelete clause was added — existing
|
||||
installs upgraded from older schemas need this Python-side cascade
|
||||
instead of a multi-step table rebuild.
|
||||
|
||||
TODO: drop this manual cascade once we ship a real
|
||||
rebuild-with-FK-actions migration for legacy SQLite installs (or
|
||||
once Postgres becomes the default deployment target).
|
||||
"""
|
||||
from sqlalchemy import delete as sa_delete, update as sa_update
|
||||
|
||||
if user_id == admin.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete yourself")
|
||||
user = await session.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
|
||||
# Lazy import to avoid circulars.
|
||||
from ..database.models import (
|
||||
Action,
|
||||
ActionExecution,
|
||||
ActionRule,
|
||||
CommandConfig,
|
||||
CommandTracker,
|
||||
CommandTrackerListener,
|
||||
DeferredDispatch,
|
||||
EventLog,
|
||||
NotificationTarget,
|
||||
NotificationTracker,
|
||||
NotificationTrackerState,
|
||||
NotificationTrackerTarget,
|
||||
ServiceProvider,
|
||||
TelegramBot,
|
||||
TelegramChat,
|
||||
TrackingConfig,
|
||||
EmailBot,
|
||||
MatrixBot,
|
||||
)
|
||||
|
||||
# Wrap the entire cascade in one transaction so a failure mid-way
|
||||
# cannot leave dangling child rows pointing at a missing user.
|
||||
try:
|
||||
# Order: leaves first, then their parents, finally the user. This
|
||||
# matters even with FKs disabled — it's the natural dependency
|
||||
# graph and avoids accidental constraint trips on engines that do
|
||||
# enforce FKs (Postgres).
|
||||
|
||||
# Resolve tracker ids first (needed for state + link cleanup
|
||||
# before the parent rows themselves are deleted further down).
|
||||
from sqlmodel import select as _select
|
||||
tracker_ids = list((await session.exec(
|
||||
_select(NotificationTracker.id).where(NotificationTracker.user_id == user_id)
|
||||
)).all())
|
||||
if tracker_ids:
|
||||
await session.execute(
|
||||
sa_delete(NotificationTrackerState).where(
|
||||
NotificationTrackerState.tracker_id.in_(tracker_ids)
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
sa_delete(NotificationTrackerTarget).where(
|
||||
NotificationTrackerTarget.tracker_id.in_(tracker_ids)
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
sa_delete(DeferredDispatch).where(
|
||||
DeferredDispatch.tracker_id.in_(tracker_ids)
|
||||
)
|
||||
)
|
||||
|
||||
# Action children: rules and execution log.
|
||||
action_ids = list((await session.exec(
|
||||
_select(Action.id).where(Action.user_id == user_id)
|
||||
)).all())
|
||||
if action_ids:
|
||||
await session.execute(
|
||||
sa_delete(ActionRule).where(ActionRule.action_id.in_(action_ids))
|
||||
)
|
||||
await session.execute(
|
||||
sa_delete(ActionExecution).where(
|
||||
ActionExecution.action_id.in_(action_ids)
|
||||
)
|
||||
)
|
||||
|
||||
# Command tracker children: listeners.
|
||||
cmd_tracker_ids = list((await session.exec(
|
||||
_select(CommandTracker.id).where(CommandTracker.user_id == user_id)
|
||||
)).all())
|
||||
if cmd_tracker_ids:
|
||||
await session.execute(
|
||||
sa_delete(CommandTrackerListener).where(
|
||||
CommandTrackerListener.command_tracker_id.in_(cmd_tracker_ids)
|
||||
)
|
||||
)
|
||||
|
||||
# Telegram bot children: chats.
|
||||
bot_ids = list((await session.exec(
|
||||
_select(TelegramBot.id).where(TelegramBot.user_id == user_id)
|
||||
)).all())
|
||||
if bot_ids:
|
||||
await session.execute(
|
||||
sa_delete(TelegramChat).where(TelegramChat.bot_id.in_(bot_ids))
|
||||
)
|
||||
|
||||
# Owned top-level entities (user is a direct owner).
|
||||
for model in (
|
||||
NotificationTracker,
|
||||
NotificationTarget,
|
||||
CommandTracker,
|
||||
CommandConfig,
|
||||
TrackingConfig,
|
||||
Action,
|
||||
TelegramBot,
|
||||
EmailBot,
|
||||
MatrixBot,
|
||||
ServiceProvider,
|
||||
):
|
||||
await session.execute(
|
||||
sa_delete(model).where(model.user_id == user_id)
|
||||
)
|
||||
|
||||
# EventLog: keep the audit trail but null the owner reference so
|
||||
# the rows survive the user delete (matches the SET NULL semantic
|
||||
# declared on the model).
|
||||
await session.execute(
|
||||
sa_update(EventLog).where(EventLog.user_id == user_id).values(user_id=None)
|
||||
)
|
||||
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
@@ -12,6 +12,8 @@ from fastapi import APIRouter, HTTPException, Request
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..auth.routes import limiter
|
||||
|
||||
from notify_bridge_core.models.events import ServiceEvent
|
||||
from notify_bridge_core.providers.gitea.event_parser import parse_webhook as parse_gitea_webhook
|
||||
from notify_bridge_core.providers.planka.event_parser import parse_webhook as parse_planka_webhook
|
||||
@@ -240,6 +242,10 @@ async def planka_webhook(token: str, request: Request):
|
||||
if not _verify_planka_token(webhook_secret, request):
|
||||
raise HTTPException(status_code=403, detail="Invalid token")
|
||||
|
||||
# Read body AFTER auth check so an attacker without the bearer token
|
||||
# can't force an unbounded read. Token is in the header, not the body.
|
||||
raw_body = await _read_bounded_body(request)
|
||||
|
||||
# Parse payload from the bounded raw_body we already read.
|
||||
try:
|
||||
payload = json.loads(raw_body.decode("utf-8"))
|
||||
@@ -320,6 +326,8 @@ def _verify_generic_webhook_auth(
|
||||
_SENSITIVE_HEADER_SUBSTR = (
|
||||
"token", "auth", "key", "secret", "signature", "password", "credential",
|
||||
"cookie", "x-api", "x-hub-signature",
|
||||
# Extended for per-key body redaction; harmless extras for header check.
|
||||
"oauth", "client_secret", "webhook_secret", "csrf",
|
||||
)
|
||||
|
||||
|
||||
@@ -328,6 +336,28 @@ def _is_sensitive_header(name: str) -> bool:
|
||||
return any(s in n for s in _SENSITIVE_HEADER_SUBSTR)
|
||||
|
||||
|
||||
_REDACTED_PLACEHOLDER = "[REDACTED]"
|
||||
|
||||
|
||||
def _redact_sensitive_body(value: object) -> object:
|
||||
"""Walk a parsed JSON body and redact values for sensitive-named keys.
|
||||
|
||||
Returns a defensively-copied structure so the caller's object is
|
||||
never mutated (callers downstream still consume the original).
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
cleaned: dict[str, object] = {}
|
||||
for k, v in value.items():
|
||||
if isinstance(k, str) and _is_sensitive_header(k):
|
||||
cleaned[k] = _REDACTED_PLACEHOLDER
|
||||
else:
|
||||
cleaned[k] = _redact_sensitive_body(v)
|
||||
return cleaned
|
||||
if isinstance(value, list):
|
||||
return [_redact_sensitive_body(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def _filter_headers(raw_headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Keep only safe headers for logging (strip Authorization, signatures, tokens).
|
||||
|
||||
@@ -358,11 +388,15 @@ async def _save_webhook_log(
|
||||
"""Insert a webhook payload log entry and prune old ones."""
|
||||
try:
|
||||
body_json = body if isinstance(body, dict) else {}
|
||||
# Strip sensitive values before persistence — webhook payloads
|
||||
# routinely include OAuth tokens / secrets in the body, and the
|
||||
# log is admin-readable but not need-to-know for the operator.
|
||||
safe_body = _redact_sensitive_body(body_json) if body_json else {}
|
||||
session.add(WebhookPayloadLog(
|
||||
provider_id=provider_id,
|
||||
method=method,
|
||||
headers=headers,
|
||||
body=body_json,
|
||||
body=safe_body,
|
||||
status=status,
|
||||
extracted_fields=extracted_fields or {},
|
||||
error_message=error_message,
|
||||
@@ -386,13 +420,19 @@ async def _save_webhook_log(
|
||||
_LOGGER.warning("Failed to save webhook payload log for provider %d", provider_id, exc_info=True)
|
||||
try:
|
||||
await session.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception("Rollback after payload-log save failed")
|
||||
|
||||
|
||||
@router.post("/webhook/{token}")
|
||||
@limiter.limit("60/minute")
|
||||
async def generic_webhook(token: str, request: Request):
|
||||
"""Receive a generic webhook, extract variables via JSONPath, and dispatch notifications."""
|
||||
"""Receive a generic webhook, extract variables via JSONPath, and dispatch notifications.
|
||||
|
||||
Per-IP rate limit (60/min) caps blast radius from a single source —
|
||||
legitimate providers send well below this; anything higher is either
|
||||
a misconfigured retry loop or abuse.
|
||||
"""
|
||||
engine = get_engine()
|
||||
|
||||
# --- Load provider and validate auth ---
|
||||
|
||||
@@ -50,7 +50,12 @@ class RefreshRequest(BaseModel):
|
||||
|
||||
|
||||
async def _hash_password(password: str) -> str:
|
||||
"""bcrypt.hashpw is CPU-bound (~200-500ms); never run it on the event loop."""
|
||||
"""bcrypt.hashpw is CPU-bound (~200-500ms); never run it on the event loop.
|
||||
|
||||
Caller is responsible for length-validating ``password`` against the
|
||||
72-byte bcrypt cap before calling — bcrypt silently truncates beyond
|
||||
that, which is a correctness footgun, not a security one.
|
||||
"""
|
||||
|
||||
def _work() -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
@@ -58,6 +63,24 @@ async def _hash_password(password: str) -> str:
|
||||
return await asyncio.to_thread(_work)
|
||||
|
||||
|
||||
# bcrypt's algorithm cap — the underlying primitive truncates input
|
||||
# beyond this so two distinct passwords sharing a 72-byte prefix would
|
||||
# verify identically. We reject up-front with a clear 422 message.
|
||||
_BCRYPT_MAX_PASSWORD_BYTES = 72
|
||||
|
||||
|
||||
def _check_bcrypt_length(password: str) -> None:
|
||||
if len(password.encode("utf-8")) > _BCRYPT_MAX_PASSWORD_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=(
|
||||
f"Password too long; bcrypt limit is "
|
||||
f"{_BCRYPT_MAX_PASSWORD_BYTES} bytes (longer passwords would "
|
||||
"be silently truncated)"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _verify_password(password: str, hashed: str) -> bool:
|
||||
def _work() -> bool:
|
||||
try:
|
||||
@@ -74,6 +97,7 @@ async def _verify_password(password: str, hashed: str) -> bool:
|
||||
async def setup(request: Request, body: SetupRequest, session: AsyncSession = Depends(get_session)):
|
||||
if len(body.password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
_check_bcrypt_length(body.password)
|
||||
# Compute hash BEFORE opening the transaction so we don't hold a writer lock
|
||||
# during the CPU-bound bcrypt work.
|
||||
hashed = await _hash_password(body.password)
|
||||
@@ -97,6 +121,16 @@ async def setup(request: Request, body: SetupRequest, session: AsyncSession = De
|
||||
session.add(user)
|
||||
await session.refresh(user)
|
||||
|
||||
# Auto-create the bridge_self provider for the new admin so internal-
|
||||
# failure notifications work out of the box. Best-effort — a seeding
|
||||
# failure should not abort setup.
|
||||
try:
|
||||
from ..database.seeds import ensure_bridge_self_provider_for_user
|
||||
await ensure_bridge_self_provider_for_user(session, user.id)
|
||||
await session.commit()
|
||||
except Exception: # noqa: BLE001
|
||||
await session.rollback()
|
||||
|
||||
return TokenResponse(
|
||||
access_token=create_access_token(user.id, user.role, user.token_version),
|
||||
refresh_token=create_refresh_token(user.id, user.token_version),
|
||||
@@ -170,6 +204,7 @@ async def change_password(
|
||||
raise HTTPException(status_code=400, detail="Current password is incorrect")
|
||||
if len(body.new_password) < 8:
|
||||
raise HTTPException(status_code=400, detail="New password must be at least 8 characters")
|
||||
_check_bcrypt_length(body.new_password)
|
||||
user.hashed_password = await _hash_password(body.new_password)
|
||||
user.token_version = (user.token_version or 1) + 1
|
||||
session.add(user)
|
||||
|
||||
@@ -43,10 +43,21 @@ async def telegram_webhook(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Handle incoming Telegram messages — route commands to handlers."""
|
||||
# Validate webhook secret if configured
|
||||
if _webhook_secret:
|
||||
if not hmac.compare_digest(x_telegram_bot_api_secret_token or "", _webhook_secret):
|
||||
raise HTTPException(status_code=403, detail="Invalid webhook secret")
|
||||
# Telegram webhook secret is MANDATORY: without it any peer that knows
|
||||
# the opaque webhook URL could inject arbitrary updates as if Telegram
|
||||
# had sent them. Refuse to handle updates if no secret is configured.
|
||||
if not _webhook_secret:
|
||||
_LOGGER.error(
|
||||
"Refusing Telegram webhook update for %s — webhook secret not configured "
|
||||
"(set NOTIFY_BRIDGE_TELEGRAM_WEBHOOK_SECRET)",
|
||||
webhook_id,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Telegram webhook secret not configured on this server",
|
||||
)
|
||||
if not hmac.compare_digest(x_telegram_bot_api_secret_token or "", _webhook_secret):
|
||||
raise HTTPException(status_code=403, detail="Invalid webhook secret")
|
||||
|
||||
# Find bot by opaque webhook path ID (not by token — token must not appear in URLs)
|
||||
bot_result = await session.exec(
|
||||
@@ -161,7 +172,17 @@ async def telegram_webhook(
|
||||
|
||||
|
||||
async def register_webhook(bot_token: str, webhook_url: str, secret: str | None = None) -> dict:
|
||||
"""Register webhook URL with Telegram Bot API via TelegramClient."""
|
||||
"""Register webhook URL with Telegram Bot API via TelegramClient.
|
||||
|
||||
Refuses to register without a secret: a webhook without a secret
|
||||
accepts any unauthenticated POST as a valid Telegram update, so we
|
||||
never want one in production.
|
||||
"""
|
||||
if not secret:
|
||||
raise ValueError(
|
||||
"Telegram webhook registration requires a secret token "
|
||||
"(set NOTIFY_BRIDGE_TELEGRAM_WEBHOOK_SECRET)"
|
||||
)
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, bot_token)
|
||||
|
||||
@@ -76,6 +76,13 @@ class Settings(BaseSettings):
|
||||
before migrations run using SQLite's ``VACUUM INTO`` (atomic, consistent).
|
||||
"""
|
||||
|
||||
metrics_enabled: bool = True
|
||||
"""Expose the Prometheus ``/api/metrics`` endpoint. Disable on hardened
|
||||
deployments where the API port is exposed beyond the trust boundary —
|
||||
metrics are unauthenticated and can leak operational information about
|
||||
queue depth, dispatch rates, and provider failures.
|
||||
"""
|
||||
|
||||
model_config = {"env_prefix": "NOTIFY_BRIDGE_"}
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
|
||||
@@ -309,6 +309,22 @@ async def migrate_schema(engine: AsyncEngine) -> None:
|
||||
)
|
||||
logger.info("Added %s column to tracking_config table", col_name)
|
||||
|
||||
# Add Bridge self-monitoring tracking flags to tracking_config if missing.
|
||||
# All three default ON — the bridge_self provider exists specifically
|
||||
# to surface these conditions, so silencing one would defeat the point.
|
||||
if await _has_table(conn, "tracking_config"):
|
||||
bridge_self_flags = [
|
||||
("track_bridge_self_poll_failures", "INTEGER DEFAULT 1"),
|
||||
("track_bridge_self_deferred_backlog", "INTEGER DEFAULT 1"),
|
||||
("track_bridge_self_target_failures", "INTEGER DEFAULT 1"),
|
||||
]
|
||||
for col_name, col_type in bridge_self_flags:
|
||||
if not await _has_column(conn, "tracking_config", col_name):
|
||||
await conn.execute(
|
||||
text(f"ALTER TABLE tracking_config ADD COLUMN {col_name} {col_type}")
|
||||
)
|
||||
logger.info("Added %s column to tracking_config table", col_name)
|
||||
|
||||
# Add quiet hours to tracking_config if missing.
|
||||
# Start/end are nullable HH:MM strings; quiet_hours_enabled gates them.
|
||||
if await _has_table(conn, "tracking_config"):
|
||||
@@ -1361,6 +1377,12 @@ _INDEXES: list[tuple[str, str, str]] = [
|
||||
("ix_action_provider_id", "action", "provider_id"),
|
||||
# Dashboard: SELECT event_log WHERE user_id = ? ORDER BY created_at DESC
|
||||
("ix_event_log_user_created", "event_log", "user_id, created_at DESC"),
|
||||
# Dashboard "events of type X for me, recent first" filter.
|
||||
(
|
||||
"ix_event_log_user_event_type_created",
|
||||
"event_log",
|
||||
"user_id, event_type, created_at DESC",
|
||||
),
|
||||
("ix_event_log_provider_id", "event_log", "provider_id"),
|
||||
("ix_event_log_notification_tracker_id", "event_log", "notification_tracker_id"),
|
||||
("ix_event_log_action_id", "event_log", "action_id"),
|
||||
@@ -1543,6 +1565,269 @@ async def migrate_chat_action_to_column(engine: AsyncEngine) -> None:
|
||||
logger.info("Migrated chat_action from config JSON to column where present")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Uniqueness + dedupe migrations for webhook hot paths.
|
||||
#
|
||||
# These backfill missing UNIQUE indexes on webhook tokens, webhook path IDs,
|
||||
# bot_id (with sentinel guard), (bot_id, chat_id), and tracker-target links.
|
||||
# Every CREATE UNIQUE INDEX is preceded by a dedupe pass that keeps the
|
||||
# canonical row (lowest id, or oldest created_at where specified) and removes
|
||||
# the rest, logging a WARNING with the dropped count so operators can audit.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _dedupe_by_columns(
|
||||
conn,
|
||||
table: str,
|
||||
cols: list[str],
|
||||
*,
|
||||
keep: str = "min_id",
|
||||
label: str = "",
|
||||
) -> int:
|
||||
"""Delete duplicate rows leaving one survivor per ``cols`` group.
|
||||
|
||||
``keep`` chooses the survivor:
|
||||
- ``"min_id"`` keeps the row with the lowest ``id`` (default — used
|
||||
when there is no semantic "first" row to preserve).
|
||||
- ``"min_created_at"`` keeps the row with the oldest ``created_at``,
|
||||
falling back to the lowest id on ties — preferred for tracker-target
|
||||
links so the original link wins.
|
||||
|
||||
Returns the number of rows deleted. All identifiers flow through
|
||||
``_assert_ident`` to neutralise SQL injection from any caller mistake.
|
||||
"""
|
||||
_assert_ident(table, "table")
|
||||
for c in cols:
|
||||
_assert_ident(c, "column")
|
||||
group_by = ", ".join(cols)
|
||||
where_cols = " AND ".join(f"{c} = g.{c}" for c in cols)
|
||||
if keep == "min_created_at":
|
||||
# Tie-break on id so the survivor is deterministic even if two rows
|
||||
# share the same created_at (insert-batches commonly do).
|
||||
survivor_sql = (
|
||||
f"SELECT id FROM {table} "
|
||||
f"WHERE {where_cols} "
|
||||
f"ORDER BY created_at ASC, id ASC LIMIT 1"
|
||||
)
|
||||
elif keep == "min_id":
|
||||
survivor_sql = f"SELECT MIN(id) FROM {table} WHERE {where_cols}"
|
||||
else:
|
||||
raise ValueError(f"Unknown keep strategy: {keep!r}")
|
||||
|
||||
delete_sql = (
|
||||
f"DELETE FROM {table} WHERE id IN ("
|
||||
f" SELECT t.id FROM {table} t "
|
||||
f" JOIN ("
|
||||
f" SELECT {group_by} FROM {table} "
|
||||
f" GROUP BY {group_by} HAVING COUNT(*) > 1"
|
||||
f" ) g ON {' AND '.join(f't.{c} = g.{c}' for c in cols)} "
|
||||
f" WHERE t.id NOT IN ({survivor_sql})"
|
||||
f")"
|
||||
)
|
||||
result = await conn.execute(text(delete_sql))
|
||||
deleted = int(getattr(result, "rowcount", 0) or 0)
|
||||
if deleted:
|
||||
logger.warning(
|
||||
"Removed %d duplicate row(s) from %s on (%s)%s",
|
||||
deleted, table, ", ".join(cols),
|
||||
f" — {label}" if label else "",
|
||||
)
|
||||
return deleted
|
||||
|
||||
|
||||
async def migrate_uniqueness_constraints(engine: AsyncEngine) -> None:
|
||||
"""Backfill missing UNIQUE indexes on webhook hot paths.
|
||||
|
||||
SQLite cannot ALTER an existing column to add a UNIQUE constraint, but
|
||||
a UNIQUE INDEX is functionally equivalent and can be created with
|
||||
``IF NOT EXISTS`` on every boot. Each index is preceded by a dedupe
|
||||
pass so the index creation does not fail on existing duplicates.
|
||||
|
||||
Indexes added:
|
||||
- service_provider.webhook_token (full unique)
|
||||
- telegram_bot.webhook_path_id (full unique)
|
||||
- telegram_bot.bot_id (partial unique WHERE bot_id != 0; 0 is a
|
||||
sentinel meaning "not yet validated")
|
||||
- telegram_chat (bot_id, chat_id) (full unique composite)
|
||||
- notification_tracker_target (notification_tracker_id, target_id)
|
||||
(full unique composite)
|
||||
"""
|
||||
# Skip on non-SQLite engines — they enforce UNIQUE via the model
|
||||
# metadata (create_all) and don't have sqlite_master introspection.
|
||||
if not str(engine.url).startswith("sqlite"):
|
||||
return
|
||||
async with engine.begin() as conn:
|
||||
# service_provider.webhook_token
|
||||
if await _has_table(conn, "service_provider") and await _has_column(
|
||||
conn, "service_provider", "webhook_token",
|
||||
):
|
||||
await _dedupe_by_columns(
|
||||
conn, "service_provider", ["webhook_token"],
|
||||
keep="min_id", label="webhook_token uniqueness",
|
||||
)
|
||||
await conn.execute(text(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS "
|
||||
"uq_service_provider_webhook_token "
|
||||
"ON service_provider(webhook_token)"
|
||||
))
|
||||
|
||||
# telegram_bot.webhook_path_id (full unique)
|
||||
# telegram_bot.bot_id (partial unique excluding sentinel 0)
|
||||
if await _has_table(conn, "telegram_bot"):
|
||||
if await _has_column(conn, "telegram_bot", "webhook_path_id"):
|
||||
await _dedupe_by_columns(
|
||||
conn, "telegram_bot", ["webhook_path_id"],
|
||||
keep="min_id", label="webhook_path_id uniqueness",
|
||||
)
|
||||
await conn.execute(text(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS "
|
||||
"uq_telegram_bot_webhook_path_id "
|
||||
"ON telegram_bot(webhook_path_id)"
|
||||
))
|
||||
if await _has_column(conn, "telegram_bot", "bot_id"):
|
||||
# Dedupe only non-sentinel rows. Two unverified bots both
|
||||
# carrying bot_id=0 is legitimate — only collisions among
|
||||
# validated bot_ids signal a real corruption to clean up.
|
||||
deleted = await conn.execute(text(
|
||||
"DELETE FROM telegram_bot WHERE id IN ("
|
||||
" SELECT t.id FROM telegram_bot t "
|
||||
" JOIN ("
|
||||
" SELECT bot_id FROM telegram_bot "
|
||||
" WHERE bot_id != 0 GROUP BY bot_id HAVING COUNT(*) > 1"
|
||||
" ) g ON t.bot_id = g.bot_id "
|
||||
" WHERE t.id NOT IN ("
|
||||
" SELECT MIN(id) FROM telegram_bot WHERE bot_id = g.bot_id"
|
||||
" )"
|
||||
")"
|
||||
))
|
||||
rc = int(getattr(deleted, "rowcount", 0) or 0)
|
||||
if rc:
|
||||
logger.warning(
|
||||
"Removed %d duplicate telegram_bot row(s) on bot_id "
|
||||
"(non-sentinel collisions)", rc,
|
||||
)
|
||||
# Plain INDEX for the lookup-by-bot_id path.
|
||||
await conn.execute(text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_telegram_bot_bot_id "
|
||||
"ON telegram_bot(bot_id)"
|
||||
))
|
||||
# Partial UNIQUE excluding the sentinel.
|
||||
await conn.execute(text(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS "
|
||||
"uq_telegram_bot_bot_id_nonzero "
|
||||
"ON telegram_bot(bot_id) WHERE bot_id != 0"
|
||||
))
|
||||
|
||||
# telegram_chat (bot_id, chat_id) — keep the survivor with the oldest
|
||||
# discovered_at so the original discovery row wins. _dedupe_by_columns
|
||||
# only handles created_at; do this one inline.
|
||||
if await _has_table(conn, "telegram_chat"):
|
||||
res = await conn.execute(text(
|
||||
"DELETE FROM telegram_chat WHERE id IN ("
|
||||
" SELECT t.id FROM telegram_chat t "
|
||||
" JOIN ("
|
||||
" SELECT bot_id, chat_id FROM telegram_chat "
|
||||
" GROUP BY bot_id, chat_id HAVING COUNT(*) > 1"
|
||||
" ) g ON t.bot_id = g.bot_id AND t.chat_id = g.chat_id "
|
||||
" WHERE t.id NOT IN ("
|
||||
" SELECT id FROM telegram_chat "
|
||||
" WHERE bot_id = g.bot_id AND chat_id = g.chat_id "
|
||||
" ORDER BY discovered_at ASC, id ASC LIMIT 1"
|
||||
" )"
|
||||
")"
|
||||
))
|
||||
rc = int(getattr(res, "rowcount", 0) or 0)
|
||||
if rc:
|
||||
logger.warning(
|
||||
"Removed %d duplicate telegram_chat row(s) on (bot_id, chat_id)",
|
||||
rc,
|
||||
)
|
||||
await conn.execute(text(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS uq_telegram_chat_bot_chat "
|
||||
"ON telegram_chat(bot_id, chat_id)"
|
||||
))
|
||||
await conn.execute(text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_telegram_chat_bot_chat "
|
||||
"ON telegram_chat(bot_id, chat_id)"
|
||||
))
|
||||
|
||||
# notification_tracker_target (notification_tracker_id, target_id)
|
||||
# — keep the oldest created_at link so the original wins.
|
||||
if await _has_table(conn, "notification_tracker_target") and await _has_column(
|
||||
conn, "notification_tracker_target", "notification_tracker_id",
|
||||
):
|
||||
await _dedupe_by_columns(
|
||||
conn,
|
||||
"notification_tracker_target",
|
||||
["notification_tracker_id", "target_id"],
|
||||
keep="min_created_at",
|
||||
label="tracker-target link uniqueness",
|
||||
)
|
||||
await conn.execute(text(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS uq_ntt_tracker_target "
|
||||
"ON notification_tracker_target(notification_tracker_id, target_id)"
|
||||
))
|
||||
|
||||
# service_provider partial unique on (user_id) WHERE type='bridge_self'.
|
||||
# Bridge-self is special: exactly one row per user, auto-seeded at boot,
|
||||
# at user-create, and on /setup. Without this guard, a concurrent boot
|
||||
# backfill + POST /api/users could double-insert. Dedupe keeps the
|
||||
# oldest row so any user-customised thresholds on it survive.
|
||||
if await _has_table(conn, "service_provider"):
|
||||
res = await conn.execute(text(
|
||||
"DELETE FROM service_provider WHERE id IN ("
|
||||
" SELECT t.id FROM service_provider t "
|
||||
" JOIN ("
|
||||
" SELECT user_id FROM service_provider "
|
||||
" WHERE type='bridge_self' GROUP BY user_id HAVING COUNT(*) > 1"
|
||||
" ) g ON t.user_id = g.user_id "
|
||||
" WHERE t.type='bridge_self' AND t.id NOT IN ("
|
||||
" SELECT MIN(id) FROM service_provider "
|
||||
" WHERE type='bridge_self' AND user_id = g.user_id"
|
||||
" )"
|
||||
")"
|
||||
))
|
||||
rc = int(getattr(res, "rowcount", 0) or 0)
|
||||
if rc:
|
||||
logger.warning(
|
||||
"Removed %d duplicate bridge_self service_provider row(s) "
|
||||
"on user_id", rc,
|
||||
)
|
||||
await conn.execute(text(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS "
|
||||
"uq_service_provider_bridge_self_per_user "
|
||||
"ON service_provider(user_id) WHERE type='bridge_self'"
|
||||
))
|
||||
|
||||
|
||||
async def migrate_eventlog_provider_fk(engine: AsyncEngine) -> None:
|
||||
"""Document the EventLog.provider_id FK situation.
|
||||
|
||||
SQLite cannot ALTER a column to add a foreign-key constraint without
|
||||
rebuilding the table. The model annotation now declares
|
||||
``ondelete=SET NULL`` which only takes effect on freshly created
|
||||
tables (i.e. brand-new installs). For existing installs we rely on
|
||||
application-side cleanup in ``api/providers.delete_provider`` to NULL
|
||||
out ``event_log.provider_id`` rows before deleting the provider row.
|
||||
|
||||
This migration is intentionally a no-op aside from the log line — it
|
||||
exists so the migration order is explicit and operators see in the
|
||||
logs that the FK strategy was reviewed on this boot.
|
||||
"""
|
||||
if not str(engine.url).startswith("sqlite"):
|
||||
return
|
||||
async with engine.begin() as conn:
|
||||
if not await _has_table(conn, "event_log"):
|
||||
return
|
||||
# No DDL change. Application code in api/providers.delete_provider
|
||||
# is the source of truth for the SET NULL semantic on existing tables.
|
||||
logger.debug(
|
||||
"event_log.provider_id FK enforcement deferred to application "
|
||||
"code on existing SQLite tables (model declares ondelete=SET NULL "
|
||||
"which applies to fresh schemas only)."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema version tracking — lightweight alternative to Alembic while the
|
||||
# hand-rolled idempotent migrations remain the source of truth. Gives
|
||||
|
||||
@@ -6,7 +6,7 @@ from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import ForeignKey, UniqueConstraint, Text
|
||||
from sqlalchemy import ForeignKey, Index, UniqueConstraint, Text
|
||||
from sqlmodel import JSON, Column, Field, SQLModel
|
||||
|
||||
|
||||
@@ -29,12 +29,25 @@ class ServiceProvider(SQLModel, table=True):
|
||||
__tablename__ = "service_provider"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
type: str # ServiceProviderType value ("immich")
|
||||
name: str
|
||||
icon: str = Field(default="")
|
||||
config: dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON))
|
||||
webhook_token: str = Field(default_factory=lambda: uuid4().hex)
|
||||
# Webhook token is the shared secret embedded in inbound webhook URLs.
|
||||
# Must be unique so a token uniquely identifies a provider; indexed so
|
||||
# the webhook router does an O(log n) lookup on every inbound request.
|
||||
webhook_token: str = Field(
|
||||
default_factory=lambda: uuid4().hex,
|
||||
unique=True,
|
||||
index=True,
|
||||
)
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
@@ -42,13 +55,29 @@ class TelegramBot(SQLModel, table=True):
|
||||
__tablename__ = "telegram_bot"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
name: str
|
||||
token: str
|
||||
icon: str = Field(default="")
|
||||
bot_username: str = Field(default="")
|
||||
bot_id: int = Field(default=0)
|
||||
webhook_path_id: str = Field(default_factory=lambda: uuid4().hex)
|
||||
# bot_id=0 is a sentinel meaning "Telegram has not yet returned a numeric
|
||||
# ID for this bot" (i.e. token never validated). Multiple unverified bots
|
||||
# may legitimately carry 0, so we only enforce uniqueness for non-sentinel
|
||||
# values via a partial index added in migrate_uniqueness_constraints.
|
||||
bot_id: int = Field(default=0, index=True)
|
||||
# URL-path embedded in Telegram's setWebhook callback URL. Must be unique
|
||||
# so the inbound dispatcher resolves a single bot per incoming request.
|
||||
webhook_path_id: str = Field(
|
||||
default_factory=lambda: uuid4().hex,
|
||||
unique=True,
|
||||
index=True,
|
||||
)
|
||||
update_mode: str = Field(default="none") # "none", "polling", or "webhook"
|
||||
# NOTE: commands_config column remains in the DB for backward compat,
|
||||
# but is no longer part of the SQLModel class. Data migrated to CommandConfig.
|
||||
@@ -61,7 +90,13 @@ class MatrixBot(SQLModel, table=True):
|
||||
__tablename__ = "matrix_bot"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
name: str
|
||||
icon: str = Field(default="")
|
||||
homeserver_url: str # e.g. https://matrix.org
|
||||
@@ -76,7 +111,13 @@ class EmailBot(SQLModel, table=True):
|
||||
__tablename__ = "email_bot"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
name: str
|
||||
icon: str = Field(default="")
|
||||
email: str # From address
|
||||
@@ -90,6 +131,13 @@ class EmailBot(SQLModel, table=True):
|
||||
|
||||
class TelegramChat(SQLModel, table=True):
|
||||
__tablename__ = "telegram_chat"
|
||||
# (bot_id, chat_id) uniquely identifies a chat. The composite index is
|
||||
# the access pattern for save_chat_from_webhook ON CONFLICT updates and
|
||||
# for any "lookup by (bot, chat)" callers.
|
||||
__table_args__ = (
|
||||
UniqueConstraint("bot_id", "chat_id", name="uq_telegram_chat_bot_chat"),
|
||||
Index("ix_telegram_chat_bot_chat", "bot_id", "chat_id"),
|
||||
)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
bot_id: int = Field(foreign_key="telegram_bot.id")
|
||||
@@ -109,7 +157,13 @@ class TrackingConfig(SQLModel, table=True):
|
||||
__tablename__ = "tracking_config"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
provider_type: str # Must match provider's type
|
||||
name: str
|
||||
icon: str = Field(default="")
|
||||
@@ -171,6 +225,12 @@ class TrackingConfig(SQLModel, table=True):
|
||||
track_ha_service_called: bool = Field(default=False)
|
||||
track_ha_event_fired: bool = Field(default=False)
|
||||
|
||||
# Bridge self-monitoring event tracking — defaults ON because the whole
|
||||
# point of the provider is to alert on these conditions.
|
||||
track_bridge_self_poll_failures: bool = Field(default=True)
|
||||
track_bridge_self_deferred_backlog: bool = Field(default=True)
|
||||
track_bridge_self_target_failures: bool = Field(default=True)
|
||||
|
||||
# Immich asset display
|
||||
track_images: bool = Field(default=True)
|
||||
track_videos: bool = Field(default=True)
|
||||
@@ -276,7 +336,13 @@ class NotificationTarget(SQLModel, table=True):
|
||||
__tablename__ = "notification_target"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
type: str # "telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"
|
||||
name: str
|
||||
icon: str = Field(default="")
|
||||
@@ -319,7 +385,13 @@ class NotificationTracker(SQLModel, table=True):
|
||||
__tablename__ = "notification_tracker"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
provider_id: int = Field(foreign_key="service_provider.id")
|
||||
name: str
|
||||
icon: str = Field(default="")
|
||||
@@ -342,6 +414,15 @@ class NotificationTrackerTarget(SQLModel, table=True):
|
||||
"""Junction between NotificationTracker and NotificationTarget with per-link config."""
|
||||
|
||||
__tablename__ = "notification_tracker_target"
|
||||
# A tracker should never link to the same target twice — duplicate links
|
||||
# would deliver the same notification multiple times. Enforced at the DB
|
||||
# level so concurrent inserts can't bypass an application-side check.
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"notification_tracker_id", "target_id",
|
||||
name="uq_ntt_tracker_target",
|
||||
),
|
||||
)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
# Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id
|
||||
@@ -403,7 +484,13 @@ class CommandConfig(SQLModel, table=True):
|
||||
__tablename__ = "command_config"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
provider_type: str
|
||||
name: str
|
||||
icon: str = Field(default="")
|
||||
@@ -464,7 +551,13 @@ class CommandTracker(SQLModel, table=True):
|
||||
__tablename__ = "command_tracker"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
provider_id: int = Field(foreign_key="service_provider.id")
|
||||
command_config_id: int = Field(foreign_key="command_config.id")
|
||||
name: str
|
||||
@@ -517,7 +610,15 @@ class DeferredDispatch(SQLModel, table=True):
|
||||
__tablename__ = "deferred_dispatch"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int | None = Field(default=None, foreign_key="user.id", index=True)
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
tracker_id: int = Field(foreign_key="notification_tracker.id", index=True)
|
||||
# The specific link this deferral targets. On drain we re-fetch by ID; if
|
||||
# the link was disabled or removed in the meantime we drop with a
|
||||
@@ -566,8 +667,17 @@ class EventLog(SQLModel, table=True):
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
# Owner. Indexed for the dashboard events query. Nullable only because
|
||||
# historical rows (pre-user_id column) may have no owner; new rows always
|
||||
# set this directly.
|
||||
user_id: int | None = Field(default=None, foreign_key="user.id", index=True)
|
||||
# set this directly. SET NULL on user delete preserves the audit trail
|
||||
# while letting the user record itself be removed.
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
# Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id
|
||||
tracker_id: int | None = Field(
|
||||
default=None,
|
||||
@@ -594,7 +704,21 @@ class EventLog(SQLModel, table=True):
|
||||
default=None, foreign_key="telegram_bot.id", index=True,
|
||||
)
|
||||
bot_name: str = Field(default="")
|
||||
provider_id: int | None = Field(default=None, index=True)
|
||||
# FK to service_provider with SET NULL so deleting a provider leaves
|
||||
# historical event_log rows intact (provider_name preserves the label
|
||||
# for display). The FK only takes effect on freshly created tables —
|
||||
# SQLite cannot ALTER a constraint into an existing table without a
|
||||
# rebuild, so application code in api/providers.delete_provider also
|
||||
# nulls these explicitly. See migrate_eventlog_provider_fk.
|
||||
provider_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
"provider_id",
|
||||
ForeignKey("service_provider.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
provider_name: str = Field(default="")
|
||||
event_type: str = Field(index=True)
|
||||
collection_id: str
|
||||
@@ -610,7 +734,13 @@ class Action(SQLModel, table=True):
|
||||
__tablename__ = "action"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
user_id: int = Field(
|
||||
sa_column=Column(
|
||||
"user_id",
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
provider_id: int = Field(foreign_key="service_provider.id")
|
||||
name: str
|
||||
icon: str = Field(default="")
|
||||
|
||||
@@ -13,9 +13,11 @@ from .models import (
|
||||
CommandConfig,
|
||||
CommandTemplateConfig,
|
||||
CommandTemplateSlot,
|
||||
ServiceProvider,
|
||||
TemplateConfig,
|
||||
TemplateSlot,
|
||||
TrackingConfig,
|
||||
User,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -159,6 +161,7 @@ async def _seed_default_templates() -> None:
|
||||
await _seed_provider_template(session, "google_photos", "Google Photos")
|
||||
await _seed_provider_template(session, "webhook", "Generic Webhook")
|
||||
await _seed_provider_template(session, "home_assistant", "Home Assistant")
|
||||
await _seed_provider_template(session, "bridge_self", "Bridge Self-Monitoring")
|
||||
await session.commit()
|
||||
|
||||
|
||||
@@ -285,6 +288,13 @@ async def _seed_default_tracking_configs() -> None:
|
||||
"track_ha_service_called": False,
|
||||
"track_ha_event_fired": False,
|
||||
},
|
||||
{
|
||||
"provider_type": "bridge_self",
|
||||
"name": "Default Bridge Self-Monitoring",
|
||||
"track_bridge_self_poll_failures": True,
|
||||
"track_bridge_self_deferred_backlog": True,
|
||||
"track_bridge_self_target_failures": True,
|
||||
},
|
||||
]
|
||||
|
||||
for cfg in defaults:
|
||||
@@ -403,6 +413,67 @@ async def _seed_default_command_configs() -> None:
|
||||
await session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bridge self-monitoring per-user provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Default thresholds — duplicated here as constants instead of imported so the
|
||||
# seed module stays self-contained and import-cycle-free during boot.
|
||||
_BRIDGE_SELF_DEFAULT_CONFIG = {
|
||||
"poll_failure_threshold": 3,
|
||||
"deferred_backlog_threshold": 100,
|
||||
"target_failure_threshold": 5,
|
||||
}
|
||||
|
||||
|
||||
async def ensure_bridge_self_provider_for_user(
|
||||
session: AsyncSession, user_id: int,
|
||||
) -> ServiceProvider | None:
|
||||
"""Create the user's bridge_self provider if absent. Returns the provider.
|
||||
|
||||
The bridge_self provider is special — exactly one per user, auto-created
|
||||
so the operator never has to think about wiring it up. Idempotent.
|
||||
Skips ``user_id <= 0`` (the ``__system__`` placeholder) which never
|
||||
receives notifications.
|
||||
"""
|
||||
if user_id <= 0:
|
||||
return None
|
||||
result = await session.exec(
|
||||
select(ServiceProvider).where(
|
||||
ServiceProvider.user_id == user_id,
|
||||
ServiceProvider.type == "bridge_self",
|
||||
)
|
||||
)
|
||||
existing = result.first()
|
||||
if existing is not None:
|
||||
return existing
|
||||
provider = ServiceProvider(
|
||||
user_id=user_id,
|
||||
type="bridge_self",
|
||||
name="Bridge Self-Monitoring",
|
||||
config=dict(_BRIDGE_SELF_DEFAULT_CONFIG),
|
||||
)
|
||||
session.add(provider)
|
||||
await session.flush()
|
||||
return provider
|
||||
|
||||
|
||||
async def _seed_bridge_self_providers_for_existing_users() -> None:
|
||||
"""Backfill bridge_self provider for every existing real user.
|
||||
|
||||
Runs once at boot so deployments upgrading from a pre-bridge_self
|
||||
release pick up the auto-created provider without requiring user
|
||||
action. Skips users that already have one.
|
||||
"""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
users = (await session.exec(select(User).where(User.id != 0))).all()
|
||||
for user in users:
|
||||
await ensure_bridge_self_provider_for_user(session, user.id)
|
||||
await session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -442,3 +513,4 @@ async def seed_all() -> None:
|
||||
await _seed_default_command_templates()
|
||||
await _seed_default_tracking_configs()
|
||||
await _seed_default_command_configs()
|
||||
await _seed_bridge_self_providers_for_existing_users()
|
||||
|
||||
@@ -50,6 +50,7 @@ from .commands.webhook import router as webhook_router, set_webhook_secret
|
||||
from .api.webhooks import router as webhooks_router
|
||||
from .api.webhook_logs import router as webhook_logs_router
|
||||
from .api.backup import router as backup_router
|
||||
from .api.metrics import router as metrics_router
|
||||
|
||||
|
||||
# Readiness flag — flipped to True once the scheduler has started and the
|
||||
@@ -78,6 +79,8 @@ async def lifespan(app: FastAPI):
|
||||
migrate_chat_action_to_column,
|
||||
migrate_deferred_dispatch_event_log_fk,
|
||||
migrate_deferred_dispatch_unique_pending,
|
||||
migrate_uniqueness_constraints,
|
||||
migrate_eventlog_provider_fk,
|
||||
migrate_schema_version,
|
||||
)
|
||||
from .database.snapshot import snapshot_and_prune
|
||||
@@ -107,6 +110,13 @@ async def lifespan(app: FastAPI):
|
||||
# the partial unique index.
|
||||
await migrate_deferred_dispatch_event_log_fk(engine)
|
||||
await migrate_deferred_dispatch_unique_pending(engine)
|
||||
# Backfill missing UNIQUE indexes on webhook hot paths (deduping any
|
||||
# existing duplicates). Runs after performance_indexes so non-unique
|
||||
# support indexes are already in place.
|
||||
await migrate_uniqueness_constraints(engine)
|
||||
# Document EventLog.provider_id FK strategy on existing tables (no-op
|
||||
# on SQLite besides the log line; new tables get the FK from create_all).
|
||||
await migrate_eventlog_provider_fk(engine)
|
||||
await migrate_schema_version(engine)
|
||||
from .database.seeds import seed_all
|
||||
await seed_all()
|
||||
@@ -254,6 +264,7 @@ app.include_router(webhook_router)
|
||||
app.include_router(webhooks_router)
|
||||
app.include_router(webhook_logs_router)
|
||||
app.include_router(backup_router)
|
||||
app.include_router(metrics_router)
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
@@ -265,15 +276,107 @@ async def health():
|
||||
|
||||
@app.get("/api/ready")
|
||||
async def ready():
|
||||
"""Readiness: migrations and scheduler have started, app can serve traffic.
|
||||
"""Readiness: deep dependency check.
|
||||
|
||||
Returns 503 until the lifespan startup sequence has completed. Use this
|
||||
for orchestrator readiness probes (Docker, Kubernetes).
|
||||
Verifies each critical dependency is actually reachable, not just that
|
||||
the app finished its lifespan startup. Returns 503 if any *required*
|
||||
check fails (db, scheduler). Home Assistant supervisor presence is
|
||||
informational — a degraded HA does not flip readiness off.
|
||||
|
||||
Response shape:
|
||||
{
|
||||
"ready": bool,
|
||||
"checks": {"db": "ok|fail", "scheduler": "ok|fail", "ha": "ok|degraded|na"},
|
||||
"errors": [str, ...]
|
||||
}
|
||||
"""
|
||||
from starlette.responses import JSONResponse
|
||||
import asyncio as _asyncio
|
||||
from sqlalchemy import text as _text
|
||||
|
||||
checks: dict[str, str] = {}
|
||||
errors: list[str] = []
|
||||
|
||||
if not _READY:
|
||||
from starlette.responses import JSONResponse
|
||||
return JSONResponse({"status": "starting"}, status_code=503)
|
||||
return {"status": "ready", "version": _APP_VERSION}
|
||||
# Lifespan still running — short-circuit so we don't poke a half-built engine.
|
||||
return JSONResponse(
|
||||
{
|
||||
"ready": False,
|
||||
"checks": {"db": "fail", "scheduler": "fail", "ha": "na"},
|
||||
"errors": ["startup not complete"],
|
||||
"version": _APP_VERSION,
|
||||
},
|
||||
status_code=503,
|
||||
)
|
||||
|
||||
# --- DB: SELECT 1 with a 2s timeout ---
|
||||
try:
|
||||
from .database.engine import get_engine
|
||||
engine = get_engine()
|
||||
|
||||
async def _ping_db() -> None:
|
||||
async with engine.connect() as conn:
|
||||
await conn.execute(_text("SELECT 1"))
|
||||
|
||||
await _asyncio.wait_for(_ping_db(), timeout=2.0)
|
||||
checks["db"] = "ok"
|
||||
except Exception as exc: # noqa: BLE001
|
||||
checks["db"] = "fail"
|
||||
errors.append(f"db: {exc!s}")
|
||||
|
||||
# --- Scheduler: APScheduler must be running ---
|
||||
try:
|
||||
from .services.scheduler import get_scheduler
|
||||
scheduler = get_scheduler()
|
||||
if scheduler.running:
|
||||
checks["scheduler"] = "ok"
|
||||
else:
|
||||
checks["scheduler"] = "fail"
|
||||
errors.append("scheduler: not running")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
checks["scheduler"] = "fail"
|
||||
errors.append(f"scheduler: {exc!s}")
|
||||
|
||||
# --- HA supervisor: informational only ---
|
||||
# If no HA providers are configured, report "na" (not applicable). If any
|
||||
# HA providers exist, ensure at least one supervisor task is alive — a
|
||||
# task being not-yet-connected is fine, we just want it to exist.
|
||||
try:
|
||||
from sqlmodel import select as _select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession as _AS
|
||||
from .database.models import ServiceProvider
|
||||
from .services.ha_subscription import _running_tasks as _ha_tasks
|
||||
|
||||
from .database.engine import get_engine as _get_engine_ha
|
||||
async with _AS(_get_engine_ha()) as _session:
|
||||
_result = await _session.exec(
|
||||
_select(ServiceProvider).where(
|
||||
ServiceProvider.type == "home_assistant",
|
||||
)
|
||||
)
|
||||
ha_providers = _result.all()
|
||||
if not ha_providers:
|
||||
checks["ha"] = "na"
|
||||
else:
|
||||
alive = [
|
||||
t for t in _ha_tasks.values() if t is not None and not t.done()
|
||||
]
|
||||
checks["ha"] = "ok" if alive else "degraded"
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# Never let the HA probe fail readiness — it's informational.
|
||||
checks["ha"] = "degraded"
|
||||
errors.append(f"ha: {exc!s}")
|
||||
|
||||
required_ok = checks["db"] == "ok" and checks["scheduler"] == "ok"
|
||||
body = {
|
||||
"ready": required_ok,
|
||||
"checks": checks,
|
||||
"errors": errors,
|
||||
"version": _APP_VERSION,
|
||||
}
|
||||
if not required_ok:
|
||||
return JSONResponse(body, status_code=503)
|
||||
return body
|
||||
|
||||
|
||||
# --- Serve frontend static files (production) ---
|
||||
|
||||
@@ -667,13 +667,19 @@ async def import_backup(
|
||||
if name is None:
|
||||
continue
|
||||
ctc_id = _map_id(id_map, "command_template_configs", cc.command_template_config_id)
|
||||
try:
|
||||
safe_enabled = _sanitize_config(cc.enabled_commands or {})
|
||||
safe_limits = _sanitize_config(cc.rate_limits or {})
|
||||
except ValueError as exc:
|
||||
result.warnings.append(f"Skipped command config '{cc.name}': {exc}")
|
||||
continue
|
||||
new_cc = CommandConfig(
|
||||
user_id=user_id, provider_type=cc.provider_type,
|
||||
name=name, icon=cc.icon,
|
||||
enabled_commands=cc.enabled_commands,
|
||||
enabled_commands=safe_enabled,
|
||||
response_mode=cc.response_mode,
|
||||
default_count=cc.default_count,
|
||||
rate_limits=cc.rate_limits,
|
||||
rate_limits=safe_limits,
|
||||
command_template_config_id=ctc_id,
|
||||
)
|
||||
session.add(new_cc)
|
||||
@@ -728,10 +734,16 @@ async def import_backup(
|
||||
)
|
||||
if name is None:
|
||||
continue
|
||||
try:
|
||||
safe_filters = _sanitize_config(nt.filters or {})
|
||||
safe_collection_ids = _sanitize_config(nt.collection_ids or [])
|
||||
except ValueError as exc:
|
||||
result.warnings.append(f"Skipped tracker '{nt.name}': {exc}")
|
||||
continue
|
||||
new_nt = NotificationTracker(
|
||||
user_id=user_id, provider_id=provider_id,
|
||||
name=name, icon=nt.icon, collection_ids=nt.collection_ids,
|
||||
filters=nt.filters, scan_interval=nt.scan_interval,
|
||||
name=name, icon=nt.icon, collection_ids=safe_collection_ids,
|
||||
filters=safe_filters, scan_interval=nt.scan_interval,
|
||||
default_tracking_config_id=_map_id(id_map, "tracking_configs", nt.default_tracking_config_id),
|
||||
default_template_config_id=_map_id(id_map, "template_configs", nt.default_template_config_id),
|
||||
enabled=nt.enabled,
|
||||
@@ -810,9 +822,14 @@ async def import_backup(
|
||||
)
|
||||
if name is None:
|
||||
continue
|
||||
try:
|
||||
safe_a_cfg = _sanitize_config(a.config or {})
|
||||
except ValueError as exc:
|
||||
result.warnings.append(f"Skipped action '{a.name}': {exc}")
|
||||
continue
|
||||
new_a = Action(
|
||||
user_id=user_id, provider_id=provider_id, name=name,
|
||||
icon=a.icon, action_type=a.action_type, config=a.config,
|
||||
icon=a.icon, action_type=a.action_type, config=safe_a_cfg,
|
||||
schedule_type=a.schedule_type,
|
||||
schedule_interval=a.schedule_interval,
|
||||
schedule_cron=a.schedule_cron, enabled=False, # always import disabled
|
||||
@@ -820,9 +837,16 @@ async def import_backup(
|
||||
session.add(new_a)
|
||||
await session.flush()
|
||||
for r in a.rules:
|
||||
try:
|
||||
safe_r_cfg = _sanitize_config(r.rule_config or {})
|
||||
except ValueError as exc:
|
||||
result.warnings.append(
|
||||
f"Skipped rule '{r.name}' in action '{a.name}': {exc}"
|
||||
)
|
||||
continue
|
||||
session.add(ActionRule(
|
||||
action_id=new_a.id, name=r.name,
|
||||
rule_config=r.rule_config, enabled=r.enabled,
|
||||
rule_config=safe_r_cfg, enabled=r.enabled,
|
||||
order=r.order,
|
||||
))
|
||||
result.created += len(a.rules)
|
||||
|
||||
@@ -0,0 +1,432 @@
|
||||
"""Bridge self-monitoring service helpers.
|
||||
|
||||
Three subsystems feed into ``emit_bridge_self_event``:
|
||||
|
||||
1. The watcher's poll loop, when consecutive provider polls fail.
|
||||
2. A periodic scheduler job, when the deferred-dispatch backlog crosses
|
||||
the configured threshold.
|
||||
3. The notification dispatcher, when consecutive sends to a single target
|
||||
fail with 5xx / network errors.
|
||||
|
||||
The helper looks up the user's ``bridge_self`` provider, builds a
|
||||
synthetic :class:`ServiceEvent`, and pushes it through the same
|
||||
``dispatch_provider_event`` pipeline that every other provider uses.
|
||||
That keeps templates, quiet hours, deferral, target gating, and event
|
||||
logging consistent with the rest of the system.
|
||||
|
||||
We intentionally avoid raising into the caller's flow — a
|
||||
self-monitoring failure must never break the subsystem it's monitoring.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.models.events import ServiceEvent
|
||||
from notify_bridge_core.providers.bridge_self import build_event
|
||||
from notify_bridge_core.providers.bridge_self.provider import (
|
||||
DEFAULT_DEFERRED_BACKLOG_THRESHOLD,
|
||||
DEFAULT_POLL_FAILURE_THRESHOLD,
|
||||
DEFAULT_TARGET_FAILURE_THRESHOLD,
|
||||
)
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import ServiceProvider, User
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Detail keys carried into the EventLog.details JSON column. Mirrors the
|
||||
# pattern used by the HA subscription and webhook routers for their
|
||||
# ``dispatch_provider_event`` calls.
|
||||
BRIDGE_SELF_DETAIL_KEYS: tuple[str, ...] = (
|
||||
"failure_type", "subject_id", "subject_name",
|
||||
"count", "threshold", "last_error", "details",
|
||||
)
|
||||
|
||||
|
||||
async def get_bridge_self_provider(
|
||||
session: AsyncSession, user_id: int,
|
||||
) -> ServiceProvider | None:
|
||||
"""Return the user's bridge_self provider row (or None if absent)."""
|
||||
result = await session.exec(
|
||||
select(ServiceProvider).where(
|
||||
ServiceProvider.user_id == user_id,
|
||||
ServiceProvider.type == "bridge_self",
|
||||
)
|
||||
)
|
||||
return result.first()
|
||||
|
||||
|
||||
async def get_user_thresholds(user_id: int) -> dict[str, int]:
|
||||
"""Return the user's bridge_self thresholds, falling back to defaults.
|
||||
|
||||
Reads in a short-lived session — emission sites should NOT hold a
|
||||
transaction across this call.
|
||||
"""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
provider = await get_bridge_self_provider(session, user_id)
|
||||
if provider is None:
|
||||
return {
|
||||
"poll_failure_threshold": DEFAULT_POLL_FAILURE_THRESHOLD,
|
||||
"deferred_backlog_threshold": DEFAULT_DEFERRED_BACKLOG_THRESHOLD,
|
||||
"target_failure_threshold": DEFAULT_TARGET_FAILURE_THRESHOLD,
|
||||
}
|
||||
cfg = dict(provider.config or {})
|
||||
|
||||
def _int(key: str, fallback: int) -> int:
|
||||
raw = cfg.get(key, fallback)
|
||||
try:
|
||||
value = int(raw)
|
||||
except (TypeError, ValueError):
|
||||
return fallback
|
||||
return value if value >= 1 else fallback
|
||||
|
||||
return {
|
||||
"poll_failure_threshold": _int("poll_failure_threshold", DEFAULT_POLL_FAILURE_THRESHOLD),
|
||||
"deferred_backlog_threshold": _int(
|
||||
"deferred_backlog_threshold", DEFAULT_DEFERRED_BACKLOG_THRESHOLD,
|
||||
),
|
||||
"target_failure_threshold": _int(
|
||||
"target_failure_threshold", DEFAULT_TARGET_FAILURE_THRESHOLD,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def emit_bridge_self_event(
|
||||
*,
|
||||
user_id: int,
|
||||
failure_type: str,
|
||||
subject_id: int,
|
||||
subject_name: str,
|
||||
count: int,
|
||||
threshold: int,
|
||||
last_error: str = "",
|
||||
details: dict[str, Any] | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> int:
|
||||
"""Emit a self-monitoring event for ``user_id``.
|
||||
|
||||
Resolves the user's bridge_self provider and dispatches the event via
|
||||
``dispatch_provider_event``. Returns the number of dispatched
|
||||
notifications (0 when the user has no bridge_self provider, no
|
||||
matching trackers, or the event was suppressed by quiet hours / event-
|
||||
type gating).
|
||||
|
||||
Always swallows internal exceptions so the calling subsystem keeps
|
||||
running — self-monitoring must never crash the watcher / scheduler /
|
||||
dispatcher.
|
||||
"""
|
||||
payload = {
|
||||
"failure_type": failure_type,
|
||||
"subject_id": subject_id,
|
||||
"subject_name": subject_name,
|
||||
"count": count,
|
||||
"threshold": threshold,
|
||||
"last_error": last_error,
|
||||
"details": dict(details or {}),
|
||||
}
|
||||
event = build_event(payload, timestamp=timestamp or datetime.now(timezone.utc))
|
||||
if event is None:
|
||||
_LOGGER.debug("Skipping malformed bridge_self payload: %s", payload)
|
||||
return 0
|
||||
|
||||
engine = get_engine()
|
||||
try:
|
||||
async with AsyncSession(engine) as session:
|
||||
provider = await get_bridge_self_provider(session, user_id)
|
||||
if provider is None:
|
||||
_LOGGER.debug(
|
||||
"User %s has no bridge_self provider; skipping %s emission",
|
||||
user_id, failure_type,
|
||||
)
|
||||
return 0
|
||||
provider_id = provider.id
|
||||
provider_name = provider.name
|
||||
provider_config = dict(provider.config or {})
|
||||
|
||||
# Imported here to avoid a top-level cycle: dispatch_helpers imports
|
||||
# several models which transitively touch this module's siblings.
|
||||
from .event_dispatch import dispatch_provider_event
|
||||
|
||||
return await dispatch_provider_event(
|
||||
engine=engine,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
provider_config=provider_config,
|
||||
event=event,
|
||||
detail_keys=BRIDGE_SELF_DETAIL_KEYS,
|
||||
filter_fn=lambda _ev, _filters: True,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"bridge_self emission failed (user=%s, failure_type=%s)",
|
||||
user_id, failure_type,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Threshold-crossing trackers (in-memory, per-process).
|
||||
#
|
||||
# We track consecutive failure counts in module-level dicts keyed by the
|
||||
# subject id (tracker_id, target_id). On threshold crossing we emit and
|
||||
# reset the counter so we don't spam — the next emission only happens after
|
||||
# another full streak of failures.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Tracker poll failures (keyed by tracker_id).
|
||||
_poll_failure_counts: dict[int, int] = {}
|
||||
_poll_failure_last_error: dict[int, str] = {}
|
||||
|
||||
# Target send failures (keyed by target_id).
|
||||
_target_failure_counts: dict[int, int] = {}
|
||||
_target_failure_last_error: dict[int, str] = {}
|
||||
|
||||
# Last-known backlog state per user (True = above threshold, False = below).
|
||||
# We only emit on the False -> True transition so a sustained backlog
|
||||
# triggers exactly one notification per crossing.
|
||||
_backlog_above_threshold: dict[int, bool] = {}
|
||||
|
||||
|
||||
def record_poll_success(tracker_id: int) -> None:
|
||||
"""Reset the failure counter for ``tracker_id`` after a successful poll."""
|
||||
_poll_failure_counts.pop(tracker_id, None)
|
||||
_poll_failure_last_error.pop(tracker_id, None)
|
||||
|
||||
|
||||
def record_poll_failure(tracker_id: int, error: str = "") -> int:
|
||||
"""Increment the failure counter for ``tracker_id``; return the new count."""
|
||||
_poll_failure_counts[tracker_id] = _poll_failure_counts.get(tracker_id, 0) + 1
|
||||
if error:
|
||||
_poll_failure_last_error[tracker_id] = error
|
||||
return _poll_failure_counts[tracker_id]
|
||||
|
||||
|
||||
def reset_poll_counter(tracker_id: int) -> None:
|
||||
"""Clear the failure counter for ``tracker_id`` without emitting."""
|
||||
_poll_failure_counts.pop(tracker_id, None)
|
||||
_poll_failure_last_error.pop(tracker_id, None)
|
||||
|
||||
|
||||
def record_target_success(target_id: int) -> None:
|
||||
"""Reset the failure counter for ``target_id`` after a successful send."""
|
||||
_target_failure_counts.pop(target_id, None)
|
||||
_target_failure_last_error.pop(target_id, None)
|
||||
|
||||
|
||||
def record_target_failure(target_id: int, error: str = "") -> int:
|
||||
"""Increment the failure counter for ``target_id``; return the new count."""
|
||||
_target_failure_counts[target_id] = _target_failure_counts.get(target_id, 0) + 1
|
||||
if error:
|
||||
_target_failure_last_error[target_id] = error
|
||||
return _target_failure_counts[target_id]
|
||||
|
||||
|
||||
def reset_target_counter(target_id: int) -> None:
|
||||
"""Clear the failure counter for ``target_id`` without emitting."""
|
||||
_target_failure_counts.pop(target_id, None)
|
||||
_target_failure_last_error.pop(target_id, None)
|
||||
|
||||
|
||||
def record_backlog_state(user_id: int, above_threshold: bool) -> bool:
|
||||
"""Record the new backlog state, returning True iff we just crossed up.
|
||||
|
||||
The first ever observation is treated as "below" so a process that
|
||||
starts with a non-empty backlog still emits one notification.
|
||||
"""
|
||||
prior = _backlog_above_threshold.get(user_id, False)
|
||||
_backlog_above_threshold[user_id] = above_threshold
|
||||
return above_threshold and not prior
|
||||
|
||||
|
||||
def get_poll_failure_count(tracker_id: int) -> int:
|
||||
return _poll_failure_counts.get(tracker_id, 0)
|
||||
|
||||
|
||||
def get_target_failure_count(target_id: int) -> int:
|
||||
return _target_failure_counts.get(target_id, 0)
|
||||
|
||||
|
||||
def get_poll_last_error(tracker_id: int) -> str:
|
||||
return _poll_failure_last_error.get(tracker_id, "")
|
||||
|
||||
|
||||
def get_target_last_error(target_id: int) -> str:
|
||||
return _target_failure_last_error.get(target_id, "")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User-level helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def list_user_ids() -> list[int]:
|
||||
"""Return all real user ids (excluding the __system__ placeholder)."""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(select(User.id).where(User.id != 0))
|
||||
return [int(uid) for uid in result.all() if uid is not None]
|
||||
|
||||
|
||||
async def find_tracker_owner(tracker_id: int) -> int | None:
|
||||
"""Return the user_id that owns ``tracker_id`` (or None)."""
|
||||
from ..database.models import NotificationTracker
|
||||
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
tracker = await session.get(NotificationTracker, tracker_id)
|
||||
if tracker is None:
|
||||
return None
|
||||
return int(tracker.user_id)
|
||||
|
||||
|
||||
async def find_target_owner(target_id: int) -> int | None:
|
||||
"""Return the user_id that owns ``target_id`` (or None)."""
|
||||
from ..database.models import NotificationTarget
|
||||
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
target = await session.get(NotificationTarget, target_id)
|
||||
if target is None:
|
||||
return None
|
||||
return int(target.user_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backlog scan
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def check_deferred_backlog() -> dict[str, Any]:
|
||||
"""Scan the deferred_dispatch table and emit a backlog event if needed.
|
||||
|
||||
Counts pending rows per user, compares against each user's configured
|
||||
threshold, and emits ``bridge_self_deferred_backlog`` for users that
|
||||
just crossed up. Returns a small stats dict for logging.
|
||||
"""
|
||||
from sqlalchemy import func
|
||||
|
||||
from ..database.models import DeferredDispatch
|
||||
|
||||
engine = get_engine()
|
||||
crossings = 0
|
||||
async with AsyncSession(engine) as session:
|
||||
# GROUP BY user_id so we don't have to scan once per user. Skip rows
|
||||
# whose user_id is NULL — those are legacy / orphaned and have no
|
||||
# bridge_self provider to alert anyway.
|
||||
rows = (
|
||||
await session.exec(
|
||||
select(
|
||||
DeferredDispatch.user_id,
|
||||
func.count(DeferredDispatch.id),
|
||||
)
|
||||
.where(DeferredDispatch.status == "pending")
|
||||
.where(DeferredDispatch.user_id.is_not(None))
|
||||
.group_by(DeferredDispatch.user_id)
|
||||
)
|
||||
).all()
|
||||
|
||||
counts_by_user: dict[int, int] = {}
|
||||
for row in rows:
|
||||
if isinstance(row, tuple):
|
||||
uid, count = row
|
||||
else:
|
||||
uid, count = row
|
||||
if uid is None:
|
||||
continue
|
||||
counts_by_user[int(uid)] = int(count or 0)
|
||||
|
||||
for user_id, count in counts_by_user.items():
|
||||
thresholds = await get_user_thresholds(user_id)
|
||||
threshold = thresholds["deferred_backlog_threshold"]
|
||||
above = count >= threshold
|
||||
if record_backlog_state(user_id, above):
|
||||
crossings += 1
|
||||
await emit_bridge_self_event(
|
||||
user_id=user_id,
|
||||
failure_type="deferred_backlog",
|
||||
subject_id=0,
|
||||
subject_name="Deferred dispatch queue",
|
||||
count=count,
|
||||
threshold=threshold,
|
||||
details={"pending": count},
|
||||
)
|
||||
|
||||
# Reset latch for users that recovered (count < threshold or zero rows).
|
||||
# Iterate all known users so a user whose backlog drained to 0 (no row in
|
||||
# GROUP BY) still flips back to "below".
|
||||
for user_id in list(_backlog_above_threshold.keys()):
|
||||
if user_id in counts_by_user:
|
||||
continue
|
||||
# No pending rows for this user — clear the latch.
|
||||
_backlog_above_threshold[user_id] = False
|
||||
|
||||
return {"users_scanned": len(counts_by_user), "crossings": crossings}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Threshold-aware emission wrappers (used by watcher / dispatcher).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def maybe_emit_poll_failure(
|
||||
*, tracker_id: int, tracker_name: str, error: str = "",
|
||||
) -> None:
|
||||
"""Increment poll failure counter; emit + reset if threshold reached."""
|
||||
count = record_poll_failure(tracker_id, error)
|
||||
user_id = await find_tracker_owner(tracker_id)
|
||||
if user_id is None:
|
||||
return
|
||||
thresholds = await get_user_thresholds(user_id)
|
||||
threshold = thresholds["poll_failure_threshold"]
|
||||
if count < threshold:
|
||||
return
|
||||
last_err = get_poll_last_error(tracker_id) or error
|
||||
await emit_bridge_self_event(
|
||||
user_id=user_id,
|
||||
failure_type="poll_failures",
|
||||
subject_id=tracker_id,
|
||||
subject_name=tracker_name or f"tracker {tracker_id}",
|
||||
count=count,
|
||||
threshold=threshold,
|
||||
last_error=last_err,
|
||||
details={"tracker_id": tracker_id},
|
||||
)
|
||||
# Reset so the next emission requires another full streak. Without this
|
||||
# the same tracker would fire on EVERY tick once it crosses the
|
||||
# threshold, drowning the operator.
|
||||
reset_poll_counter(tracker_id)
|
||||
|
||||
|
||||
async def maybe_emit_target_failure(
|
||||
*, target_id: int, target_name: str, target_type: str, error: str = "",
|
||||
) -> None:
|
||||
"""Increment target failure counter; emit + reset if threshold reached."""
|
||||
count = record_target_failure(target_id, error)
|
||||
user_id = await find_target_owner(target_id)
|
||||
if user_id is None:
|
||||
return
|
||||
thresholds = await get_user_thresholds(user_id)
|
||||
threshold = thresholds["target_failure_threshold"]
|
||||
if count < threshold:
|
||||
return
|
||||
last_err = get_target_last_error(target_id) or error
|
||||
await emit_bridge_self_event(
|
||||
user_id=user_id,
|
||||
failure_type="target_failures",
|
||||
subject_id=target_id,
|
||||
subject_name=target_name or f"target {target_id}",
|
||||
count=count,
|
||||
threshold=threshold,
|
||||
last_error=last_err,
|
||||
details={"target_id": target_id, "target_type": target_type},
|
||||
)
|
||||
reset_target_counter(target_id)
|
||||
@@ -98,9 +98,16 @@ async def _flush_dirty_bots() -> None:
|
||||
bot = await session.get(TelegramBot, bot_id)
|
||||
if not bot:
|
||||
continue
|
||||
# Snapshot every attribute we touch after the session
|
||||
# exits — once detached, lazy attribute access raises
|
||||
# MissingGreenlet under SQLAlchemy async.
|
||||
bot_username = bot.bot_username
|
||||
# Expunge so the detached instance can still read snapshotted
|
||||
# attrs but won't trigger a refresh / re-query downstream.
|
||||
session.expunge(bot)
|
||||
success = await register_commands_with_telegram(bot)
|
||||
if success:
|
||||
_LOGGER.info("Auto-synced commands for bot %d (@%s)", bot_id, bot.bot_username)
|
||||
_LOGGER.info("Auto-synced commands for bot %d (@%s)", bot_id, bot_username)
|
||||
else:
|
||||
_LOGGER.warning("Auto-sync failed for bot %d", bot_id)
|
||||
except Exception:
|
||||
|
||||
@@ -29,6 +29,7 @@ import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -53,6 +54,7 @@ from .dispatch_helpers import (
|
||||
evaluate_event_gate,
|
||||
get_app_timezone,
|
||||
load_link_data,
|
||||
resolve_provider_credential,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -88,6 +90,10 @@ _DEFERRABLE_EVENT_TYPES: frozenset[str] = frozenset({
|
||||
"ups_online", "ups_on_battery", "ups_low_battery",
|
||||
"ups_battery_restored", "ups_comms_lost", "ups_comms_restored",
|
||||
"ups_replace_battery", "ups_overload",
|
||||
# Home Assistant — state changes & automations are change-driven; the
|
||||
# underlying state remains relevant after the quiet window.
|
||||
"ha_state_changed", "ha_automation_triggered",
|
||||
"ha_service_called", "ha_event_fired",
|
||||
})
|
||||
|
||||
# Per-tracker cap on the pending queue. A misconfigured short quiet window
|
||||
@@ -206,6 +212,11 @@ def _coalesce_assets_added(
|
||||
payload["removed_asset_ids"] = kept
|
||||
payload["removed_count"] = len(kept)
|
||||
existing_removed_row.event_payload = payload
|
||||
# Belt-and-braces: SQLAlchemy's mutation tracker sometimes
|
||||
# misses JSON-typed reassignments depending on dialect / column
|
||||
# config. Explicit flag_modified guarantees the dirty bit is
|
||||
# set for the upcoming flush.
|
||||
flag_modified(existing_removed_row, "event_payload")
|
||||
if not kept:
|
||||
# All previously-removed IDs are being re-added → entire
|
||||
# removal is cancelled. Mark for caller to delete.
|
||||
@@ -235,6 +246,7 @@ def _coalesce_assets_added(
|
||||
payload["added_assets"] = existing_assets
|
||||
payload["added_count"] = len(existing_assets)
|
||||
existing_added_row.event_payload = payload
|
||||
flag_modified(existing_added_row, "event_payload")
|
||||
return ("merge", existing_added_row, existing_removed_row)
|
||||
|
||||
|
||||
@@ -257,6 +269,7 @@ def _coalesce_assets_removed(
|
||||
payload["added_assets"] = kept_assets
|
||||
payload["added_count"] = len(kept_assets)
|
||||
existing_added_row.event_payload = payload
|
||||
flag_modified(existing_added_row, "event_payload")
|
||||
if not kept_assets:
|
||||
existing_added_row.status = "cancelled"
|
||||
# IDs that were just added during the window don't need to flow
|
||||
@@ -282,6 +295,7 @@ def _coalesce_assets_removed(
|
||||
payload["removed_asset_ids"] = existing_ids
|
||||
payload["removed_count"] = len(existing_ids)
|
||||
existing_removed_row.event_payload = payload
|
||||
flag_modified(existing_removed_row, "event_payload")
|
||||
return ("merge", existing_added_row, existing_removed_row)
|
||||
|
||||
|
||||
@@ -695,7 +709,7 @@ async def _process_row(
|
||||
template_slots=ld.get("template_slots"),
|
||||
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
|
||||
date_only_format=(tmpl.date_only_format if tmpl and tmpl.date_only_format else "%d.%m.%Y"),
|
||||
provider_api_key=provider_config.get("api_key") or provider_config.get("api_token"),
|
||||
provider_api_key=resolve_provider_credential(provider_config),
|
||||
provider_internal_url=provider_config.get("url", ""),
|
||||
provider_external_url=provider_config.get("external_domain", "") or provider_config.get("url", ""),
|
||||
receivers=ld["receivers"],
|
||||
|
||||
@@ -210,6 +210,13 @@ def _event_type_enabled(event: ServiceEvent, tc: TrackingConfig) -> bool:
|
||||
"ha_automation_triggered": getattr(tc, "track_ha_automation_triggered", False),
|
||||
"ha_service_called": getattr(tc, "track_ha_service_called", False),
|
||||
"ha_event_fired": getattr(tc, "track_ha_event_fired", False),
|
||||
# Bridge self-monitoring — defaults True so a tracker created before the
|
||||
# columns existed still surfaces the alerts. Legacy rows are extremely
|
||||
# unlikely here (the columns ship in the same release as the provider),
|
||||
# but the safer default matches the rest of this map.
|
||||
"bridge_self_poll_failures": getattr(tc, "track_bridge_self_poll_failures", True),
|
||||
"bridge_self_deferred_backlog": getattr(tc, "track_bridge_self_deferred_backlog", True),
|
||||
"bridge_self_target_failures": getattr(tc, "track_bridge_self_target_failures", True),
|
||||
}
|
||||
return flag_map.get(event_type, True)
|
||||
|
||||
@@ -225,11 +232,15 @@ def evaluate_event_gate(
|
||||
by quiet hours — the UTC datetime at which the window ends so the caller
|
||||
can schedule a deferred dispatch.
|
||||
|
||||
Order of checks: quiet hours first, then per-event-type flag. Quiet hours
|
||||
is the "louder" gate (it applies to every type), so reporting it first
|
||||
avoids the surprising case of "you disabled this event type" showing up
|
||||
when the user really just opened the quiet window.
|
||||
Order of checks: per-event-type flag FIRST, then quiet hours. Otherwise a
|
||||
disabled event type would get deferred during quiet hours and then
|
||||
silently dropped at drain time — wasted work and a confusing "deferred
|
||||
then dropped" trail in the dashboard. The user already said "don't tell
|
||||
me about this kind of event"; honour that immediately.
|
||||
"""
|
||||
if not _event_type_enabled(event, tc):
|
||||
return GateOutcome(reason=GateReason.EVENT_TYPE_DISABLED)
|
||||
|
||||
if tc.quiet_hours_enabled:
|
||||
end_at = quiet_hours_status(
|
||||
tc.quiet_hours_start, tc.quiet_hours_end, tz_name,
|
||||
@@ -240,9 +251,6 @@ def evaluate_event_gate(
|
||||
quiet_hours_end_at=end_at,
|
||||
)
|
||||
|
||||
if not _event_type_enabled(event, tc):
|
||||
return GateOutcome(reason=GateReason.EVENT_TYPE_DISABLED)
|
||||
|
||||
return GateOutcome(reason=GateReason.ALLOWED)
|
||||
|
||||
|
||||
@@ -397,23 +405,49 @@ def apply_tracking_display_filters(
|
||||
)
|
||||
|
||||
|
||||
def resolve_provider_credential(cfg: dict[str, Any] | None) -> str | None:
|
||||
"""Pick the first non-empty provider credential field.
|
||||
|
||||
Provider configs use different field names (Immich → ``api_key``,
|
||||
Gitea → ``api_token``, HA → ``access_token``). All four dispatch
|
||||
sites used to pick one field by hand; centralising here keeps the
|
||||
fallback order consistent so a config edit on one provider type can't
|
||||
silently break dispatch for another.
|
||||
"""
|
||||
if not cfg:
|
||||
return None
|
||||
return cfg.get("api_key") or cfg.get("api_token") or cfg.get("access_token")
|
||||
|
||||
|
||||
async def _resolve_target(
|
||||
session: AsyncSession,
|
||||
target: NotificationTarget,
|
||||
*,
|
||||
receivers_by_target: dict[int, list[TargetReceiver]] | None = None,
|
||||
telegram_chats_by_bot: dict[int, dict[str, TelegramChat]] | None = None,
|
||||
email_bots_by_id: dict[int, EmailBot] | None = None,
|
||||
matrix_bots_by_id: dict[int, MatrixBot] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve a single target into dispatch-ready data (config + receivers + credentials).
|
||||
|
||||
Returns a dict with target_type, target_config, and receivers.
|
||||
Does NOT include tracking_config or template_slots — those come from the tracker link.
|
||||
|
||||
Optional ``*_by_*`` maps short-circuit per-target DB queries when the
|
||||
caller has batch-prefetched the data. When omitted, we fall back to the
|
||||
original single-query path so direct callers (manual_dispatch) still work.
|
||||
"""
|
||||
# Load receivers as typed Receiver objects
|
||||
recv_result = await session.exec(
|
||||
select(TargetReceiver).where(
|
||||
TargetReceiver.target_id == target.id,
|
||||
TargetReceiver.enabled == True,
|
||||
# Receivers — prefer pre-fetched map.
|
||||
if receivers_by_target is not None:
|
||||
recv_rows = [r for r in receivers_by_target.get(target.id, []) if r.enabled]
|
||||
else:
|
||||
recv_result = await session.exec(
|
||||
select(TargetReceiver).where(
|
||||
TargetReceiver.target_id == target.id,
|
||||
TargetReceiver.enabled == True,
|
||||
)
|
||||
)
|
||||
)
|
||||
recv_rows = recv_result.all()
|
||||
recv_rows = recv_result.all()
|
||||
|
||||
# For Telegram targets, resolve locale from TelegramChat
|
||||
chat_locale_map: dict[str, str] = {}
|
||||
@@ -422,13 +456,24 @@ async def _resolve_target(
|
||||
if bot_id:
|
||||
chat_ids = [str(r.config.get("chat_id", "")) for r in recv_rows if r.config.get("chat_id")]
|
||||
if chat_ids:
|
||||
chat_result = await session.exec(
|
||||
select(TelegramChat).where(
|
||||
TelegramChat.bot_id == bot_id,
|
||||
TelegramChat.chat_id.in_(chat_ids),
|
||||
)
|
||||
chats_for_bot = (
|
||||
telegram_chats_by_bot.get(bot_id, {})
|
||||
if telegram_chats_by_bot is not None else None
|
||||
)
|
||||
for chat in chat_result.all():
|
||||
if chats_for_bot is not None:
|
||||
rows = [
|
||||
chats_for_bot[cid] for cid in chat_ids
|
||||
if cid in chats_for_bot
|
||||
]
|
||||
else:
|
||||
chat_result = await session.exec(
|
||||
select(TelegramChat).where(
|
||||
TelegramChat.bot_id == bot_id,
|
||||
TelegramChat.chat_id.in_(chat_ids),
|
||||
)
|
||||
)
|
||||
rows = chat_result.all()
|
||||
for chat in rows:
|
||||
resolved = (
|
||||
getattr(chat, 'language_override', '') or
|
||||
getattr(chat, 'language_code', '') or ''
|
||||
@@ -457,7 +502,10 @@ async def _resolve_target(
|
||||
if target.type == "email":
|
||||
email_bot_id = target.config.get("email_bot_id")
|
||||
if email_bot_id:
|
||||
email_bot = await session.get(EmailBot, email_bot_id)
|
||||
if email_bots_by_id is not None:
|
||||
email_bot = email_bots_by_id.get(email_bot_id)
|
||||
else:
|
||||
email_bot = await session.get(EmailBot, email_bot_id)
|
||||
if email_bot:
|
||||
target_config["smtp"] = {
|
||||
"host": email_bot.smtp_host,
|
||||
@@ -471,12 +519,17 @@ async def _resolve_target(
|
||||
elif target.type == "matrix":
|
||||
matrix_bot_id = target.config.get("matrix_bot_id")
|
||||
if matrix_bot_id:
|
||||
matrix_bot = await session.get(MatrixBot, matrix_bot_id)
|
||||
if matrix_bots_by_id is not None:
|
||||
matrix_bot = matrix_bots_by_id.get(matrix_bot_id)
|
||||
else:
|
||||
matrix_bot = await session.get(MatrixBot, matrix_bot_id)
|
||||
if matrix_bot:
|
||||
target_config["homeserver_url"] = matrix_bot.homeserver_url
|
||||
target_config["access_token"] = matrix_bot.access_token
|
||||
|
||||
return {
|
||||
"target_id": target.id,
|
||||
"target_name": target.name,
|
||||
"target_type": target.type,
|
||||
"target_config": target_config,
|
||||
"receivers": receivers,
|
||||
@@ -567,6 +620,76 @@ async def load_link_data(
|
||||
)
|
||||
child_target_map = {t.id: t for t in child_rows.all()}
|
||||
|
||||
# ---- Batch pre-fetch for _resolve_target ----
|
||||
# Build the universe of target IDs (regular + expanded broadcast children)
|
||||
# so a single query per relation type covers every call to _resolve_target
|
||||
# below — no per-target follow-up SELECTs.
|
||||
all_target_ids: set[int] = set(target_map.keys()) | set(child_target_map.keys())
|
||||
|
||||
receivers_by_target: dict[int, list[TargetReceiver]] = {}
|
||||
if all_target_ids:
|
||||
recv_result = await session.exec(
|
||||
select(TargetReceiver).where(TargetReceiver.target_id.in_(all_target_ids))
|
||||
)
|
||||
for r in recv_result.all():
|
||||
receivers_by_target.setdefault(r.target_id, []).append(r)
|
||||
|
||||
# Telegram chats keyed by (bot_id, chat_id) — collect all (bot_id, chat_id)
|
||||
# pairs referenced by enabled telegram-target receivers, then one query.
|
||||
tg_pairs: dict[int, set[str]] = {} # bot_id -> {chat_id}
|
||||
for tid in all_target_ids:
|
||||
tgt = target_map.get(tid) or child_target_map.get(tid)
|
||||
if not tgt or tgt.type != "telegram":
|
||||
continue
|
||||
bot_id = tgt.config.get("bot_id")
|
||||
if not bot_id:
|
||||
continue
|
||||
for r in receivers_by_target.get(tid, []):
|
||||
cid = str(r.config.get("chat_id", ""))
|
||||
if cid:
|
||||
tg_pairs.setdefault(bot_id, set()).add(cid)
|
||||
telegram_chats_by_bot: dict[int, dict[str, TelegramChat]] = {}
|
||||
for bot_id, chat_ids in tg_pairs.items():
|
||||
if not chat_ids:
|
||||
continue
|
||||
chat_rows = await session.exec(
|
||||
select(TelegramChat).where(
|
||||
TelegramChat.bot_id == bot_id,
|
||||
TelegramChat.chat_id.in_(chat_ids),
|
||||
)
|
||||
)
|
||||
telegram_chats_by_bot[bot_id] = {c.chat_id: c for c in chat_rows.all()}
|
||||
|
||||
# Email + Matrix bots
|
||||
email_bot_ids: set[int] = set()
|
||||
matrix_bot_ids: set[int] = set()
|
||||
for tid in all_target_ids:
|
||||
tgt = target_map.get(tid) or child_target_map.get(tid)
|
||||
if not tgt:
|
||||
continue
|
||||
if tgt.type == "email":
|
||||
bid = tgt.config.get("email_bot_id")
|
||||
if bid:
|
||||
email_bot_ids.add(bid)
|
||||
elif tgt.type == "matrix":
|
||||
bid = tgt.config.get("matrix_bot_id")
|
||||
if bid:
|
||||
matrix_bot_ids.add(bid)
|
||||
|
||||
email_bots_by_id: dict[int, EmailBot] = {}
|
||||
if email_bot_ids:
|
||||
rows = await session.exec(
|
||||
select(EmailBot).where(EmailBot.id.in_(email_bot_ids))
|
||||
)
|
||||
email_bots_by_id = {b.id: b for b in rows.all()}
|
||||
|
||||
matrix_bots_by_id: dict[int, MatrixBot] = {}
|
||||
if matrix_bot_ids:
|
||||
rows = await session.exec(
|
||||
select(MatrixBot).where(MatrixBot.id.in_(matrix_bot_ids))
|
||||
)
|
||||
matrix_bots_by_id = {b.id: b for b in rows.all()}
|
||||
|
||||
link_data: list[dict[str, Any]] = []
|
||||
for tt in active_links:
|
||||
target = target_map.get(tt.target_id)
|
||||
@@ -589,7 +712,13 @@ async def load_link_data(
|
||||
child_target = child_target_map.get(child_id)
|
||||
if not child_target or child_target.type == "broadcast":
|
||||
continue
|
||||
resolved = await _resolve_target(session, child_target)
|
||||
resolved = await _resolve_target(
|
||||
session, child_target,
|
||||
receivers_by_target=receivers_by_target,
|
||||
telegram_chats_by_bot=telegram_chats_by_bot,
|
||||
email_bots_by_id=email_bots_by_id,
|
||||
matrix_bots_by_id=matrix_bots_by_id,
|
||||
)
|
||||
link_data.append({
|
||||
**resolved,
|
||||
"link_id": tt.id,
|
||||
@@ -600,7 +729,13 @@ async def load_link_data(
|
||||
continue
|
||||
|
||||
# Regular target
|
||||
resolved = await _resolve_target(session, target)
|
||||
resolved = await _resolve_target(
|
||||
session, target,
|
||||
receivers_by_target=receivers_by_target,
|
||||
telegram_chats_by_bot=telegram_chats_by_bot,
|
||||
email_bots_by_id=email_bots_by_id,
|
||||
matrix_bots_by_id=matrix_bots_by_id,
|
||||
)
|
||||
link_data.append({
|
||||
**resolved,
|
||||
"link_id": tt.id,
|
||||
|
||||
@@ -14,6 +14,7 @@ services -> api cycle).
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from sqlmodel import select
|
||||
@@ -33,6 +34,7 @@ from .dispatch_helpers import (
|
||||
evaluate_event_gate,
|
||||
get_app_timezone,
|
||||
load_link_data,
|
||||
resolve_provider_credential,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -44,6 +46,67 @@ _LOGGER = logging.getLogger(__name__)
|
||||
FilterFn = Callable[[ServiceEvent, dict[str, Any]], bool]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tracker cache (per-provider, TTL-bounded)
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# HA's chat-bus emits dozens of events per second; the per-event SELECT for
|
||||
# enabled trackers becomes the bottleneck on busy installs. This 5-second
|
||||
# TTL cache short-circuits the lookup when the same provider is hot. Cache
|
||||
# entries are invalidated explicitly by tracker CRUD endpoints; the TTL is
|
||||
# the safety net for missed invalidations.
|
||||
|
||||
_TRACKER_CACHE_TTL_SECONDS = 5.0
|
||||
_trackers_cache: dict[int, tuple[float, list[NotificationTracker]]] = {}
|
||||
|
||||
|
||||
def invalidate_tracker_cache(provider_id: int | None = None) -> None:
|
||||
"""Drop cached trackers for one provider (or all if ``provider_id`` is None).
|
||||
|
||||
Call from tracker create / update / delete endpoints so the next
|
||||
inbound event sees the change without waiting out the TTL.
|
||||
"""
|
||||
if provider_id is None:
|
||||
_trackers_cache.clear()
|
||||
else:
|
||||
_trackers_cache.pop(provider_id, None)
|
||||
|
||||
|
||||
async def _load_trackers_cached(
|
||||
session: AsyncSession, provider_id: int,
|
||||
) -> list[NotificationTracker]:
|
||||
"""Return enabled trackers for ``provider_id``, with a short TTL cache.
|
||||
|
||||
Caches the ``NotificationTracker`` rows themselves — NOT the per-tracker
|
||||
``NotificationTrackerTarget`` link rows. ``load_link_data`` always re-reads
|
||||
links from the DB on every dispatch, so adding/removing/toggling a link
|
||||
does NOT require invalidating this cache. Only call ``invalidate_tracker_cache``
|
||||
when a tracker row is created/updated/deleted.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
cached = _trackers_cache.get(provider_id)
|
||||
if cached is not None:
|
||||
ts, trackers = cached
|
||||
if now - ts < _TRACKER_CACHE_TTL_SECONDS:
|
||||
return trackers
|
||||
result = await session.exec(
|
||||
select(NotificationTracker).where(
|
||||
NotificationTracker.provider_id == provider_id,
|
||||
NotificationTracker.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
trackers = list(result.all())
|
||||
# Detach cached instances so consumers don't accidentally use a stale
|
||||
# session — re-fetch by id when mutating.
|
||||
for t in trackers:
|
||||
try:
|
||||
session.expunge(t)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
_trackers_cache[provider_id] = (now, trackers)
|
||||
return trackers
|
||||
|
||||
|
||||
async def dispatch_provider_event(
|
||||
engine: Any,
|
||||
provider_id: int,
|
||||
@@ -82,18 +145,25 @@ async def dispatch_provider_event(
|
||||
# Drain-scheduling is best-effort: a scheduling failure must not roll
|
||||
# back the persisted defer rows (startup catch-up re-establishes them).
|
||||
defers_to_schedule: set[Any] = set()
|
||||
# Build the dispatcher once per inbound event — its only state is the
|
||||
# shared aiohttp session and Telegram caches, both of which are reused
|
||||
# across all trackers. Re-creating per-tracker meant a fresh dispatcher
|
||||
# for every notification, paying the construction cost on every HA
|
||||
# state_changed (ha provider can fire dozens per second).
|
||||
from .http_session import get_http_session
|
||||
from .watcher import _get_telegram_caches
|
||||
url_cache, asset_cache = await _get_telegram_caches()
|
||||
dispatcher = NotificationDispatcher(
|
||||
url_cache=url_cache,
|
||||
asset_cache=asset_cache,
|
||||
session=await get_http_session(),
|
||||
)
|
||||
async with AsyncSession(engine) as session:
|
||||
# App timezone is identical across trackers in one inbound event;
|
||||
# pull it once.
|
||||
app_tz = await get_app_timezone(session)
|
||||
|
||||
tracker_result = await session.exec(
|
||||
select(NotificationTracker).where(
|
||||
NotificationTracker.provider_id == provider_id,
|
||||
NotificationTracker.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
trackers = tracker_result.all()
|
||||
trackers = await _load_trackers_cached(session, provider_id)
|
||||
|
||||
for tracker in trackers:
|
||||
filters = tracker.filters or {}
|
||||
@@ -147,6 +217,17 @@ async def dispatch_provider_event(
|
||||
prior = defers_for_event.get(link_id)
|
||||
if prior is None or outcome.quiet_hours_end_at < prior:
|
||||
defers_for_event[link_id] = outcome.quiet_hours_end_at
|
||||
else:
|
||||
# Non-deferrable event hit quiet hours — stamp
|
||||
# the event_log so the dashboard surfaces *why*
|
||||
# the notification never went out.
|
||||
details = dict(event_log_row.details or {})
|
||||
if not details.get("dispatch_status"):
|
||||
details["dispatch_status"] = (
|
||||
"dropped_quiet_hours_nondeferrable"
|
||||
)
|
||||
event_log_row.details = details
|
||||
session.add(event_log_row)
|
||||
continue
|
||||
if outcome.reason is GateReason.EVENT_TYPE_DISABLED:
|
||||
continue
|
||||
@@ -162,7 +243,7 @@ async def dispatch_provider_event(
|
||||
if tmpl and tmpl.date_only_format
|
||||
else "%d.%m.%Y"
|
||||
),
|
||||
provider_api_key=provider_config.get("api_token"),
|
||||
provider_api_key=resolve_provider_credential(provider_config),
|
||||
provider_internal_url=provider_config.get("url", ""),
|
||||
provider_external_url=provider_config.get("url", ""),
|
||||
receivers=ld["receivers"],
|
||||
@@ -170,7 +251,10 @@ async def dispatch_provider_event(
|
||||
key = id(tc) if tc is not None else 0
|
||||
if key not in groups:
|
||||
groups[key] = (tc, [])
|
||||
groups[key][1].append(target_cfg)
|
||||
# Thread per-target metadata alongside the TargetConfig so the
|
||||
# bridge_self failure counters can attribute results to a
|
||||
# specific target_id after dispatch.
|
||||
groups[key][1].append((target_cfg, ld.get("target_id"), ld.get("target_name", "")))
|
||||
|
||||
# Persist defers + stamp event_log dispatch_status in the same
|
||||
# session that holds the EventLog row, so the "deferred" badge
|
||||
@@ -198,14 +282,25 @@ async def dispatch_provider_event(
|
||||
# Dispatch to targets. Isolate dispatcher exceptions per group so
|
||||
# a failed remote call doesn't bubble out, abort the surrounding
|
||||
# transaction, and roll back the just-written defers / event_log.
|
||||
from .http_session import get_http_session
|
||||
dispatcher = NotificationDispatcher(session=await get_http_session())
|
||||
for tc, target_configs in groups.values():
|
||||
if not target_configs:
|
||||
from .bridge_self import (
|
||||
maybe_emit_target_failure,
|
||||
record_target_success,
|
||||
)
|
||||
|
||||
# Skip target-failure tracking when we're already dispatching a
|
||||
# bridge_self event — otherwise a failing alert target would
|
||||
# endlessly re-emit alerts about itself.
|
||||
track_target_failures = (
|
||||
event.provider_type.value != "bridge_self"
|
||||
)
|
||||
|
||||
for tc, target_entries in groups.values():
|
||||
if not target_entries:
|
||||
continue
|
||||
shaped_event = apply_tracking_display_filters(event, tc)
|
||||
if shaped_event is None:
|
||||
continue
|
||||
target_configs = [entry[0] for entry in target_entries]
|
||||
try:
|
||||
results = await dispatcher.dispatch(shaped_event, target_configs)
|
||||
except Exception as err: # noqa: BLE001
|
||||
@@ -213,14 +308,29 @@ async def dispatch_provider_event(
|
||||
"Dispatcher raised for tracker %d: %s", tracker.id, err,
|
||||
)
|
||||
continue
|
||||
for r in results:
|
||||
for entry, r in zip(target_entries, results):
|
||||
_, target_id, target_name = entry
|
||||
if r.get("success"):
|
||||
dispatched += 1
|
||||
if track_target_failures and target_id is not None:
|
||||
record_target_success(int(target_id))
|
||||
else:
|
||||
_LOGGER.error(
|
||||
"Notification failed for tracker %d: %s",
|
||||
tracker.id, r.get("error", "unknown"),
|
||||
)
|
||||
if track_target_failures and target_id is not None:
|
||||
try:
|
||||
await maybe_emit_target_failure(
|
||||
target_id=int(target_id),
|
||||
target_name=target_name or "",
|
||||
target_type=entry[0].type,
|
||||
error=str(r.get("error") or ""),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"bridge_self target-failure emission failed",
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from notify_bridge_core.providers.home_assistant import (
|
||||
)
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import ServiceProvider
|
||||
from ..database.models import EventLog, ServiceProvider
|
||||
from .event_dispatch import dispatch_provider_event
|
||||
from .http_session import get_http_session
|
||||
|
||||
@@ -104,6 +104,46 @@ def _ha_passes_filters(event: ServiceEvent, filters: dict[str, Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def _record_ha_status(
|
||||
*,
|
||||
provider_id: int,
|
||||
provider_name: str,
|
||||
state: str,
|
||||
detail: str | None,
|
||||
) -> None:
|
||||
"""Persist an HA connection-status transition as an EventLog row.
|
||||
|
||||
Used by the supervisor's ``on_status_change`` callback so the
|
||||
dashboard surfaces "HA disconnected" / "HA reconnected" events
|
||||
alongside normal HA state changes. Best-effort: any DB failure is
|
||||
logged and swallowed so the WS reader path remains untouched.
|
||||
"""
|
||||
engine = get_engine()
|
||||
try:
|
||||
async with AsyncSession(engine) as session:
|
||||
session.add(EventLog(
|
||||
user_id=None, # provider-level event, no per-tracker owner
|
||||
tracker_id=None,
|
||||
tracker_name="",
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
event_type=f"ha_status_{state}",
|
||||
collection_id="",
|
||||
collection_name="",
|
||||
assets_count=0,
|
||||
details={
|
||||
"provider_type": "home_assistant",
|
||||
"ha_status": state,
|
||||
"ha_status_detail": detail or "",
|
||||
},
|
||||
))
|
||||
await session.commit()
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"Failed to persist HA status row for provider %s", provider_id,
|
||||
)
|
||||
|
||||
|
||||
async def _run_provider(provider_id: int) -> None:
|
||||
"""One per-provider supervisor loop.
|
||||
|
||||
@@ -158,29 +198,38 @@ async def _run_provider(provider_id: int) -> None:
|
||||
|
||||
async def _emit(event: ServiceEvent) -> None:
|
||||
# Shield the DB-writing dispatch from external cancellation
|
||||
# (shutdown, supervisor restart). The shield ensures that
|
||||
# once a transaction is mid-flight, it commits or rolls back
|
||||
# cleanly instead of being torn down with the asyncio task
|
||||
# at a write boundary. Worst case: shutdown waits up to one
|
||||
# dispatch latency longer.
|
||||
# (shutdown, supervisor restart). Without an inner Task,
|
||||
# ``asyncio.shield(coro)`` cancels the underlying coroutine
|
||||
# when the outer awaiter is cancelled — defeating the point
|
||||
# of the shield. We wrap explicitly and *drain* the inner
|
||||
# task on cancellation so the transaction completes.
|
||||
#
|
||||
# Perf note (Phase 2 follow-up): dispatch_provider_event
|
||||
# opens a fresh AsyncSession per call. For HA's chatty
|
||||
# state_changed bus this hammers the pool; batch in a
|
||||
# follow-up.
|
||||
inner = asyncio.create_task(dispatch_provider_event(
|
||||
engine=engine,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
provider_config=config,
|
||||
event=event,
|
||||
detail_keys=_HA_DETAIL_KEYS,
|
||||
filter_fn=_ha_passes_filters,
|
||||
))
|
||||
try:
|
||||
await asyncio.shield(dispatch_provider_event(
|
||||
engine=engine,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
provider_config=config,
|
||||
event=event,
|
||||
detail_keys=_HA_DETAIL_KEYS,
|
||||
filter_fn=_ha_passes_filters,
|
||||
))
|
||||
await asyncio.shield(inner)
|
||||
except asyncio.CancelledError:
|
||||
# Shield re-raises CancelledError to the caller; let it
|
||||
# propagate so the drain task exits cleanly.
|
||||
# Drain the in-flight write before re-raising so the DB
|
||||
# row commits cleanly (or fails cleanly) instead of
|
||||
# being torn down at an arbitrary await point.
|
||||
try:
|
||||
await inner
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"HA dispatch raised while draining shielded "
|
||||
"task for provider %s", provider_id,
|
||||
)
|
||||
raise
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
@@ -188,11 +237,33 @@ async def _run_provider(provider_id: int) -> None:
|
||||
provider_id,
|
||||
)
|
||||
|
||||
def _on_status_change(state: str, detail: str | None) -> None:
|
||||
"""Persist HA WS connect/disconnect transitions as event_log rows.
|
||||
|
||||
The client invokes this synchronously from inside the WS
|
||||
run loop, so we can't ``await`` here. Schedule a fire-and-
|
||||
forget task on the same loop instead — log failures, never
|
||||
propagate them back into the WS reader.
|
||||
"""
|
||||
try:
|
||||
asyncio.create_task(_record_ha_status(
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
state=state,
|
||||
detail=detail,
|
||||
))
|
||||
except RuntimeError:
|
||||
# No running loop (shouldn't happen in normal operation).
|
||||
_LOGGER.debug(
|
||||
"Skipped HA status row for provider %s: no event loop",
|
||||
provider_id,
|
||||
)
|
||||
|
||||
_LOGGER.info(
|
||||
"Starting HA subscription for provider %s (%s)",
|
||||
provider_id, provider_name,
|
||||
)
|
||||
await ha_provider.subscribe(_emit)
|
||||
await ha_provider.subscribe(_emit, on_status_change=_on_status_change)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except HomeAssistantAuthError as err:
|
||||
|
||||
@@ -6,6 +6,21 @@ per-request ``aiohttp.ClientSession`` instances. This keeps a single
|
||||
TCP connection pool alive for the lifetime of the process, avoiding
|
||||
the overhead of pool creation/teardown on every request.
|
||||
|
||||
DNS-rebinding mitigation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
The session is wired with a :class:`PinnedResolver` from
|
||||
``notify_bridge_core.notifications.ssrf`` so the IP that passed the
|
||||
SSRF block-range check during URL validation is the one aiohttp
|
||||
actually connects to. Without this pinning a malicious DNS server
|
||||
could swap a public IP for ``127.0.0.1`` between validation and
|
||||
connect, defeating the SSRF guard.
|
||||
|
||||
Callers that do their own ``avalidate_outbound_url_full`` should also
|
||||
call :func:`pin_validated` to register the resolved host->IP mapping
|
||||
on the shared resolver before issuing the request. Callers that just
|
||||
use the session opportunistically still benefit from scheme + range
|
||||
checks at validation sites, plus the fallback resolver here.
|
||||
|
||||
Call ``close_http_session()`` once during application shutdown.
|
||||
"""
|
||||
|
||||
@@ -15,32 +30,71 @@ import asyncio
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.notifications.ssrf import PinnedResolver, ValidatedURL
|
||||
|
||||
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30, connect=10)
|
||||
|
||||
_session: aiohttp.ClientSession | None = None
|
||||
_lock = asyncio.Lock()
|
||||
_resolver: PinnedResolver | None = None
|
||||
# Lazy init: ``asyncio.Lock()`` at module import time binds to whichever
|
||||
# event loop happens to be running (or none). Tests that spin up multiple
|
||||
# loops (or subprocesses with their own loop) would otherwise hit
|
||||
# "RuntimeError: ... attached to a different loop". Defer creation to
|
||||
# first use so the lock binds to the loop that actually calls us.
|
||||
_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
def _get_lock() -> asyncio.Lock:
|
||||
"""Return the module lock, creating it on first call from this loop."""
|
||||
global _lock
|
||||
if _lock is None:
|
||||
_lock = asyncio.Lock()
|
||||
return _lock
|
||||
|
||||
|
||||
async def get_http_session() -> aiohttp.ClientSession:
|
||||
"""Get or create the shared HTTP session.
|
||||
|
||||
Concurrent first-callers are serialized through ``_lock`` so we never
|
||||
leak a second ClientSession / connector pair. Once established, hot
|
||||
callers skip the lock via the fast-path check.
|
||||
Concurrent first-callers are serialized through the lazy lock so we
|
||||
never leak a second ClientSession / connector pair. Once established,
|
||||
hot callers skip the lock via the fast-path check.
|
||||
|
||||
The session uses a :class:`PinnedResolver` connector so callers that
|
||||
register validated host->IP mappings via :func:`pin_validated` defeat
|
||||
DNS rebinding between validation and connect.
|
||||
"""
|
||||
global _session
|
||||
global _session, _resolver
|
||||
if _session is not None and not _session.closed:
|
||||
return _session
|
||||
async with _lock:
|
||||
async with _get_lock():
|
||||
if _session is None or _session.closed:
|
||||
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
|
||||
_resolver = PinnedResolver()
|
||||
connector = aiohttp.TCPConnector(resolver=_resolver)
|
||||
_session = aiohttp.ClientSession(
|
||||
timeout=_DEFAULT_TIMEOUT, connector=connector,
|
||||
)
|
||||
return _session
|
||||
|
||||
|
||||
def pin_validated(validated: ValidatedURL) -> None:
|
||||
"""Register a validated (host, ip) mapping on the shared resolver.
|
||||
|
||||
Best-effort: if the resolver has not been created yet (no session
|
||||
initialised), the call is a no-op. Once the session exists, every
|
||||
aiohttp connect for ``validated.host`` will use ``validated.ip``.
|
||||
"""
|
||||
if _resolver is None:
|
||||
return
|
||||
_resolver.pin(validated.host, validated.ip)
|
||||
|
||||
|
||||
async def close_http_session() -> None:
|
||||
"""Close the shared HTTP session (call on app shutdown)."""
|
||||
global _session
|
||||
async with _lock:
|
||||
global _session, _resolver
|
||||
async with _get_lock():
|
||||
if _session is not None and not _session.closed:
|
||||
await _session.close()
|
||||
if _resolver is not None:
|
||||
await _resolver.close()
|
||||
_session = None
|
||||
_resolver = None
|
||||
|
||||
@@ -24,7 +24,7 @@ from ..database.models import (
|
||||
TemplateSlot,
|
||||
TrackingConfig,
|
||||
)
|
||||
from .dispatch_helpers import _resolve_target
|
||||
from .dispatch_helpers import _resolve_target, resolve_provider_credential
|
||||
from .watcher import _get_telegram_caches
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -94,7 +94,7 @@ async def dispatch_test_notification(
|
||||
locale=locale,
|
||||
date_format=template_config.date_format if template_config else "%d.%m.%Y, %H:%M UTC",
|
||||
date_only_format=template_config.date_only_format if template_config and template_config.date_only_format else "%d.%m.%Y",
|
||||
provider_api_key=provider_config.get("api_key"),
|
||||
provider_api_key=resolve_provider_credential(provider_config),
|
||||
provider_internal_url=provider_config.get("url", ""),
|
||||
provider_external_url=provider_config.get("external_domain", ""),
|
||||
receivers=resolved["receivers"],
|
||||
|
||||
@@ -442,8 +442,20 @@ async def _send_telegram_test_per_receiver(
|
||||
disable_web_page_preview=bool(disable_preview),
|
||||
)
|
||||
|
||||
raw = await asyncio.gather(*(_send_one(r) for r in recv_rows))
|
||||
results = [r for r in raw if r is not None]
|
||||
# ``return_exceptions=True`` so a single send raising (e.g. transient
|
||||
# network error to one chat) doesn't abort the entire fan-out and lose
|
||||
# the successful sibling sends from the aggregate count.
|
||||
raw = await asyncio.gather(
|
||||
*(_send_one(r) for r in recv_rows), return_exceptions=True,
|
||||
)
|
||||
results: list[dict] = []
|
||||
for r in raw:
|
||||
if isinstance(r, BaseException):
|
||||
_LOGGER.warning("Test send to receiver raised: %s", r)
|
||||
continue
|
||||
if r is None:
|
||||
continue
|
||||
results.append(r)
|
||||
return _aggregate(results)
|
||||
|
||||
|
||||
|
||||
@@ -234,4 +234,12 @@ _SAMPLE_CONTEXT = {
|
||||
"target_entity": "light.kitchen",
|
||||
"ha_event_type": "state_changed",
|
||||
"event_data": {"foo": "bar"},
|
||||
# Bridge self-monitoring variables (for bridge_self provider templates)
|
||||
"failure_type": "poll_failures",
|
||||
"subject_id": 42,
|
||||
"subject_name": "My Immich Tracker",
|
||||
"count": 3,
|
||||
"threshold": 3,
|
||||
"last_error": "Connection refused",
|
||||
"details": {"provider_id": 7, "provider_type": "immich"},
|
||||
}
|
||||
|
||||
@@ -49,6 +49,7 @@ from .dispatch_helpers import (
|
||||
evaluate_event_gate,
|
||||
get_app_timezone,
|
||||
load_link_data,
|
||||
resolve_provider_credential,
|
||||
)
|
||||
from .manual_dispatch import build_immich_dispatch_events
|
||||
|
||||
@@ -352,7 +353,7 @@ async def dispatch_scheduled_for_tracker(
|
||||
date_only_format=(
|
||||
tmpl.date_only_format or "%d.%m.%Y"
|
||||
),
|
||||
provider_api_key=provider_config.get("api_key"),
|
||||
provider_api_key=resolve_provider_credential(provider_config),
|
||||
provider_internal_url=provider_config.get("url", ""),
|
||||
provider_external_url=provider_config.get("external_domain", ""),
|
||||
receivers=ld["receivers"],
|
||||
|
||||
@@ -163,6 +163,9 @@ async def start_scheduler() -> None:
|
||||
# Schedule the upstream release-check probe.
|
||||
await _schedule_release_check()
|
||||
|
||||
# Schedule the bridge_self deferred-backlog scan (every 5 min).
|
||||
_schedule_bridge_self_backlog_scan()
|
||||
|
||||
|
||||
def _schedule_event_cleanup() -> None:
|
||||
"""Schedule a daily job to delete EventLog entries older than 90 days."""
|
||||
@@ -1122,7 +1125,11 @@ _DRAIN_CATCHUP_INTERVAL_SECONDS = 300
|
||||
|
||||
|
||||
def _drain_job_id_for(fire_at_utc: datetime) -> str:
|
||||
return f"{_DEFERRED_DRAIN_PREFIX}{fire_at_utc.strftime('%Y%m%d%H%M')}"
|
||||
# Include seconds — two trackers with quiet windows that end at the same
|
||||
# minute but different seconds (e.g. user-set 06:00:00 vs 06:00:30) would
|
||||
# otherwise collide on a single APScheduler job id, and ``replace_existing``
|
||||
# would silently drop the second one.
|
||||
return f"{_DEFERRED_DRAIN_PREFIX}{fire_at_utc.strftime('%Y%m%d%H%M%S')}"
|
||||
|
||||
|
||||
def schedule_deferred_drain(fire_at_utc: datetime) -> None:
|
||||
@@ -1298,6 +1305,50 @@ async def _schedule_release_check() -> None:
|
||||
interval_hours, _RELEASE_CHECK_ONESHOT_DELAY_SECONDS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bridge self-monitoring — deferred-backlog scan
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_BRIDGE_SELF_BACKLOG_JOB_ID = "bridge_self_deferred_backlog_scan"
|
||||
# 5 min trade-off between "operator notices the backlog quickly" and "extra
|
||||
# DB churn on a quiet system". The scan is one indexed GROUP BY query.
|
||||
_BRIDGE_SELF_BACKLOG_INTERVAL_SECONDS = 300
|
||||
|
||||
|
||||
def _schedule_bridge_self_backlog_scan() -> None:
|
||||
"""Install the periodic deferred-backlog scan for bridge_self."""
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
scheduler = get_scheduler()
|
||||
if scheduler.get_job(_BRIDGE_SELF_BACKLOG_JOB_ID):
|
||||
return
|
||||
scheduler.add_job(
|
||||
_run_bridge_self_backlog_scan,
|
||||
IntervalTrigger(seconds=_BRIDGE_SELF_BACKLOG_INTERVAL_SECONDS),
|
||||
id=_BRIDGE_SELF_BACKLOG_JOB_ID,
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
coalesce=True,
|
||||
)
|
||||
_LOGGER.info(
|
||||
"Scheduled bridge_self deferred-backlog scan every %ds",
|
||||
_BRIDGE_SELF_BACKLOG_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
async def _run_bridge_self_backlog_scan() -> None:
|
||||
"""APScheduler entry point — scan deferred backlog and emit if needed."""
|
||||
from .bridge_self import check_deferred_backlog
|
||||
try:
|
||||
stats = await check_deferred_backlog()
|
||||
if stats.get("crossings"):
|
||||
_LOGGER.info("bridge_self backlog scan stats: %s", stats)
|
||||
else:
|
||||
_LOGGER.debug("bridge_self backlog scan stats: %s", stats)
|
||||
except Exception as err: # noqa: BLE001
|
||||
_LOGGER.exception("bridge_self backlog scan failed: %s", err)
|
||||
|
||||
|
||||
async def reschedule_release_check() -> None:
|
||||
"""Re-arm the release-check job after settings changed.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Telegram service utilities — chat persistence helpers."""
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..database.models import TelegramChat
|
||||
@@ -12,36 +12,48 @@ async def save_chat_from_webhook(
|
||||
) -> None:
|
||||
"""Save or update a chat entry from an incoming webhook message.
|
||||
|
||||
Called by the webhook handler to auto-persist chats.
|
||||
Called by the webhook handler to auto-persist chats. Uses a single
|
||||
``INSERT ... ON CONFLICT DO UPDATE`` keyed on the
|
||||
``uq_telegram_chat_bot_chat`` unique constraint so two concurrent
|
||||
webhook deliveries cannot race a check-then-insert and produce
|
||||
duplicate rows. Only mutable display/identity fields are updated on
|
||||
conflict — ``commands_enabled``, ``language_override``, and
|
||||
``discovered_at`` belong to the user / first discovery and stay sticky.
|
||||
"""
|
||||
chat_id = str(chat_data.get("id", ""))
|
||||
if not chat_id:
|
||||
return
|
||||
|
||||
result = await session.exec(
|
||||
select(TelegramChat).where(
|
||||
TelegramChat.bot_id == bot_id,
|
||||
TelegramChat.chat_id == chat_id,
|
||||
)
|
||||
)
|
||||
existing = result.first()
|
||||
|
||||
title = chat_data.get("title") or (
|
||||
chat_data.get("first_name", "") + (" " + chat_data.get("last_name", "")).strip()
|
||||
)
|
||||
chat_type = chat_data.get("type", "private")
|
||||
username = chat_data.get("username", "")
|
||||
|
||||
if existing:
|
||||
existing.title = title
|
||||
existing.username = chat_data.get("username", existing.username)
|
||||
if language_code:
|
||||
existing.language_code = language_code
|
||||
session.add(existing)
|
||||
else:
|
||||
session.add(TelegramChat(
|
||||
bot_id=bot_id,
|
||||
chat_id=chat_id,
|
||||
title=title,
|
||||
chat_type=chat_data.get("type", "private"),
|
||||
username=chat_data.get("username", ""),
|
||||
language_code=language_code,
|
||||
))
|
||||
# Only the SQLite dialect path is wired up — the deployed default. A
|
||||
# future Postgres backend would need pg_insert here; the unique
|
||||
# constraint name is dialect-portable so the same conflict_target works.
|
||||
stmt = sqlite_insert(TelegramChat).values(
|
||||
bot_id=bot_id,
|
||||
chat_id=chat_id,
|
||||
title=title,
|
||||
chat_type=chat_type,
|
||||
username=username,
|
||||
language_code=language_code,
|
||||
)
|
||||
update_cols: dict = {
|
||||
"title": title,
|
||||
"chat_type": chat_type,
|
||||
"username": username,
|
||||
}
|
||||
# Only overwrite language_code when the inbound payload carries one,
|
||||
# otherwise we'd clobber a previously-detected locale with empty.
|
||||
if language_code:
|
||||
update_cols["language_code"] = language_code
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["bot_id", "chat_id"],
|
||||
set_=update_cols,
|
||||
)
|
||||
# session.execute (not exec) — exec is the SQLModel/Select wrapper that
|
||||
# rejects raw Core Insert statements.
|
||||
await session.execute(stmt)
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -12,6 +12,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from notify_bridge_core.models.events import ServiceEvent
|
||||
from notify_bridge_core.notifications.dispatcher import NotificationDispatcher, TargetConfig
|
||||
from notify_bridge_core.notifications.telegram.cache import TelegramFileCache
|
||||
from notify_bridge_core.providers.capabilities import get_capabilities
|
||||
from notify_bridge_core.storage import JsonFileBackend
|
||||
|
||||
from ..database.engine import get_engine
|
||||
@@ -27,6 +28,7 @@ from .dispatch_helpers import (
|
||||
evaluate_event_gate,
|
||||
get_app_timezone,
|
||||
load_link_data,
|
||||
resolve_provider_credential,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -34,7 +36,18 @@ _LOGGER = logging.getLogger(__name__)
|
||||
# Module-level Telegram file caches — shared across dispatches for reuse
|
||||
_url_cache: TelegramFileCache | None = None
|
||||
_asset_cache: TelegramFileCache | None = None
|
||||
_cache_lock = asyncio.Lock()
|
||||
# Lazy init: creating ``asyncio.Lock()`` at module import time binds the
|
||||
# lock to whichever event loop is current at import (often none / the wrong
|
||||
# one when tests fire up dedicated loops). Defer until first use.
|
||||
_cache_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
def _get_cache_lock() -> asyncio.Lock:
|
||||
"""Return the module cache lock, creating it on first call."""
|
||||
global _cache_lock
|
||||
if _cache_lock is None:
|
||||
_cache_lock = asyncio.Lock()
|
||||
return _cache_lock
|
||||
|
||||
|
||||
async def _load_cache_settings() -> tuple[int, int]:
|
||||
@@ -68,7 +81,7 @@ async def _get_telegram_caches() -> tuple[TelegramFileCache | None, TelegramFile
|
||||
global _url_cache, _asset_cache
|
||||
if _url_cache is not None:
|
||||
return _url_cache, _asset_cache
|
||||
async with _cache_lock:
|
||||
async with _get_cache_lock():
|
||||
# Double-check after acquiring lock
|
||||
if _url_cache is not None:
|
||||
return _url_cache, _asset_cache
|
||||
@@ -108,7 +121,7 @@ async def reset_telegram_caches_in_memory() -> None:
|
||||
deletes cached file_ids.
|
||||
"""
|
||||
global _url_cache, _asset_cache
|
||||
async with _cache_lock:
|
||||
async with _get_cache_lock():
|
||||
_url_cache = None
|
||||
_asset_cache = None
|
||||
_LOGGER.info("Reset Telegram cache refs in memory (files preserved)")
|
||||
@@ -135,7 +148,7 @@ async def clear_telegram_caches() -> dict[str, Any]:
|
||||
Returns a summary with the paths that were removed.
|
||||
"""
|
||||
global _url_cache, _asset_cache
|
||||
async with _cache_lock:
|
||||
async with _get_cache_lock():
|
||||
removed: list[str] = []
|
||||
for cache, label in ((_url_cache, "url"), (_asset_cache, "asset")):
|
||||
if cache is not None:
|
||||
@@ -163,6 +176,90 @@ async def clear_telegram_caches() -> dict[str, Any]:
|
||||
return {"cleared": True, "removed": removed}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider polling registry
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Each registered factory returns (events, new_state). Replaces the long
|
||||
# ``if provider_type == ...`` chain in ``check_tracker``. New pollable
|
||||
# providers register here; webhook-only providers are short-circuited above
|
||||
# via ``capabilities.webhook_based``.
|
||||
|
||||
class _PollerConnectError(Exception):
|
||||
"""Raised by a poller factory when initial provider connection fails."""
|
||||
|
||||
def __init__(self, reason: str) -> None:
|
||||
super().__init__(reason)
|
||||
self.reason = reason
|
||||
|
||||
|
||||
PollResult = tuple[list[ServiceEvent], dict[str, Any]]
|
||||
PollerFactory = Callable[..., Awaitable[PollResult]]
|
||||
|
||||
|
||||
async def _poll_immich(*, provider_config, provider_name, collection_ids, state_dict, **_kw) -> PollResult:
|
||||
from notify_bridge_core.providers.immich import ImmichServiceProvider
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
immich = ImmichServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
provider_config.get("external_domain"),
|
||||
provider_name,
|
||||
)
|
||||
if not await immich.connect():
|
||||
raise _PollerConnectError("failed to connect to provider")
|
||||
return await immich.poll(collection_ids, state_dict)
|
||||
|
||||
|
||||
async def _poll_scheduler(*, provider_name, tracker_name, tracker_filters, collection_ids, state_dict, app_tz, **_kw) -> PollResult:
|
||||
from notify_bridge_core.providers.scheduler import SchedulerServiceProvider
|
||||
sched = SchedulerServiceProvider(
|
||||
name=provider_name,
|
||||
tracker_name=tracker_name,
|
||||
custom_variables=tracker_filters.get("custom_variables", {}),
|
||||
timezone_name=app_tz,
|
||||
)
|
||||
return await sched.poll(collection_ids, state_dict)
|
||||
|
||||
|
||||
async def _poll_nut(*, provider_config, provider_name, collection_ids, state_dict, **_kw) -> PollResult:
|
||||
from notify_bridge_core.providers.nut import NutServiceProvider
|
||||
nut = NutServiceProvider(
|
||||
host=provider_config.get("host", "localhost"),
|
||||
port=provider_config.get("port", 3493),
|
||||
username=provider_config.get("username"),
|
||||
password=provider_config.get("password"),
|
||||
name=provider_name,
|
||||
)
|
||||
return await nut.poll(collection_ids, state_dict)
|
||||
|
||||
|
||||
async def _poll_google_photos(*, provider_config, provider_name, collection_ids, state_dict, **_kw) -> PollResult:
|
||||
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
gp = GooglePhotosServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("client_id", ""),
|
||||
provider_config.get("client_secret", ""),
|
||||
provider_config.get("refresh_token", ""),
|
||||
provider_name,
|
||||
)
|
||||
if not await gp.connect():
|
||||
raise _PollerConnectError("failed to connect to Google Photos")
|
||||
return await gp.poll(collection_ids, state_dict)
|
||||
|
||||
|
||||
_POLL_FACTORIES: dict[str, PollerFactory] = {
|
||||
"immich": _poll_immich,
|
||||
"scheduler": _poll_scheduler,
|
||||
"nut": _poll_nut,
|
||||
"google_photos": _poll_google_photos,
|
||||
}
|
||||
|
||||
|
||||
async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
"""Poll a tracker's provider for changes and dispatch notifications."""
|
||||
engine = get_engine()
|
||||
@@ -223,70 +320,61 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
events: list[ServiceEvent] = []
|
||||
new_state: dict[str, Any] = {}
|
||||
|
||||
if provider_type == "immich":
|
||||
from notify_bridge_core.providers.immich import ImmichServiceProvider
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
immich = ImmichServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
provider_config.get("external_domain"),
|
||||
provider_name,
|
||||
)
|
||||
connected = await immich.connect()
|
||||
if not connected:
|
||||
return {"status": "error", "reason": "failed to connect to provider"}
|
||||
# Webhook-only providers: capabilities.webhook_based short-circuits the
|
||||
# poll path. Inbound events arrive via the /api/webhooks/* endpoints.
|
||||
caps = get_capabilities(provider_type)
|
||||
if caps is not None and caps.webhook_based:
|
||||
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
|
||||
|
||||
events, new_state = await immich.poll(collection_ids, state_dict)
|
||||
elif provider_type == "gitea":
|
||||
# Gitea is webhook-based — events arrive via /api/webhooks/gitea endpoint.
|
||||
# The scheduler still calls check_tracker but there's nothing to poll.
|
||||
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
|
||||
elif provider_type == "planka":
|
||||
# Planka is webhook-based — events arrive via /api/webhooks/planka endpoint.
|
||||
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
|
||||
elif provider_type == "scheduler":
|
||||
from notify_bridge_core.providers.scheduler import SchedulerServiceProvider
|
||||
custom_vars = tracker_filters.get("custom_variables", {})
|
||||
sched = SchedulerServiceProvider(
|
||||
name=provider_name,
|
||||
tracker_name=tracker_name,
|
||||
custom_variables=custom_vars,
|
||||
timezone_name=app_tz,
|
||||
)
|
||||
events, new_state = await sched.poll(collection_ids, state_dict)
|
||||
elif provider_type == "nut":
|
||||
from notify_bridge_core.providers.nut import NutServiceProvider
|
||||
nut = NutServiceProvider(
|
||||
host=provider_config.get("host", "localhost"),
|
||||
port=provider_config.get("port", 3493),
|
||||
username=provider_config.get("username"),
|
||||
password=provider_config.get("password"),
|
||||
name=provider_name,
|
||||
)
|
||||
events, new_state = await nut.poll(collection_ids, state_dict)
|
||||
elif provider_type == "google_photos":
|
||||
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
gp = GooglePhotosServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("client_id", ""),
|
||||
provider_config.get("client_secret", ""),
|
||||
provider_config.get("refresh_token", ""),
|
||||
provider_name,
|
||||
)
|
||||
connected = await gp.connect()
|
||||
if not connected:
|
||||
return {"status": "error", "reason": "failed to connect to Google Photos"}
|
||||
events, new_state = await gp.poll(collection_ids, state_dict)
|
||||
elif provider_type == "webhook":
|
||||
# Webhook providers receive events via inbound HTTP; no polling needed.
|
||||
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
|
||||
else:
|
||||
poller = _POLL_FACTORIES.get(provider_type)
|
||||
if poller is None:
|
||||
return {"status": "error", "reason": f"unsupported provider type: {provider_type}"}
|
||||
|
||||
try:
|
||||
events, new_state = await poller(
|
||||
provider_config=provider_config,
|
||||
provider_name=provider_name,
|
||||
tracker_name=tracker_name,
|
||||
tracker_filters=tracker_filters,
|
||||
collection_ids=collection_ids,
|
||||
state_dict=state_dict,
|
||||
app_tz=app_tz,
|
||||
)
|
||||
except _PollerConnectError as exc:
|
||||
# Track consecutive poll failures so the bridge_self provider can
|
||||
# alert when a tracker stops responding. The emission is async
|
||||
# but cheap; we await it inline so its DB writes happen before
|
||||
# check_tracker returns to the scheduler.
|
||||
from .bridge_self import maybe_emit_poll_failure
|
||||
try:
|
||||
await maybe_emit_poll_failure(
|
||||
tracker_id=tracker_id,
|
||||
tracker_name=tracker_name,
|
||||
error=exc.reason,
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception("bridge_self poll-failure emission failed")
|
||||
return {"status": "error", "reason": exc.reason}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# Catch broader poll exceptions (e.g. a provider-side bug, transient
|
||||
# network error inside the poller after connect) so the same
|
||||
# streak-tracking logic applies. Re-raised after the bookkeeping so
|
||||
# the existing error path keeps logging at the caller.
|
||||
from .bridge_self import maybe_emit_poll_failure
|
||||
try:
|
||||
await maybe_emit_poll_failure(
|
||||
tracker_id=tracker_id,
|
||||
tracker_name=tracker_name,
|
||||
error=str(exc),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception("bridge_self poll-failure emission failed")
|
||||
raise
|
||||
|
||||
# Successful poll — clear the consecutive-failure counter for this tracker.
|
||||
from .bridge_self import record_poll_success
|
||||
record_poll_success(tracker_id)
|
||||
|
||||
# Save updated state and log events
|
||||
async with AsyncSession(engine) as session:
|
||||
for cid, cstate in new_state.items():
|
||||
@@ -328,6 +416,16 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
# row if quiet hours suppresses it.
|
||||
event_log_id_by_event: dict[int, int] = {}
|
||||
for event in events:
|
||||
# Skip persistence for events the dispatch loop will filter
|
||||
# anyway (assets_added with 0 added, assets_removed with 0
|
||||
# removed). Without this we wrote a "noise" row for every
|
||||
# tracker tick that detected nothing. The dispatch-time filter
|
||||
# below still runs as a safety net.
|
||||
etype = event.event_type.value
|
||||
if etype == "assets_added" and event.added_count == 0:
|
||||
continue
|
||||
if etype == "assets_removed" and event.removed_count == 0:
|
||||
continue
|
||||
assets_count = event.added_count or event.removed_count or 0
|
||||
details: dict[str, Any] = {
|
||||
"added_count": event.added_count,
|
||||
@@ -445,7 +543,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
template_slots=ld["template_slots"],
|
||||
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
|
||||
date_only_format=tmpl.date_only_format if tmpl and tmpl.date_only_format else "%d.%m.%Y",
|
||||
provider_api_key=provider_config.get("api_key"),
|
||||
provider_api_key=resolve_provider_credential(provider_config),
|
||||
provider_internal_url=provider_config.get("url", ""),
|
||||
provider_external_url=provider_config.get("external_domain", ""),
|
||||
receivers=ld["receivers"],
|
||||
@@ -453,7 +551,9 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
key = id(tc) if tc is not None else 0
|
||||
if key not in groups:
|
||||
groups[key] = (tc, [])
|
||||
groups[key][1].append(target_cfg)
|
||||
# Threaded with target_id/target_name so per-target failure
|
||||
# counters can attribute the dispatch result correctly.
|
||||
groups[key][1].append((target_cfg, ld.get("target_id"), ld.get("target_name", "")))
|
||||
|
||||
# Persist defers + stamp the event_log row + schedule drains in a
|
||||
# single transaction. This keeps the "deferred" pill on the
|
||||
@@ -496,8 +596,17 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
"Failed to schedule deferred drain for %s", fire_at,
|
||||
)
|
||||
|
||||
for tc, target_configs in groups.values():
|
||||
if not target_configs:
|
||||
from .bridge_self import (
|
||||
maybe_emit_target_failure,
|
||||
record_target_success,
|
||||
)
|
||||
|
||||
track_target_failures = (
|
||||
event.provider_type.value != "bridge_self"
|
||||
)
|
||||
|
||||
for tc, target_entries in groups.values():
|
||||
if not target_entries:
|
||||
continue
|
||||
shaped_event = apply_tracking_display_filters(event, tc)
|
||||
if shaped_event is None:
|
||||
@@ -505,12 +614,28 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
" Event suppressed by display filters (favorites_only)",
|
||||
)
|
||||
continue
|
||||
target_configs = [entry[0] for entry in target_entries]
|
||||
results = await dispatcher.dispatch(shaped_event, target_configs)
|
||||
for r in results:
|
||||
for entry, r in zip(target_entries, results):
|
||||
_, target_id, target_name = entry
|
||||
if r.get("success"):
|
||||
_LOGGER.info(" Notification sent successfully")
|
||||
if track_target_failures and target_id is not None:
|
||||
record_target_success(int(target_id))
|
||||
else:
|
||||
_LOGGER.error(" Notification failed: %s", r.get("error", "unknown"))
|
||||
if track_target_failures and target_id is not None:
|
||||
try:
|
||||
await maybe_emit_target_failure(
|
||||
target_id=int(target_id),
|
||||
target_name=target_name or "",
|
||||
target_type=entry[0].type,
|
||||
error=str(r.get("error") or ""),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"bridge_self target-failure emission failed",
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
|
||||
@@ -0,0 +1,268 @@
|
||||
"""End-to-end backup roundtrip: seed -> export -> wipe -> import -> verify.
|
||||
|
||||
Drives the backup service module directly (no HTTP layer) against a fresh
|
||||
SQLite DB built in the conftest temp data dir. Verifies entity counts and
|
||||
key fields survive a full round-trip.
|
||||
|
||||
Kept under 5s by avoiding the lifespan startup — we build a private engine
|
||||
in an isolated DB file so we don't share state with other tests in the
|
||||
session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def isolated_engine(tmp_path: Path):
|
||||
"""A throwaway SQLite engine + freshly created schema for one test.
|
||||
|
||||
Avoids the global engine in ``database.engine`` — tests in the same
|
||||
session share that singleton, and recreating tables on it would corrupt
|
||||
parallel tests' state.
|
||||
"""
|
||||
# Importing the module registers all SQLModel tables on the metadata.
|
||||
from notify_bridge_server.database import models # noqa: F401
|
||||
|
||||
db_path = tmp_path / "roundtrip.db"
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def _seed(session: AsyncSession, user_id: int) -> dict[str, int]:
|
||||
"""Insert enough rows to exercise the major code paths in import/export."""
|
||||
from notify_bridge_server.database.models import (
|
||||
EventLog,
|
||||
NotificationTarget,
|
||||
NotificationTracker,
|
||||
ServiceProvider,
|
||||
TargetReceiver,
|
||||
TelegramBot,
|
||||
TrackingConfig,
|
||||
User,
|
||||
)
|
||||
|
||||
user = User(
|
||||
id=user_id,
|
||||
username="roundtrip-user",
|
||||
hashed_password="hash",
|
||||
role="user",
|
||||
)
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
bot = TelegramBot(
|
||||
user_id=user_id, name="Test bot", token="123456:fake-token-value",
|
||||
bot_username="testbot", bot_id=1,
|
||||
)
|
||||
session.add(bot)
|
||||
await session.flush()
|
||||
|
||||
provider = ServiceProvider(
|
||||
user_id=user_id, type="immich", name="Immich prod",
|
||||
config={"base_url": "https://immich.example.com", "api_key": "secret"},
|
||||
)
|
||||
session.add(provider)
|
||||
await session.flush()
|
||||
|
||||
target = NotificationTarget(
|
||||
user_id=user_id, type="telegram", name="My channel",
|
||||
config={"bot_token_id": bot.id, "disable_url_preview": True},
|
||||
)
|
||||
session.add(target)
|
||||
await session.flush()
|
||||
|
||||
receiver = TargetReceiver(
|
||||
target_id=target.id, name="Channel A",
|
||||
config={"chat_id": "-100123"}, receiver_key="-100123", locale="en",
|
||||
)
|
||||
session.add(receiver)
|
||||
|
||||
tc = TrackingConfig(
|
||||
user_id=user_id, provider_type="immich",
|
||||
name="Default Immich tracking", track_assets_added=True,
|
||||
)
|
||||
session.add(tc)
|
||||
await session.flush()
|
||||
|
||||
tracker = NotificationTracker(
|
||||
user_id=user_id, provider_id=provider.id,
|
||||
name="Family album tracker", scan_interval=120,
|
||||
collection_ids=["album-uuid-1"],
|
||||
)
|
||||
session.add(tracker)
|
||||
await session.flush()
|
||||
|
||||
# Capture IDs before commit — accessing attributes after commit
|
||||
# triggers a refresh that needs an async-IO context the test caller
|
||||
# may not be inside. Better to snapshot now and use plain ints later.
|
||||
ids = {
|
||||
"provider_id": provider.id,
|
||||
"target_id": target.id,
|
||||
"bot_id": bot.id,
|
||||
"tracker_id": tracker.id,
|
||||
"tracking_config_id": tc.id,
|
||||
"tracker_name": tracker.name,
|
||||
"provider_name": provider.name,
|
||||
}
|
||||
|
||||
# EventLog rows are NOT in the backup schema — they're operational data,
|
||||
# not configuration. Insert a few anyway so we can verify they survive
|
||||
# the export step (since export only reads, never writes/wipes them).
|
||||
for i in range(3):
|
||||
session.add(EventLog(
|
||||
user_id=user_id, tracker_id=ids["tracker_id"], tracker_name=ids["tracker_name"],
|
||||
provider_id=ids["provider_id"], provider_name=ids["provider_name"],
|
||||
event_type="assets_added", collection_id="album-uuid-1",
|
||||
collection_name="Family", assets_count=i,
|
||||
))
|
||||
|
||||
await session.commit()
|
||||
|
||||
return ids
|
||||
|
||||
|
||||
async def _wipe_user_owned_rows(engine, user_id: int) -> None:
|
||||
"""Delete every backup-able row for the user via raw SQL.
|
||||
|
||||
Using ORM-level deletes triggers SQLAlchemy's cascade machinery, which
|
||||
lazy-loads relationships in a sync context that the async driver cannot
|
||||
serve (MissingGreenlet). Raw DELETE statements skip cascades and let
|
||||
SQLite's FKs enforce ordering naturally.
|
||||
|
||||
Order matters: child rows first, then parents.
|
||||
"""
|
||||
from sqlalchemy import text
|
||||
|
||||
statements = (
|
||||
"DELETE FROM event_log",
|
||||
"DELETE FROM notification_tracker_target",
|
||||
"DELETE FROM notification_tracker",
|
||||
"DELETE FROM target_receiver",
|
||||
"DELETE FROM notification_target",
|
||||
"DELETE FROM tracking_config",
|
||||
"DELETE FROM service_provider",
|
||||
"DELETE FROM template_slot",
|
||||
"DELETE FROM template_config",
|
||||
"DELETE FROM telegram_bot",
|
||||
"DELETE FROM appsetting",
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
for stmt in statements:
|
||||
try:
|
||||
await conn.execute(text(stmt))
|
||||
except Exception: # noqa: BLE001 — table may not exist in test schema
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_wipe_import_roundtrip(isolated_engine, tmp_data_dir) -> None: # noqa: ARG001
|
||||
"""A full round-trip preserves entity counts and the key fields the
|
||||
UI relies on — names, configs (with secrets included), provider
|
||||
references via id_map.
|
||||
"""
|
||||
from notify_bridge_server.database.models import (
|
||||
NotificationTarget, NotificationTracker, ServiceProvider,
|
||||
TargetReceiver, TelegramBot, TrackingConfig,
|
||||
)
|
||||
from notify_bridge_server.services.backup_schema import (
|
||||
ConflictMode, SecretsMode,
|
||||
)
|
||||
from notify_bridge_server.services.backup_service import (
|
||||
export_backup, import_backup,
|
||||
)
|
||||
|
||||
user_id = 1
|
||||
|
||||
# ---- Seed ----
|
||||
async with AsyncSession(isolated_engine) as session:
|
||||
ids = await _seed(session, user_id)
|
||||
|
||||
# ---- Export with secrets included so import sees real values ----
|
||||
async with AsyncSession(isolated_engine) as session:
|
||||
backup = await export_backup(
|
||||
session, user_id, secrets_mode=SecretsMode.INCLUDE,
|
||||
)
|
||||
|
||||
assert len(backup.data.providers) == 1
|
||||
assert len(backup.data.telegram_bots) == 1
|
||||
assert len(backup.data.targets) == 1
|
||||
assert len(backup.data.targets[0].receivers) == 1
|
||||
assert len(backup.data.tracking_configs) == 1
|
||||
assert len(backup.data.notification_trackers) == 1
|
||||
assert backup.data.providers[0].config["api_key"] == "secret"
|
||||
|
||||
# ---- Wipe ----
|
||||
await _wipe_user_owned_rows(isolated_engine, user_id)
|
||||
|
||||
async with AsyncSession(isolated_engine) as session:
|
||||
result = await session.exec(
|
||||
select(ServiceProvider).where(ServiceProvider.user_id == user_id)
|
||||
)
|
||||
assert result.all() == []
|
||||
|
||||
# ---- Import ----
|
||||
async with AsyncSession(isolated_engine) as session:
|
||||
result = await import_backup(
|
||||
session, user_id, backup, conflict_mode=ConflictMode.SKIP,
|
||||
)
|
||||
|
||||
assert result.errors == [], f"Import errors: {result.errors}"
|
||||
assert result.created > 0
|
||||
|
||||
# ---- Verify the entities are back ----
|
||||
async with AsyncSession(isolated_engine) as session:
|
||||
providers = (await session.exec(
|
||||
select(ServiceProvider).where(ServiceProvider.user_id == user_id)
|
||||
)).all()
|
||||
assert len(providers) == 1
|
||||
prov = providers[0]
|
||||
assert prov.name == "Immich prod"
|
||||
assert prov.config["base_url"] == "https://immich.example.com"
|
||||
# Secrets imported intact when SecretsMode.INCLUDE was used at export.
|
||||
assert prov.config["api_key"] == "secret"
|
||||
|
||||
bots = (await session.exec(
|
||||
select(TelegramBot).where(TelegramBot.user_id == user_id)
|
||||
)).all()
|
||||
assert len(bots) == 1
|
||||
assert bots[0].name == "Test bot"
|
||||
|
||||
targets = (await session.exec(
|
||||
select(NotificationTarget).where(NotificationTarget.user_id == user_id)
|
||||
)).all()
|
||||
assert len(targets) == 1
|
||||
receivers = (await session.exec(
|
||||
select(TargetReceiver).where(TargetReceiver.target_id == targets[0].id)
|
||||
)).all()
|
||||
assert len(receivers) == 1
|
||||
assert receivers[0].config["chat_id"] == "-100123"
|
||||
|
||||
tcs = (await session.exec(
|
||||
select(TrackingConfig).where(TrackingConfig.user_id == user_id)
|
||||
)).all()
|
||||
assert len(tcs) == 1
|
||||
assert tcs[0].name == "Default Immich tracking"
|
||||
|
||||
trackers = (await session.exec(
|
||||
select(NotificationTracker).where(NotificationTracker.user_id == user_id)
|
||||
)).all()
|
||||
assert len(trackers) == 1
|
||||
# provider_id was remapped via id_map — original provider id may have
|
||||
# changed across the wipe, so just check it links to a real row.
|
||||
assert trackers[0].provider_id == prov.id
|
||||
assert trackers[0].scan_interval == 120
|
||||
assert trackers[0].collection_ids == ["album-uuid-1"]
|
||||
@@ -0,0 +1,265 @@
|
||||
"""Tests for the bridge self-monitoring provider.
|
||||
|
||||
Covers:
|
||||
1. ``build_event`` parses a well-formed payload and rejects malformed ones.
|
||||
2. The threshold-crossing helpers in ``services.bridge_self`` only emit on
|
||||
the actual crossing, not on every increment afterwards (anti-spam).
|
||||
3. ``ensure_bridge_self_provider_for_user`` creates exactly one provider
|
||||
per user and is idempotent on re-run.
|
||||
4. The capability registry exposes the new event/slot definitions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_event_well_formed_payload() -> None:
|
||||
from notify_bridge_core.providers.bridge_self.event_parser import build_event
|
||||
from notify_bridge_core.models.events import EventType
|
||||
from notify_bridge_core.providers.base import ServiceProviderType
|
||||
|
||||
payload = {
|
||||
"failure_type": "poll_failures",
|
||||
"subject_id": 7,
|
||||
"subject_name": "My Tracker",
|
||||
"count": 3,
|
||||
"threshold": 3,
|
||||
"last_error": "Timeout",
|
||||
"details": {"tracker_id": 7},
|
||||
}
|
||||
when = datetime(2026, 5, 16, 10, 0, tzinfo=timezone.utc)
|
||||
event = build_event(payload, timestamp=when)
|
||||
|
||||
assert event is not None
|
||||
assert event.event_type == EventType.BRIDGE_SELF_POLL_FAILURES
|
||||
assert event.provider_type == ServiceProviderType.BRIDGE_SELF
|
||||
assert event.collection_id == "7"
|
||||
assert event.collection_name == "My Tracker"
|
||||
assert event.timestamp == when
|
||||
assert event.extra["count"] == 3
|
||||
assert event.extra["threshold"] == 3
|
||||
assert event.extra["last_error"] == "Timeout"
|
||||
assert event.extra["failure_type"] == "poll_failures"
|
||||
assert event.extra["details"] == {"tracker_id": 7}
|
||||
|
||||
|
||||
def test_build_event_unknown_failure_type_returns_none() -> None:
|
||||
from notify_bridge_core.providers.bridge_self.event_parser import build_event
|
||||
|
||||
assert build_event({"failure_type": "rocket_launch"}) is None
|
||||
|
||||
|
||||
def test_build_event_non_dict_payload_returns_none() -> None:
|
||||
from notify_bridge_core.providers.bridge_self.event_parser import build_event
|
||||
|
||||
assert build_event("not a dict") is None # type: ignore[arg-type]
|
||||
assert build_event(None) is None # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_build_event_clamps_long_error_messages() -> None:
|
||||
from notify_bridge_core.providers.bridge_self.event_parser import (
|
||||
build_event, _MAX_ERROR_LEN,
|
||||
)
|
||||
|
||||
huge = "X" * (_MAX_ERROR_LEN * 5)
|
||||
event = build_event({
|
||||
"failure_type": "target_failures",
|
||||
"subject_id": 1,
|
||||
"subject_name": "t",
|
||||
"count": 5,
|
||||
"threshold": 5,
|
||||
"last_error": huge,
|
||||
})
|
||||
assert event is not None
|
||||
assert len(event.extra["last_error"]) <= _MAX_ERROR_LEN
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Threshold-crossing counters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_record_poll_failure_increments_then_success_resets() -> None:
|
||||
from notify_bridge_server.services import bridge_self as bs
|
||||
|
||||
# Use a tracker_id we know is unique to this test to avoid pollution
|
||||
# across tests sharing the module-level dicts.
|
||||
tid = 9_001
|
||||
bs.reset_poll_counter(tid)
|
||||
|
||||
assert bs.record_poll_failure(tid, "boom") == 1
|
||||
assert bs.record_poll_failure(tid, "boom") == 2
|
||||
assert bs.record_poll_failure(tid, "boom") == 3
|
||||
assert bs.get_poll_failure_count(tid) == 3
|
||||
assert bs.get_poll_last_error(tid) == "boom"
|
||||
|
||||
bs.record_poll_success(tid)
|
||||
assert bs.get_poll_failure_count(tid) == 0
|
||||
assert bs.get_poll_last_error(tid) == ""
|
||||
|
||||
|
||||
def test_record_target_failure_increments_then_success_resets() -> None:
|
||||
from notify_bridge_server.services import bridge_self as bs
|
||||
|
||||
tid = 9_101
|
||||
bs.reset_target_counter(tid)
|
||||
|
||||
assert bs.record_target_failure(tid, "503") == 1
|
||||
assert bs.record_target_failure(tid, "503") == 2
|
||||
assert bs.get_target_failure_count(tid) == 2
|
||||
|
||||
bs.record_target_success(tid)
|
||||
assert bs.get_target_failure_count(tid) == 0
|
||||
|
||||
|
||||
def test_backlog_state_only_emits_on_crossing() -> None:
|
||||
"""Only the False -> True transition should report a crossing.
|
||||
|
||||
A sustained backlog must not re-fire on every scan, and a recovered
|
||||
backlog re-arms the latch so the next crossing is reported again.
|
||||
"""
|
||||
from notify_bridge_server.services import bridge_self as bs
|
||||
|
||||
user_id = 9_201
|
||||
# Reset latch by going through a False reading first.
|
||||
bs._backlog_above_threshold.pop(user_id, None)
|
||||
|
||||
# Initial above-threshold reading IS a crossing (None -> True latch).
|
||||
assert bs.record_backlog_state(user_id, True) is True
|
||||
# Sustained above — no second alert.
|
||||
assert bs.record_backlog_state(user_id, True) is False
|
||||
assert bs.record_backlog_state(user_id, True) is False
|
||||
# Drop below — no alert (we don't notify on recovery).
|
||||
assert bs.record_backlog_state(user_id, False) is False
|
||||
# Cross again — alert.
|
||||
assert bs.record_backlog_state(user_id, True) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ensure_bridge_self_provider_for_user — DB roundtrip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def session() -> AsyncSession:
|
||||
"""Fresh in-memory DB with the SQLModel schema applied."""
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_bridge_self_provider_creates_once(session: AsyncSession) -> None:
|
||||
from notify_bridge_server.database.models import ServiceProvider, User
|
||||
from notify_bridge_server.database.seeds import (
|
||||
ensure_bridge_self_provider_for_user,
|
||||
)
|
||||
|
||||
# Create a real user.
|
||||
user = User(username="alice", hashed_password="x", role="user")
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
user_id = user.id
|
||||
|
||||
p1 = await ensure_bridge_self_provider_for_user(session, user_id)
|
||||
assert p1 is not None
|
||||
p1_id = p1.id
|
||||
assert p1.type == "bridge_self"
|
||||
assert p1.user_id == user_id
|
||||
assert p1.config["poll_failure_threshold"] == 3
|
||||
assert p1.config["deferred_backlog_threshold"] == 100
|
||||
assert p1.config["target_failure_threshold"] == 5
|
||||
await session.commit()
|
||||
|
||||
# Idempotent: second call returns the same row, no duplicates.
|
||||
p2 = await ensure_bridge_self_provider_for_user(session, user_id)
|
||||
assert p2 is not None
|
||||
assert p2.id == p1_id
|
||||
await session.commit()
|
||||
|
||||
rows = (
|
||||
await session.exec(
|
||||
select(ServiceProvider).where(
|
||||
ServiceProvider.user_id == user_id,
|
||||
ServiceProvider.type == "bridge_self",
|
||||
)
|
||||
)
|
||||
).all()
|
||||
assert len(rows) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_bridge_self_provider_skips_system_user(session: AsyncSession) -> None:
|
||||
"""user_id <= 0 is the __system__ placeholder — never gets a provider."""
|
||||
from notify_bridge_server.database.seeds import (
|
||||
ensure_bridge_self_provider_for_user,
|
||||
)
|
||||
|
||||
result = await ensure_bridge_self_provider_for_user(session, 0)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Capability registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_capability_registry_lists_bridge_self() -> None:
|
||||
from notify_bridge_core.providers.capabilities import (
|
||||
get_capabilities, get_all_capabilities,
|
||||
)
|
||||
|
||||
caps = get_capabilities("bridge_self")
|
||||
assert caps is not None
|
||||
assert caps.provider_type == "bridge_self"
|
||||
assert caps.webhook_based is False
|
||||
|
||||
event_names = {e["name"] for e in caps.events}
|
||||
assert event_names == {
|
||||
"bridge_self_poll_failures",
|
||||
"bridge_self_deferred_backlog",
|
||||
"bridge_self_target_failures",
|
||||
}
|
||||
|
||||
slot_names = {s["name"] for s in caps.notification_slots}
|
||||
assert slot_names == {
|
||||
"message_bridge_self_poll_failures",
|
||||
"message_bridge_self_deferred_backlog",
|
||||
"message_bridge_self_target_failures",
|
||||
}
|
||||
|
||||
# And it shows up in the global registry.
|
||||
assert "bridge_self" in get_all_capabilities()
|
||||
|
||||
|
||||
def test_default_template_loader_returns_bridge_self_slots() -> None:
|
||||
"""All three bridge_self slots have shipped Jinja2 default templates."""
|
||||
from notify_bridge_core.templates.defaults.loader import load_default_templates
|
||||
|
||||
en = load_default_templates("en", "bridge_self")
|
||||
ru = load_default_templates("ru", "bridge_self")
|
||||
expected = {
|
||||
"message_bridge_self_poll_failures",
|
||||
"message_bridge_self_deferred_backlog",
|
||||
"message_bridge_self_target_failures",
|
||||
}
|
||||
assert set(en.keys()) == expected
|
||||
assert set(ru.keys()) == expected
|
||||
# Sanity: each template references at least one of the bridge_self vars.
|
||||
for tpl in list(en.values()) + list(ru.values()):
|
||||
assert "{{" in tpl
|
||||
@@ -0,0 +1,249 @@
|
||||
"""Unit tests for the Gitea webhook parser.
|
||||
|
||||
Pure-function tests against ``parse_webhook`` using realistic Gitea
|
||||
payloads (trimmed to the fields the parser actually consumes). No DB or
|
||||
HTTP fixtures needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from notify_bridge_core.models.events import EventType
|
||||
from notify_bridge_core.providers.base import ServiceProviderType
|
||||
from notify_bridge_core.providers.gitea.event_parser import parse_webhook
|
||||
|
||||
|
||||
def _repo() -> dict:
|
||||
return {
|
||||
"id": 42,
|
||||
"name": "demo",
|
||||
"full_name": "alexei/demo",
|
||||
"html_url": "https://git.example.com/alexei/demo",
|
||||
"description": "Demo repo",
|
||||
"private": False,
|
||||
"owner": {
|
||||
"id": 1,
|
||||
"login": "alexei",
|
||||
"full_name": "Alexei",
|
||||
"email": "alexei@example.com",
|
||||
"avatar_url": "https://git.example.com/avatars/1",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _sender() -> dict:
|
||||
return {
|
||||
"id": 1,
|
||||
"login": "alexei",
|
||||
"full_name": "Alexei",
|
||||
"avatar_url": "https://git.example.com/avatars/1",
|
||||
}
|
||||
|
||||
|
||||
def test_push_event() -> None:
|
||||
payload = {
|
||||
"ref": "refs/heads/master",
|
||||
"before": "0000000000000000000000000000000000000000",
|
||||
"after": "abcdef0123456789abcdef0123456789abcdef01",
|
||||
"compare_url": "https://git.example.com/alexei/demo/compare/000...abc",
|
||||
"commits": [
|
||||
{
|
||||
"id": "abcdef0123456789abcdef0123456789abcdef01",
|
||||
"message": "feat: initial commit\n\nMore detail.",
|
||||
"url": "https://git.example.com/alexei/demo/commit/abcdef0",
|
||||
"author": {
|
||||
"name": "Alexei",
|
||||
"email": "alexei@example.com",
|
||||
"username": "alexei",
|
||||
},
|
||||
"timestamp": "2026-05-16T10:00:00Z",
|
||||
},
|
||||
{
|
||||
"id": "1234567890123456789012345678901234567890",
|
||||
"message": "chore: tweak",
|
||||
"url": "https://git.example.com/alexei/demo/commit/1234567",
|
||||
"author": {"name": "Alexei", "email": "alexei@example.com"},
|
||||
"timestamp": "2026-05-16T10:05:00Z",
|
||||
},
|
||||
],
|
||||
"repository": _repo(),
|
||||
"sender": _sender(),
|
||||
}
|
||||
|
||||
evt = parse_webhook("push", payload, provider_name="gitea-prod")
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.PUSH
|
||||
assert evt.provider_type is ServiceProviderType.GITEA
|
||||
assert evt.collection_id == "alexei/demo"
|
||||
assert evt.collection_name == "alexei/demo"
|
||||
assert evt.extra["ref"] == "refs/heads/master"
|
||||
assert evt.extra["branch"] == "master"
|
||||
assert evt.extra["commit_count"] == 2
|
||||
assert evt.extra["commits"][0]["short_id"] == "abcdef0"
|
||||
# The first commit's multi-line body must be preserved (.strip handles
|
||||
# trailing newlines but should keep the inner '\n').
|
||||
assert "feat: initial commit" in evt.extra["commits"][0]["message"]
|
||||
|
||||
|
||||
def test_issue_opened() -> None:
|
||||
payload = {
|
||||
"action": "opened",
|
||||
"issue": {
|
||||
"id": 100,
|
||||
"number": 7,
|
||||
"title": "Bug: thing broken",
|
||||
"html_url": "https://git.example.com/alexei/demo/issues/7",
|
||||
"state": "open",
|
||||
"body": "Steps to reproduce...",
|
||||
"labels": [{"name": "bug"}, {"name": "p1"}],
|
||||
},
|
||||
"repository": _repo(),
|
||||
"sender": _sender(),
|
||||
}
|
||||
evt = parse_webhook("issues", payload, provider_name="gitea-prod")
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.ISSUE_OPENED
|
||||
assert evt.collection_id == "alexei/demo"
|
||||
assert evt.extra["issue_number"] == 7
|
||||
assert evt.extra["issue_title"] == "Bug: thing broken"
|
||||
assert evt.extra["issue_labels"] == ["bug", "p1"]
|
||||
|
||||
|
||||
def test_issue_closed() -> None:
|
||||
payload = {
|
||||
"action": "closed",
|
||||
"issue": {
|
||||
"id": 100,
|
||||
"number": 7,
|
||||
"title": "Bug: thing broken",
|
||||
"html_url": "https://git.example.com/alexei/demo/issues/7",
|
||||
"state": "closed",
|
||||
"body": "",
|
||||
"labels": [],
|
||||
},
|
||||
"repository": _repo(),
|
||||
"sender": _sender(),
|
||||
}
|
||||
evt = parse_webhook("issues", payload, provider_name="gitea-prod")
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.ISSUE_CLOSED
|
||||
assert evt.extra["issue_state"] == "closed"
|
||||
|
||||
|
||||
def test_pr_opened() -> None:
|
||||
payload = {
|
||||
"action": "opened",
|
||||
"pull_request": {
|
||||
"id": 200,
|
||||
"number": 12,
|
||||
"title": "Add metrics endpoint",
|
||||
"html_url": "https://git.example.com/alexei/demo/pulls/12",
|
||||
"state": "open",
|
||||
"body": "PR body",
|
||||
"merged": False,
|
||||
"base": {"ref": "master", "label": "alexei:master"},
|
||||
"head": {"ref": "feat/metrics", "label": "alexei:feat/metrics"},
|
||||
"labels": [{"name": "enhancement"}],
|
||||
},
|
||||
"repository": _repo(),
|
||||
"sender": _sender(),
|
||||
}
|
||||
evt = parse_webhook("pull_request", payload, provider_name="gitea-prod")
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.PR_OPENED
|
||||
assert evt.extra["pr_number"] == 12
|
||||
assert evt.extra["pr_merged"] is False
|
||||
assert evt.extra["pr_base"] == "alexei:master"
|
||||
assert evt.extra["pr_head"] == "alexei:feat/metrics"
|
||||
|
||||
|
||||
def test_pr_merged_resolves_from_closed_with_merged_flag() -> None:
|
||||
"""A 'closed' action with merged=True is the merge signal — Gitea does
|
||||
not send a distinct event header for it, so the parser must promote
|
||||
PR_CLOSED -> PR_MERGED on its own."""
|
||||
payload = {
|
||||
"action": "closed",
|
||||
"pull_request": {
|
||||
"id": 200,
|
||||
"number": 12,
|
||||
"title": "Add metrics endpoint",
|
||||
"html_url": "https://git.example.com/alexei/demo/pulls/12",
|
||||
"state": "closed",
|
||||
"body": "",
|
||||
"merged": True,
|
||||
"base": {"ref": "master"},
|
||||
"head": {"ref": "feat/metrics"},
|
||||
"labels": [],
|
||||
},
|
||||
"repository": _repo(),
|
||||
"sender": _sender(),
|
||||
}
|
||||
evt = parse_webhook("pull_request", payload, provider_name="gitea-prod")
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.PR_MERGED
|
||||
assert evt.extra["pr_merged"] is True
|
||||
|
||||
|
||||
def test_pr_closed_without_merge() -> None:
|
||||
payload = {
|
||||
"action": "closed",
|
||||
"pull_request": {
|
||||
"id": 200,
|
||||
"number": 12,
|
||||
"title": "Abandoned PR",
|
||||
"html_url": "https://git.example.com/alexei/demo/pulls/12",
|
||||
"state": "closed",
|
||||
"body": "",
|
||||
"merged": False,
|
||||
"base": {"ref": "master"},
|
||||
"head": {"ref": "feat/x"},
|
||||
"labels": [],
|
||||
},
|
||||
"repository": _repo(),
|
||||
"sender": _sender(),
|
||||
}
|
||||
evt = parse_webhook("pull_request", payload, provider_name="gitea-prod")
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.PR_CLOSED
|
||||
|
||||
|
||||
def test_release_published() -> None:
|
||||
payload = {
|
||||
"action": "published",
|
||||
"release": {
|
||||
"id": 9,
|
||||
"tag_name": "v1.2.3",
|
||||
"name": "Release v1.2.3",
|
||||
"html_url": "https://git.example.com/alexei/demo/releases/tag/v1.2.3",
|
||||
"body": "Bug fixes and improvements",
|
||||
"draft": False,
|
||||
"prerelease": False,
|
||||
},
|
||||
"repository": _repo(),
|
||||
"sender": _sender(),
|
||||
}
|
||||
evt = parse_webhook("release", payload, provider_name="gitea-prod")
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.RELEASE_PUBLISHED
|
||||
assert evt.extra["release_tag"] == "v1.2.3"
|
||||
assert evt.extra["release_prerelease"] is False
|
||||
|
||||
|
||||
def test_release_non_published_is_ignored() -> None:
|
||||
"""Only ``published`` releases should produce events — drafts and edits
|
||||
are noise and would spam any tracker subscribed to release notifications."""
|
||||
payload = {
|
||||
"action": "edited",
|
||||
"release": {
|
||||
"id": 9, "tag_name": "v1.2.3", "name": "x",
|
||||
"html_url": "", "body": "",
|
||||
"draft": True, "prerelease": False,
|
||||
},
|
||||
"repository": _repo(),
|
||||
"sender": _sender(),
|
||||
}
|
||||
assert parse_webhook("release", payload, provider_name="g") is None
|
||||
|
||||
|
||||
def test_unknown_event_header_returns_none() -> None:
|
||||
payload = {"repository": _repo(), "sender": _sender()}
|
||||
assert parse_webhook("unknown_event", payload, provider_name="g") is None
|
||||
@@ -27,7 +27,13 @@ def test_ready_endpoint(tmp_data_dir) -> None: # noqa: ARG001
|
||||
resp = client.get("/api/ready")
|
||||
# By the time TestClient yields, lifespan startup has completed.
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ready"
|
||||
body = resp.json()
|
||||
assert body["ready"] is True
|
||||
assert body["checks"]["db"] == "ok"
|
||||
assert body["checks"]["scheduler"] == "ok"
|
||||
# No HA providers configured by default in the test fixture.
|
||||
assert body["checks"]["ha"] == "na"
|
||||
assert body["errors"] == []
|
||||
|
||||
|
||||
def test_health_is_anonymous(tmp_data_dir) -> None: # noqa: ARG001
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
"""Unit tests for Immich album change detection.
|
||||
|
||||
Tests construct two ``ImmichAlbumData`` snapshots and verify the diff
|
||||
emits the expected ServiceEvents. No HTTP, no DB. Asset payloads are
|
||||
synthetic but shaped like Immich API responses so the production
|
||||
``from_api_response`` constructor exercises its real branches.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from notify_bridge_core.models.events import EventType
|
||||
from notify_bridge_core.providers.base import ServiceProviderType
|
||||
from notify_bridge_core.providers.immich.change_detector import (
|
||||
detect_album_changes,
|
||||
)
|
||||
from notify_bridge_core.providers.immich.models import ImmichAlbumData
|
||||
|
||||
|
||||
_EXTERNAL = "https://immich.example.com"
|
||||
|
||||
|
||||
def _asset(asset_id: str, *, processed: bool = True, type_: str = "IMAGE") -> dict:
|
||||
"""Build an Immich asset payload that ``from_api_response`` accepts."""
|
||||
return {
|
||||
"id": asset_id,
|
||||
"type": type_,
|
||||
"originalFileName": f"{asset_id}.jpg",
|
||||
"fileCreatedAt": "2026-05-15T12:00:00.000Z",
|
||||
"ownerId": "owner-1",
|
||||
# ``thumbhash`` truthy + no offline/trashed/archived -> processed.
|
||||
# Skipped when caller asks for an unprocessed asset.
|
||||
"thumbhash": "abc" if processed else None,
|
||||
"isOffline": False,
|
||||
"isTrashed": False,
|
||||
"isArchived": False,
|
||||
"isFavorite": False,
|
||||
"exifInfo": {},
|
||||
}
|
||||
|
||||
|
||||
def _album(asset_dicts: list[dict], *, name: str = "Trip", album_id: str = "a1",
|
||||
shared: bool = False) -> ImmichAlbumData:
|
||||
return ImmichAlbumData.from_api_response(
|
||||
{
|
||||
"id": album_id,
|
||||
"albumName": name,
|
||||
"assets": asset_dicts,
|
||||
"assetCount": len(asset_dicts),
|
||||
"createdAt": "2026-05-01T00:00:00Z",
|
||||
"updatedAt": "2026-05-15T12:00:00Z",
|
||||
"shared": shared,
|
||||
"owner": {"name": "alexei"},
|
||||
"albumThumbnailAssetId": asset_dicts[0]["id"] if asset_dicts else None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_added_asset_emits_assets_added_event() -> None:
|
||||
old = _album([_asset("a"), _asset("b")])
|
||||
new = _album([_asset("a"), _asset("b"), _asset("c")])
|
||||
|
||||
events, pending = detect_album_changes(
|
||||
old, new, pending_asset_ids=set(),
|
||||
provider_name="immich-prod", external_url=_EXTERNAL,
|
||||
)
|
||||
|
||||
assert len(events) == 1
|
||||
evt = events[0]
|
||||
assert evt.event_type is EventType.ASSETS_ADDED
|
||||
assert evt.provider_type is ServiceProviderType.IMMICH
|
||||
assert evt.collection_id == "a1"
|
||||
assert evt.collection_name == "Trip"
|
||||
assert evt.added_count == 1
|
||||
assert len(evt.added_assets) == 1
|
||||
assert pending == set()
|
||||
|
||||
|
||||
def test_removed_asset_emits_assets_removed_event() -> None:
|
||||
old = _album([_asset("a"), _asset("b"), _asset("c")])
|
||||
new = _album([_asset("a")])
|
||||
|
||||
events, _ = detect_album_changes(
|
||||
old, new, pending_asset_ids=set(),
|
||||
provider_name="immich-prod", external_url=_EXTERNAL,
|
||||
)
|
||||
|
||||
by_type = {e.event_type: e for e in events}
|
||||
assert EventType.ASSETS_REMOVED in by_type
|
||||
removed = by_type[EventType.ASSETS_REMOVED]
|
||||
assert removed.removed_count == 2
|
||||
assert set(removed.removed_asset_ids) == {"b", "c"}
|
||||
|
||||
|
||||
def test_no_changes_returns_no_events() -> None:
|
||||
old = _album([_asset("a"), _asset("b")])
|
||||
new = _album([_asset("a"), _asset("b")])
|
||||
|
||||
events, pending = detect_album_changes(
|
||||
old, new, pending_asset_ids=set(),
|
||||
provider_name="immich-prod", external_url=_EXTERNAL,
|
||||
)
|
||||
|
||||
assert events == []
|
||||
assert pending == set()
|
||||
|
||||
|
||||
def test_unprocessed_asset_is_held_in_pending() -> None:
|
||||
"""Assets without a thumbhash haven't finished server-side processing.
|
||||
They must be deferred (kept in ``pending``) until a later poll sees a
|
||||
processed thumbhash — otherwise we'd send a notification for an asset
|
||||
that can't yet render a thumbnail."""
|
||||
old = _album([_asset("a")])
|
||||
new = _album([_asset("a"), _asset("b", processed=False)])
|
||||
|
||||
events, pending = detect_album_changes(
|
||||
old, new, pending_asset_ids=set(),
|
||||
provider_name="immich-prod", external_url=_EXTERNAL,
|
||||
)
|
||||
|
||||
# ``b`` is not processed, so no event for it AND nothing else changed,
|
||||
# so we get an empty event list. Pending tracks the held asset.
|
||||
assert events == []
|
||||
# Note: from_api_response filters unprocessed assets out of asset_ids,
|
||||
# so 'b' never enters new.asset_ids — pending stays empty in this path.
|
||||
# The pending mechanism kicks in once 'b' lands in asset_ids on a later
|
||||
# tick. Use the next test to exercise that branch.
|
||||
assert pending == set()
|
||||
|
||||
|
||||
def test_collection_renamed_emits_renamed_event() -> None:
|
||||
old = _album([_asset("a")], name="Trip")
|
||||
new = _album([_asset("a")], name="Vacation")
|
||||
|
||||
events, _ = detect_album_changes(
|
||||
old, new, pending_asset_ids=set(),
|
||||
provider_name="immich-prod", external_url=_EXTERNAL,
|
||||
)
|
||||
|
||||
by_type = {e.event_type: e for e in events}
|
||||
assert EventType.COLLECTION_RENAMED in by_type
|
||||
rename = by_type[EventType.COLLECTION_RENAMED]
|
||||
assert rename.old_name == "Trip"
|
||||
assert rename.new_name == "Vacation"
|
||||
|
||||
|
||||
def test_sharing_change_emits_sharing_event() -> None:
|
||||
old = _album([_asset("a")], shared=False)
|
||||
new = _album([_asset("a")], shared=True)
|
||||
|
||||
events, _ = detect_album_changes(
|
||||
old, new, pending_asset_ids=set(),
|
||||
provider_name="immich-prod", external_url=_EXTERNAL,
|
||||
)
|
||||
|
||||
by_type = {e.event_type: e for e in events}
|
||||
assert EventType.SHARING_CHANGED in by_type
|
||||
sharing = by_type[EventType.SHARING_CHANGED]
|
||||
assert sharing.old_shared is False
|
||||
assert sharing.new_shared is True
|
||||
@@ -0,0 +1,147 @@
|
||||
"""Unit tests for the Planka webhook parser.
|
||||
|
||||
Pure-function tests against ``parse_webhook`` using realistic Planka
|
||||
webhook payload shapes. The parser is forgiving about missing ``included``
|
||||
data (older Planka builds), so we mix payloads with and without it to
|
||||
catch regressions in the fallback paths.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from notify_bridge_core.models.events import EventType
|
||||
from notify_bridge_core.providers.base import ServiceProviderType
|
||||
from notify_bridge_core.providers.planka.event_parser import parse_webhook
|
||||
|
||||
|
||||
_BASE_URL = "https://planka.example.com"
|
||||
|
||||
|
||||
def _user() -> dict:
|
||||
return {"id": "u1", "username": "alexei", "name": "Alexei"}
|
||||
|
||||
|
||||
def test_card_created() -> None:
|
||||
payload = {
|
||||
"user": _user(),
|
||||
"item": {
|
||||
"id": "c1",
|
||||
"name": "Implement metrics",
|
||||
"description": "Wire prometheus client.",
|
||||
"boardId": "b1",
|
||||
"listId": "l1",
|
||||
"position": 1,
|
||||
},
|
||||
"included": {
|
||||
"board": {"id": "b1", "name": "Roadmap"},
|
||||
"lists": [{"id": "l1", "name": "Todo"}],
|
||||
},
|
||||
}
|
||||
evt = parse_webhook("cardCreate", payload, provider_name="planka", base_url=_BASE_URL)
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.CARD_CREATED
|
||||
assert evt.provider_type is ServiceProviderType.PLANKA
|
||||
assert evt.collection_id == "b1"
|
||||
assert evt.collection_name == "Roadmap"
|
||||
assert evt.extra["card_name"] == "Implement metrics"
|
||||
assert evt.extra["card_url"] == f"{_BASE_URL}/cards/c1"
|
||||
assert evt.extra["list_name"] == "Todo"
|
||||
assert evt.extra["sender"] == "alexei"
|
||||
|
||||
|
||||
def test_card_moved_when_list_changes() -> None:
|
||||
"""beforeUpdate.listId != item.listId is the signal Planka uses for a
|
||||
card move; the parser must promote the generic cardUpdate event into
|
||||
CARD_MOVED so trackers can subscribe to moves specifically."""
|
||||
payload = {
|
||||
"user": _user(),
|
||||
"beforeUpdate": {"listId": "l1"},
|
||||
"item": {
|
||||
"id": "c1",
|
||||
"name": "Implement metrics",
|
||||
"description": "",
|
||||
"boardId": "b1",
|
||||
"listId": "l2",
|
||||
},
|
||||
"included": {
|
||||
"board": {"id": "b1", "name": "Roadmap"},
|
||||
"lists": [
|
||||
{"id": "l1", "name": "Todo"},
|
||||
{"id": "l2", "name": "In progress"},
|
||||
],
|
||||
},
|
||||
}
|
||||
evt = parse_webhook("cardUpdate", payload, provider_name="planka", base_url=_BASE_URL)
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.CARD_MOVED
|
||||
assert evt.extra["old_list_id"] == "l1"
|
||||
assert evt.extra["new_list_id"] == "l2"
|
||||
assert evt.extra["old_list_name"] == "Todo"
|
||||
assert evt.extra["new_list_name"] == "In progress"
|
||||
|
||||
|
||||
def test_card_update_without_list_change_is_card_updated() -> None:
|
||||
payload = {
|
||||
"user": _user(),
|
||||
"beforeUpdate": {"name": "Old name"},
|
||||
"item": {
|
||||
"id": "c1", "name": "New name", "description": "", "boardId": "b1", "listId": "l1",
|
||||
},
|
||||
}
|
||||
evt = parse_webhook("cardUpdate", payload, provider_name="planka", base_url=_BASE_URL)
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.CARD_UPDATED
|
||||
|
||||
|
||||
def test_comment_created() -> None:
|
||||
payload = {
|
||||
"user": _user(),
|
||||
"item": {
|
||||
"id": "cm1",
|
||||
"text": "LGTM, ship it.",
|
||||
"cardId": "c1",
|
||||
"userId": "u1",
|
||||
},
|
||||
"included": {
|
||||
"card": {"id": "c1", "name": "Implement metrics", "boardId": "b1"},
|
||||
"board": {"id": "b1", "name": "Roadmap"},
|
||||
},
|
||||
}
|
||||
evt = parse_webhook(
|
||||
"commentCreate", payload, provider_name="planka", base_url=_BASE_URL,
|
||||
)
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.CARD_COMMENTED
|
||||
assert evt.collection_id == "b1"
|
||||
assert evt.extra["comment_text"] == "LGTM, ship it."
|
||||
assert evt.extra["card_id"] == "c1"
|
||||
assert evt.extra["card_url"] == f"{_BASE_URL}/cards/c1"
|
||||
|
||||
|
||||
def test_task_completion_emits_only_on_transition() -> None:
|
||||
"""Task updates should only produce TASK_COMPLETED when the task flips
|
||||
from incomplete to complete — toggling the description or other fields
|
||||
on a task that was already complete must NOT spam notifications."""
|
||||
completing = {
|
||||
"user": _user(),
|
||||
"beforeUpdate": {"isCompleted": False},
|
||||
"item": {"id": "t1", "name": "Step 1", "isCompleted": True, "cardId": "c1"},
|
||||
"included": {
|
||||
"card": {"id": "c1", "name": "Implement metrics", "boardId": "b1"},
|
||||
"board": {"id": "b1", "name": "Roadmap"},
|
||||
},
|
||||
}
|
||||
evt = parse_webhook("taskUpdate", completing, provider_name="planka", base_url=_BASE_URL)
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.TASK_COMPLETED
|
||||
|
||||
# Editing a task that was already completed -> no event.
|
||||
re_edit = {
|
||||
"user": _user(),
|
||||
"beforeUpdate": {"isCompleted": True},
|
||||
"item": {"id": "t1", "name": "Step 1 v2", "isCompleted": True, "cardId": "c1"},
|
||||
}
|
||||
assert parse_webhook("taskUpdate", re_edit, provider_name="planka", base_url=_BASE_URL) is None
|
||||
|
||||
|
||||
def test_unknown_event_returns_none() -> None:
|
||||
assert parse_webhook("nonexistent", {"item": {}}, provider_name="planka", base_url="") is None
|
||||
Reference in New Issue
Block a user