feat: deferred dispatch, release-check provider, settings polish
- Defer quiet-hours dispatches into new deferred_dispatch table; drain job + periodic catch-up scan re-fire at window end with coalescing on (link, event_type, collection_id). - Add ON DELETE SET NULL migration on event_log_id and partial unique index on (link_id, collection_id, event_type) WHERE status='pending'. - Add release-check provider abstraction (Gitea/GitHub) with SSRF-safe URL validation, settings UI cassette, and scheduled polling. - Replace importlib-only version lookup with version.py helper that prefers the higher of installed metadata vs source pyproject so stale editable dev installs stop misreporting. - Aurora frontend polish: MetaStrip component, ReleaseCassette, EventDetailModal expansion, and i18n additions.
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
"""Upstream release-check providers.
|
||||
|
||||
This package is intentionally separate from :mod:`notify_bridge_core.providers`:
|
||||
|
||||
* service providers are user-configured entities persisted per-tenant in the DB;
|
||||
* release providers are admin-level upstream-version probes selected by setting,
|
||||
with at most one active provider per installation.
|
||||
|
||||
Mixing them in one enum/factory bled responsibilities and complicated future
|
||||
additions (e.g. a GitHub release provider that has nothing to do with Gitea
|
||||
service integrations).
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
ReleaseErrorCode,
|
||||
ReleaseInfo,
|
||||
ReleaseProvider,
|
||||
ReleaseProviderKind,
|
||||
ReleaseTestResult,
|
||||
is_valid_repo,
|
||||
)
|
||||
from .registry import build_release_provider
|
||||
|
||||
__all__ = [
|
||||
"ReleaseErrorCode",
|
||||
"ReleaseInfo",
|
||||
"ReleaseProvider",
|
||||
"ReleaseProviderKind",
|
||||
"ReleaseTestResult",
|
||||
"build_release_provider",
|
||||
"is_valid_repo",
|
||||
]
|
||||
@@ -0,0 +1,156 @@
|
||||
"""ReleaseProvider abstraction and shared tag/version utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import ClassVar, Protocol, TypedDict, runtime_checkable
|
||||
|
||||
|
||||
class ReleaseProviderKind(str, Enum):
|
||||
"""Supported upstream release-check providers."""
|
||||
|
||||
DISABLED = "disabled"
|
||||
GITEA = "gitea"
|
||||
GITHUB = "github"
|
||||
|
||||
|
||||
# Single source of truth for `release_error` taxonomy. Surfaced into the cached
|
||||
# `AppSetting`, returned via the API, and translated by the frontend.
|
||||
class ReleaseErrorCode(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
MISCONFIGURED = "misconfigured"
|
||||
PROVIDER_CHANGED = "provider_changed"
|
||||
NO_RELEASE_FOUND = "no_release_found"
|
||||
NETWORK_ERROR = "network_error"
|
||||
HTTP_ERROR = "http_error"
|
||||
PARSE_ERROR = "parse_error"
|
||||
UNSAFE_URL = "unsafe_url"
|
||||
NOT_IMPLEMENTED = "not_implemented"
|
||||
UNKNOWN_ERROR = "unknown_error"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReleaseInfo:
|
||||
"""Normalised release metadata returned by a provider."""
|
||||
|
||||
tag: str
|
||||
version: str
|
||||
name: str | None = None
|
||||
body: str | None = None
|
||||
url: str | None = None
|
||||
published_at: str | None = None
|
||||
prerelease: bool = False
|
||||
draft: bool = False
|
||||
|
||||
|
||||
class ReleaseTestResult(TypedDict):
|
||||
"""Structured shape returned by :meth:`ReleaseProvider.test`."""
|
||||
|
||||
ok: bool
|
||||
info: ReleaseInfo | None
|
||||
error: str | None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ReleaseProvider(Protocol):
|
||||
"""Protocol implemented by every release provider.
|
||||
|
||||
Implementations are expected to be safe to instantiate without external
|
||||
side effects — connectivity is deferred until :meth:`fetch_latest` or
|
||||
:meth:`test` is awaited.
|
||||
"""
|
||||
|
||||
kind: ClassVar[ReleaseProviderKind]
|
||||
|
||||
async def fetch_latest(self, *, include_prereleases: bool = False) -> ReleaseInfo | None:
|
||||
"""Return the latest release, or ``None`` if there is nothing to report."""
|
||||
|
||||
async def test(self) -> ReleaseTestResult:
|
||||
"""Probe the upstream and return a structured status payload."""
|
||||
|
||||
|
||||
# Owner/name validation — matches Gitea/GitHub's allowed identifier chars.
|
||||
_REPO_RE = re.compile(r"^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$")
|
||||
|
||||
|
||||
def is_valid_repo(repo: str) -> bool:
|
||||
"""``True`` when ``repo`` is a safe ``owner/name`` string (no path traversal)."""
|
||||
|
||||
return bool(repo) and _REPO_RE.match(repo) is not None
|
||||
|
||||
|
||||
_TAG_NUMERIC = re.compile(r"\d+")
|
||||
# Stop reading numeric segments at the first non-digit-non-dot character so
|
||||
# ``1.0a2`` doesn't get parsed as ``(1, 0, 2)``.
|
||||
_HEAD_SPLIT = re.compile(r"[^0-9.]")
|
||||
|
||||
|
||||
def normalise_version(tag: str) -> str:
|
||||
"""Strip a leading ``v`` from a tag (``"v1.2.3"`` → ``"1.2.3"``)."""
|
||||
|
||||
if not tag:
|
||||
return ""
|
||||
cleaned = tag.strip()
|
||||
if cleaned.startswith(("v", "V")) and len(cleaned) > 1 and cleaned[1].isdigit():
|
||||
cleaned = cleaned[1:]
|
||||
return cleaned
|
||||
|
||||
|
||||
def _split_version(version: str) -> tuple[tuple[int, ...], str]:
|
||||
"""Split a version into (numeric segments, prerelease suffix).
|
||||
|
||||
A non-empty prerelease suffix marks the version as pre-stable. We use it
|
||||
as a tie-break only — when numeric segments are equal a stable build
|
||||
sorts strictly newer than its pre-release counterpart (``0.7.2`` >
|
||||
``0.7.2-rc1``), which prevents the badge from flickering between
|
||||
"up to date" and "downgrade available" on installs that ship the GA.
|
||||
"""
|
||||
|
||||
if not version:
|
||||
return (), ""
|
||||
work = version.split("+", 1)[0]
|
||||
if "-" in work:
|
||||
head, _, suffix = work.partition("-")
|
||||
else:
|
||||
# Implicit prerelease form: ``1.0a2`` / ``1.0rc1``. Anything after the
|
||||
# first non-digit-non-dot is treated as the suffix.
|
||||
m = _HEAD_SPLIT.search(work)
|
||||
if m and m.start() > 0:
|
||||
head, suffix = work[: m.start()], work[m.start():]
|
||||
else:
|
||||
head, suffix = work, ""
|
||||
segments = tuple(int(n) for n in _TAG_NUMERIC.findall(head))
|
||||
return segments, suffix.strip()
|
||||
|
||||
|
||||
def compare_versions(a: str, b: str) -> int:
|
||||
"""Return ``1`` if ``a > b``, ``-1`` if ``a < b``, ``0`` if equal.
|
||||
|
||||
Numeric segments win. When numerically equal, *stable* (no suffix) beats
|
||||
*prerelease* (any non-empty suffix); two equally-prereleased versions
|
||||
compare equal — we deliberately do not order ``rc2`` over ``rc1`` because
|
||||
that requires real semver parsing and would only matter for downgrades.
|
||||
"""
|
||||
|
||||
sa, suffix_a = _split_version(normalise_version(a))
|
||||
sb, suffix_b = _split_version(normalise_version(b))
|
||||
length = max(len(sa), len(sb))
|
||||
for i in range(length):
|
||||
x = sa[i] if i < len(sa) else 0
|
||||
y = sb[i] if i < len(sb) else 0
|
||||
if x != y:
|
||||
return 1 if x > y else -1
|
||||
# Equal numerics — stable beats prerelease.
|
||||
if not suffix_a and suffix_b:
|
||||
return 1
|
||||
if suffix_a and not suffix_b:
|
||||
return -1
|
||||
return 0
|
||||
|
||||
|
||||
def is_newer(candidate: str, baseline: str) -> bool:
|
||||
"""``True`` when ``candidate`` is strictly newer than ``baseline``."""
|
||||
|
||||
return compare_versions(candidate, baseline) > 0
|
||||
@@ -0,0 +1,167 @@
|
||||
"""Gitea release provider — queries ``/api/v1/repos/{owner}/{repo}/releases``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import ClassVar
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..notifications.ssrf import UnsafeURLError, avalidate_outbound_url
|
||||
from .base import (
|
||||
ReleaseErrorCode,
|
||||
ReleaseInfo,
|
||||
ReleaseProviderKind,
|
||||
ReleaseTestResult,
|
||||
is_valid_repo,
|
||||
normalise_version,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Cap upstream response body — release lists are normally a few KB; anything
|
||||
# beyond this is either a misconfigured target or a malicious payload.
|
||||
_MAX_BODY_BYTES = 1_000_000
|
||||
|
||||
|
||||
class GiteaReleaseProvider:
|
||||
"""Anonymous Gitea release probe.
|
||||
|
||||
Hits the ``releases`` endpoint (not ``releases/latest``) because the latter
|
||||
skips pre-releases unconditionally — we want to honour the caller's
|
||||
``include_prereleases`` flag instead of relying on Gitea's filtering.
|
||||
"""
|
||||
|
||||
kind: ClassVar[ReleaseProviderKind] = ReleaseProviderKind.GITEA
|
||||
|
||||
def __init__(self, session: aiohttp.ClientSession, url: str, repo: str) -> None:
|
||||
if not url:
|
||||
raise ValueError("Gitea release provider requires a base URL")
|
||||
if not is_valid_repo(repo):
|
||||
raise ValueError(
|
||||
"Gitea release provider requires repo as 'owner/name' "
|
||||
"(alphanumerics, dot, dash, underscore only)"
|
||||
)
|
||||
self._session = session
|
||||
self._url = url.rstrip("/")
|
||||
self._repo = repo.strip("/")
|
||||
|
||||
@property
|
||||
def _endpoint(self) -> str:
|
||||
return f"{self._url}/api/v1/repos/{self._repo}/releases"
|
||||
|
||||
async def fetch_latest(self, *, include_prereleases: bool = False) -> ReleaseInfo | None:
|
||||
try:
|
||||
await avalidate_outbound_url(self._endpoint)
|
||||
except UnsafeURLError as err:
|
||||
_LOGGER.warning("Gitea release URL rejected by SSRF guard: %s", err)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with self._session.get(
|
||||
self._endpoint,
|
||||
params={"limit": "20", "page": "1", "draft": "false"},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
_LOGGER.warning(
|
||||
"Gitea releases fetch failed: HTTP %s for %s",
|
||||
response.status, self._endpoint,
|
||||
)
|
||||
return None
|
||||
# Enforce a size cap without trusting chunked encoding: read
|
||||
# the whole body (aiohttp buffers it) but reject anything that
|
||||
# advertised more than the cap up front, and bail if it grew
|
||||
# past the cap after the fact.
|
||||
if response.content_length is not None and response.content_length > _MAX_BODY_BYTES:
|
||||
_LOGGER.warning(
|
||||
"Gitea releases response advertised %d bytes — refusing",
|
||||
response.content_length,
|
||||
)
|
||||
return None
|
||||
raw = await response.read()
|
||||
if len(raw) > _MAX_BODY_BYTES:
|
||||
_LOGGER.warning(
|
||||
"Gitea releases response exceeded %d bytes — refusing to parse",
|
||||
_MAX_BODY_BYTES,
|
||||
)
|
||||
return None
|
||||
import json
|
||||
|
||||
payload = json.loads(raw.decode("utf-8"))
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as err:
|
||||
_LOGGER.warning("Gitea releases fetch error: %s", err)
|
||||
return None
|
||||
except (ValueError, UnicodeDecodeError) as err:
|
||||
_LOGGER.warning("Gitea releases parse error: %s", err)
|
||||
return None
|
||||
|
||||
if not isinstance(payload, list):
|
||||
return None
|
||||
|
||||
for entry in payload:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("draft"):
|
||||
continue
|
||||
if entry.get("prerelease") and not include_prereleases:
|
||||
continue
|
||||
return _to_release_info(entry)
|
||||
return None
|
||||
|
||||
async def test(self) -> ReleaseTestResult:
|
||||
# Validate URL first so the "test" button surfaces an SSRF rejection
|
||||
# to the operator rather than silently returning "unreachable".
|
||||
try:
|
||||
await avalidate_outbound_url(self._endpoint)
|
||||
except UnsafeURLError:
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.UNSAFE_URL.value}
|
||||
|
||||
try:
|
||||
async with self._session.get(
|
||||
self._endpoint,
|
||||
params={"limit": "1", "page": "1", "draft": "false"},
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.HTTP_ERROR.value}
|
||||
# Enforce a size cap without trusting chunked encoding: read
|
||||
# the whole body (aiohttp buffers it) but reject anything that
|
||||
# advertised more than the cap up front, and bail if it grew
|
||||
# past the cap after the fact.
|
||||
if response.content_length is not None and response.content_length > _MAX_BODY_BYTES:
|
||||
_LOGGER.warning(
|
||||
"Gitea releases response advertised %d bytes — refusing",
|
||||
response.content_length,
|
||||
)
|
||||
return None
|
||||
raw = await response.read()
|
||||
if len(raw) > _MAX_BODY_BYTES:
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.PARSE_ERROR.value}
|
||||
import json
|
||||
|
||||
payload = json.loads(raw.decode("utf-8"))
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError):
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.NETWORK_ERROR.value}
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.PARSE_ERROR.value}
|
||||
|
||||
if not isinstance(payload, list) or not payload:
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.NO_RELEASE_FOUND.value}
|
||||
first = payload[0]
|
||||
if not isinstance(first, dict):
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.PARSE_ERROR.value}
|
||||
return {"ok": True, "info": _to_release_info(first), "error": None}
|
||||
|
||||
|
||||
def _to_release_info(entry: dict) -> ReleaseInfo:
|
||||
tag = str(entry.get("tag_name") or "").strip()
|
||||
return ReleaseInfo(
|
||||
tag=tag,
|
||||
version=normalise_version(tag),
|
||||
name=entry.get("name") or None,
|
||||
body=entry.get("body") or None,
|
||||
url=entry.get("html_url") or None,
|
||||
published_at=entry.get("published_at") or entry.get("created_at") or None,
|
||||
prerelease=bool(entry.get("prerelease", False)),
|
||||
draft=bool(entry.get("draft", False)),
|
||||
)
|
||||
@@ -0,0 +1,34 @@
|
||||
"""GitHub release provider stub.
|
||||
|
||||
Reserved so the registry advertises the option and the frontend can render the
|
||||
provider toggle without a follow-up backend release. The full implementation
|
||||
will mirror :class:`GiteaReleaseProvider` against
|
||||
``api.github.com/repos/{owner}/{repo}/releases``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .base import ReleaseErrorCode, ReleaseInfo, ReleaseProviderKind, ReleaseTestResult
|
||||
|
||||
|
||||
class GitHubReleaseProvider:
|
||||
"""Not yet implemented — placeholder so the registry is forward-compatible."""
|
||||
|
||||
kind: ClassVar[ReleaseProviderKind] = ReleaseProviderKind.GITHUB
|
||||
|
||||
def __init__(self, session: aiohttp.ClientSession, repo: str) -> None:
|
||||
self._session = session
|
||||
self._repo = repo
|
||||
|
||||
async def fetch_latest(self, *, include_prereleases: bool = False) -> ReleaseInfo | None:
|
||||
# Soft-fail rather than raise — `run_check` already catches
|
||||
# NotImplementedError but a None return keeps the persisted
|
||||
# `release_error` taxonomy clean (NOT_IMPLEMENTED, not "not impl…").
|
||||
return None
|
||||
|
||||
async def test(self) -> ReleaseTestResult:
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.NOT_IMPLEMENTED.value}
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Factory for release providers — single entry point for callers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .base import ReleaseProvider, ReleaseProviderKind, is_valid_repo
|
||||
from .gitea import GiteaReleaseProvider
|
||||
from .github import GitHubReleaseProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
|
||||
def build_release_provider(
|
||||
kind: str | ReleaseProviderKind,
|
||||
*,
|
||||
session: aiohttp.ClientSession,
|
||||
url: str = "",
|
||||
repo: str = "",
|
||||
) -> ReleaseProvider | None:
|
||||
"""Build a release provider for the given kind.
|
||||
|
||||
Returns ``None`` when disabled or when required configuration is missing
|
||||
or unsafe (invalid repo format, empty URL) — callers treat the absence as
|
||||
"no checks performed" without branching on the kind string everywhere.
|
||||
"""
|
||||
|
||||
try:
|
||||
normalised = (
|
||||
ReleaseProviderKind(kind)
|
||||
if not isinstance(kind, ReleaseProviderKind)
|
||||
else kind
|
||||
)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
if normalised is ReleaseProviderKind.DISABLED:
|
||||
return None
|
||||
if normalised is ReleaseProviderKind.GITEA:
|
||||
if not url or not is_valid_repo(repo):
|
||||
return None
|
||||
try:
|
||||
return GiteaReleaseProvider(session=session, url=url, repo=repo)
|
||||
except ValueError:
|
||||
return None
|
||||
if normalised is ReleaseProviderKind.GITHUB:
|
||||
if not is_valid_repo(repo):
|
||||
return None
|
||||
return GitHubReleaseProvider(session=session, repo=repo)
|
||||
return None
|
||||
@@ -2,13 +2,18 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.notifications.ssrf import UnsafeURLError, avalidate_outbound_url
|
||||
from notify_bridge_core.release import ReleaseProviderKind, is_valid_repo
|
||||
|
||||
from ..auth.dependencies import get_current_user, require_admin
|
||||
from ..auth.routes import limiter # shared SlowAPI instance (app.state.limiter)
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import AppSetting, TelegramBot, User
|
||||
|
||||
@@ -28,6 +33,12 @@ _SETTING_KEYS = {
|
||||
"log_level": "NOTIFY_BRIDGE_LOG_LEVEL", # DEBUG/INFO/WARNING/ERROR
|
||||
"log_format": "NOTIFY_BRIDGE_LOG_FORMAT", # text|json (requires restart to switch)
|
||||
"log_levels": "NOTIFY_BRIDGE_LOG_LEVELS", # module=LEVEL,module2=LEVEL
|
||||
# Release-check — see services/release_check.py for the cached-state keys.
|
||||
"release_provider_kind": "NOTIFY_BRIDGE_RELEASE_PROVIDER", # disabled|gitea|github
|
||||
"release_provider_url": "NOTIFY_BRIDGE_RELEASE_PROVIDER_URL",
|
||||
"release_provider_repo": "NOTIFY_BRIDGE_RELEASE_PROVIDER_REPO",
|
||||
"release_include_prereleases": None, # "0"|"1"
|
||||
"release_check_interval_hours": None, # 1..168
|
||||
}
|
||||
|
||||
_DEFAULTS = {
|
||||
@@ -42,6 +53,13 @@ _DEFAULTS = {
|
||||
"log_level": "INFO",
|
||||
"log_format": "text",
|
||||
"log_levels": "",
|
||||
# Pre-seed Gitea release checks against this repo's own upstream so a fresh
|
||||
# install knows where to look without operator intervention.
|
||||
"release_provider_kind": "gitea",
|
||||
"release_provider_url": "https://git.dolgolyov-family.by",
|
||||
"release_provider_repo": "alexei.dolgolyov/notify-bridge",
|
||||
"release_include_prereleases": "0",
|
||||
"release_check_interval_hours": "12",
|
||||
}
|
||||
|
||||
# Settings whose changes require dropping in-memory Telegram caches so the
|
||||
@@ -53,6 +71,17 @@ _CACHE_SETTING_KEYS = {"telegram_cache_ttl_hours", "telegram_asset_cache_max_ent
|
||||
# changing it means swapping the handler formatter entirely.
|
||||
_LOG_SETTING_KEYS = {"log_level", "log_levels", "log_format"}
|
||||
|
||||
# Release-check settings whose change must trigger cache invalidation (so a
|
||||
# stale "latest version" doesn't linger after pointing at a new repo) and a
|
||||
# scheduler re-arm so the new interval/provider takes effect immediately.
|
||||
_RELEASE_PROVIDER_KEYS = {
|
||||
"release_provider_kind",
|
||||
"release_provider_url",
|
||||
"release_provider_repo",
|
||||
"release_include_prereleases",
|
||||
}
|
||||
_RELEASE_INTERVAL_KEY = "release_check_interval_hours"
|
||||
|
||||
|
||||
async def get_setting(session: AsyncSession, key: str) -> str:
|
||||
"""Read a setting from DB, falling back to env var then default."""
|
||||
@@ -81,6 +110,11 @@ class SettingsUpdate(BaseModel):
|
||||
log_level: str | None = None
|
||||
log_format: str | None = None
|
||||
log_levels: str | None = None
|
||||
release_provider_kind: str | None = None
|
||||
release_provider_url: str | None = None
|
||||
release_provider_repo: str | None = None
|
||||
release_include_prereleases: bool | int | str | None = None
|
||||
release_check_interval_hours: int | str | None = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
@@ -111,12 +145,65 @@ async def update_settings(
|
||||
old_cache_values = {k: await get_setting(session, k) for k in _CACHE_SETTING_KEYS}
|
||||
old_timezone = await get_setting(session, "timezone")
|
||||
old_log_values = {k: await get_setting(session, k) for k in _LOG_SETTING_KEYS}
|
||||
old_release_values = {k: await get_setting(session, k) for k in _RELEASE_PROVIDER_KEYS}
|
||||
old_release_interval = await get_setting(session, _RELEASE_INTERVAL_KEY)
|
||||
|
||||
for key in _SETTING_KEYS:
|
||||
value = getattr(body, key, None)
|
||||
if value is None:
|
||||
continue
|
||||
value_str = str(value)
|
||||
# Normalise per-key before storing so the cache keys always hold the
|
||||
# canonical wire format ("0"/"1" for bool flags, clamped int for the
|
||||
# release interval). Without this, str(True) would leak "True" into the
|
||||
# release_include_prereleases cell and silently disable filtering.
|
||||
if key == "release_include_prereleases":
|
||||
if isinstance(value, bool):
|
||||
value_str = "1" if value else "0"
|
||||
else:
|
||||
value_str = "1" if str(value).strip().lower() in ("1", "true", "yes", "on") else "0"
|
||||
elif key == "release_check_interval_hours":
|
||||
from ..services.release_check import parse_interval_hours
|
||||
value_str = str(parse_interval_hours(str(value)))
|
||||
elif key == "release_provider_kind":
|
||||
# Reject anything outside the enum so a typo doesn't leave the DB
|
||||
# in a state the service can't interpret.
|
||||
value_str = str(value).strip().lower()
|
||||
try:
|
||||
value_str = ReleaseProviderKind(value_str).value
|
||||
except ValueError as err:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid release_provider_kind: {value_str!r}",
|
||||
) from err
|
||||
elif key == "release_provider_url":
|
||||
value_str = str(value).strip()
|
||||
if value_str:
|
||||
# Reject embedded userinfo (http://user:pass@host) so the
|
||||
# GET /settings response can never echo credentials back, and
|
||||
# block private/loopback/metadata targets via the SSRF guard.
|
||||
parsed = urlparse(value_str)
|
||||
if parsed.username or parsed.password:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="release_provider_url must not contain credentials",
|
||||
)
|
||||
try:
|
||||
await avalidate_outbound_url(value_str)
|
||||
except UnsafeURLError as err:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid release_provider_url: {err}",
|
||||
) from err
|
||||
elif key == "release_provider_repo":
|
||||
value_str = str(value).strip()
|
||||
if value_str and not is_valid_repo(value_str):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="release_provider_repo must match 'owner/name' "
|
||||
"(alphanumerics, dot, dash, underscore only)",
|
||||
)
|
||||
else:
|
||||
value_str = str(value)
|
||||
# GET masks the webhook secret as "***<last4>" so the real value is
|
||||
# never exposed to the frontend. If the client sends the mask back
|
||||
# (which happens on every save, since bind:value holds whatever GET
|
||||
@@ -182,6 +269,27 @@ async def update_settings(
|
||||
if new_base_url and (new_base_url != old_base_url or new_secret != old_secret):
|
||||
await _reregister_webhooks(session, new_base_url, new_secret)
|
||||
|
||||
# Release-check: clear stale cache when the provider repo/url/kind changes,
|
||||
# and re-arm the periodic job whenever the interval or provider changes.
|
||||
new_release_values = {k: await get_setting(session, k) for k in _RELEASE_PROVIDER_KEYS}
|
||||
new_release_interval = await get_setting(session, _RELEASE_INTERVAL_KEY)
|
||||
release_provider_changed = new_release_values != old_release_values
|
||||
release_interval_changed = new_release_interval != old_release_interval
|
||||
if release_provider_changed:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from notify_bridge_core.release import ReleaseErrorCode
|
||||
|
||||
from ..services.release_check import persist_release_state
|
||||
await persist_release_state(
|
||||
checked_at=datetime.now(timezone.utc).isoformat(),
|
||||
error=ReleaseErrorCode.PROVIDER_CHANGED.value,
|
||||
info=None,
|
||||
)
|
||||
if release_provider_changed or release_interval_changed:
|
||||
from ..services.scheduler import reschedule_release_check
|
||||
await reschedule_release_check()
|
||||
|
||||
result = {}
|
||||
for key in _SETTING_KEYS:
|
||||
result[key] = await get_setting(session, key)
|
||||
@@ -231,6 +339,122 @@ async def get_external_url(
|
||||
return {"external_url": (await get_setting(session, "external_url")).rstrip("/")}
|
||||
|
||||
|
||||
def _status_payload(status, *, is_admin: bool) -> dict:
|
||||
"""Serialise a :class:`ReleaseStatus` for the API.
|
||||
|
||||
Non-admin payloads strip the upstream release body (an XSS landmine —
|
||||
arbitrary attacker-controlled markdown should never reach a non-admin
|
||||
UI unless we explicitly sanitise it for display) and replace the raw
|
||||
error string with a coarse ``error`` / ``ok`` marker so internal
|
||||
hostnames from probe failures can't leak via the badge.
|
||||
"""
|
||||
payload = {
|
||||
"provider": status.provider,
|
||||
"current": status.current,
|
||||
"latest": status.latest,
|
||||
"latest_tag": status.latest_tag,
|
||||
"latest_url": status.latest_url,
|
||||
"latest_name": status.latest_name,
|
||||
"latest_published_at": status.latest_published_at,
|
||||
"latest_prerelease": status.latest_prerelease,
|
||||
"checked_at": status.checked_at,
|
||||
"update_available": status.update_available,
|
||||
}
|
||||
if is_admin:
|
||||
payload["latest_body"] = status.latest_body
|
||||
payload["error"] = status.error
|
||||
else:
|
||||
payload["latest_body"] = None
|
||||
payload["error"] = None if not status.error else "error"
|
||||
return payload
|
||||
|
||||
|
||||
@router.get("/release")
|
||||
async def get_release_status(
|
||||
user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return the cached upstream release status (no network call).
|
||||
|
||||
Available to all authenticated users so the sidebar badge can render for
|
||||
everyone — admins manage the configuration but the awareness is global.
|
||||
"""
|
||||
from ..services.release_check import load_status
|
||||
return _status_payload(await load_status(), is_admin=(user.role == "admin"))
|
||||
|
||||
|
||||
@router.post("/release/check")
|
||||
@limiter.limit("6/minute")
|
||||
async def force_release_check(
|
||||
request: Request,
|
||||
user: User = Depends(require_admin),
|
||||
):
|
||||
"""Force an immediate upstream check and return the refreshed status."""
|
||||
from ..services.release_check import run_check
|
||||
status = await run_check(force=True)
|
||||
return _status_payload(status, is_admin=True)
|
||||
|
||||
|
||||
class ReleaseTestRequest(BaseModel):
|
||||
provider_kind: str
|
||||
provider_url: str | None = None
|
||||
provider_repo: str | None = None
|
||||
include_prereleases: bool | None = False
|
||||
|
||||
|
||||
@router.post("/release/test")
|
||||
@limiter.limit("12/minute")
|
||||
async def test_release_provider(
|
||||
request: Request,
|
||||
body: ReleaseTestRequest,
|
||||
user: User = Depends(require_admin),
|
||||
):
|
||||
"""Dry-run an arbitrary provider config — used by the cassette's Test button.
|
||||
|
||||
Validates the provider URL on the spot (SSRF + userinfo) so the operator
|
||||
sees an actionable error before any outbound request fires.
|
||||
"""
|
||||
from notify_bridge_core.release import ReleaseErrorCode, build_release_provider
|
||||
|
||||
from ..services.http_session import get_http_session
|
||||
|
||||
test_url = (body.provider_url or "").strip()
|
||||
test_repo = (body.provider_repo or "").strip()
|
||||
|
||||
if test_repo and not is_valid_repo(test_repo):
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.MISCONFIGURED.value}
|
||||
if test_url:
|
||||
parsed = urlparse(test_url)
|
||||
if parsed.username or parsed.password:
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.UNSAFE_URL.value}
|
||||
try:
|
||||
await avalidate_outbound_url(test_url)
|
||||
except UnsafeURLError:
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.UNSAFE_URL.value}
|
||||
|
||||
http = await get_http_session()
|
||||
provider = build_release_provider(
|
||||
body.provider_kind,
|
||||
session=http,
|
||||
url=test_url,
|
||||
repo=test_repo,
|
||||
)
|
||||
if provider is None:
|
||||
return {"ok": False, "info": None, "error": ReleaseErrorCode.MISCONFIGURED.value}
|
||||
result = await provider.test()
|
||||
info = result.get("info")
|
||||
info_dict = None
|
||||
if info is not None:
|
||||
info_dict = {
|
||||
"tag": info.tag,
|
||||
"version": info.version,
|
||||
"name": info.name,
|
||||
"url": info.url,
|
||||
"published_at": info.published_at,
|
||||
"prerelease": info.prerelease,
|
||||
}
|
||||
return {"ok": result["ok"], "info": info_dict, "error": result.get("error")}
|
||||
|
||||
|
||||
async def _reregister_webhooks(
|
||||
session: AsyncSession, base_url: str, secret: str
|
||||
) -> None:
|
||||
|
||||
@@ -28,8 +28,9 @@ from ..database.models import (
|
||||
WebhookPayloadLog,
|
||||
)
|
||||
from ..services.dispatch_helpers import (
|
||||
GateReason,
|
||||
apply_tracking_display_filters,
|
||||
event_allowed_by_config,
|
||||
evaluate_event_gate,
|
||||
get_app_timezone,
|
||||
load_link_data,
|
||||
)
|
||||
@@ -164,7 +165,16 @@ async def _dispatch_webhook_event(
|
||||
Number of successfully dispatched notifications.
|
||||
"""
|
||||
dispatched = 0
|
||||
# ``defers_to_schedule`` is collected during the loop and flushed AFTER the
|
||||
# main session commits — the only side-effect of failing to schedule is a
|
||||
# delayed delivery (the startup loader / catch-up scan will reschedule),
|
||||
# so this is best-effort and must not roll back the DB writes.
|
||||
defers_to_schedule: set[Any] = set()
|
||||
async with AsyncSession(engine) as session:
|
||||
# App timezone is identical across trackers within one webhook request;
|
||||
# pull it once.
|
||||
app_tz = await get_app_timezone(session)
|
||||
|
||||
tracker_result = await session.exec(
|
||||
select(NotificationTracker).where(
|
||||
NotificationTracker.provider_id == provider_id,
|
||||
@@ -173,6 +183,8 @@ async def _dispatch_webhook_event(
|
||||
)
|
||||
trackers = tracker_result.all()
|
||||
|
||||
from ..services.deferred_dispatch import defer_event, is_deferrable
|
||||
|
||||
for tracker in trackers:
|
||||
filters = tracker.filters or {}
|
||||
if not _passes_filters(event, filters):
|
||||
@@ -185,11 +197,9 @@ async def _dispatch_webhook_event(
|
||||
if not link_data:
|
||||
continue
|
||||
|
||||
app_tz = await get_app_timezone(session)
|
||||
|
||||
# Log event
|
||||
extra_details = {k: v for k, v in event.extra.items() if k in detail_keys}
|
||||
session.add(EventLog(
|
||||
event_log_row = EventLog(
|
||||
user_id=tracker.user_id,
|
||||
tracker_id=tracker.id,
|
||||
tracker_name=tracker.name,
|
||||
@@ -203,18 +213,90 @@ async def _dispatch_webhook_event(
|
||||
"provider_type": event.provider_type.value,
|
||||
**extra_details,
|
||||
},
|
||||
))
|
||||
)
|
||||
session.add(event_log_row)
|
||||
await session.flush()
|
||||
event_log_id = event_log_row.id
|
||||
|
||||
# Dispatch to targets
|
||||
# Dedupe defers by parent ``link_id``: broadcast links emit one
|
||||
# ``link_data`` entry per child, all sharing the same parent id —
|
||||
# the deferred row is one-per-link, so we only call ``defer_event``
|
||||
# once per distinct id (earliest fire_at wins on ties).
|
||||
groups: dict[int, tuple[Any, list[TargetConfig]]] = {}
|
||||
defers_for_event: dict[int, Any] = {}
|
||||
for ld in link_data:
|
||||
tc = ld["tracking_config"]
|
||||
if tc is not None:
|
||||
outcome = evaluate_event_gate(event, tc, app_tz)
|
||||
if outcome.reason is GateReason.QUIET_HOURS:
|
||||
if is_deferrable(event.event_type.value) and outcome.quiet_hours_end_at is not None:
|
||||
link_id = ld.get("link_id")
|
||||
if link_id is not None:
|
||||
prior = defers_for_event.get(link_id)
|
||||
if prior is None or outcome.quiet_hours_end_at < prior:
|
||||
defers_for_event[link_id] = outcome.quiet_hours_end_at
|
||||
continue
|
||||
if outcome.reason is GateReason.EVENT_TYPE_DISABLED:
|
||||
continue
|
||||
|
||||
tmpl = ld["template_config"]
|
||||
target_cfg = TargetConfig(
|
||||
type=ld["target_type"],
|
||||
config=ld["target_config"],
|
||||
template_slots=ld["template_slots"],
|
||||
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
|
||||
date_only_format=tmpl.date_only_format if tmpl and tmpl.date_only_format else "%d.%m.%Y",
|
||||
provider_api_key=provider_config.get("api_token"),
|
||||
provider_internal_url=provider_config.get("url", ""),
|
||||
provider_external_url=provider_config.get("url", ""),
|
||||
receivers=ld["receivers"],
|
||||
)
|
||||
key = id(tc) if tc is not None else 0
|
||||
if key not in groups:
|
||||
groups[key] = (tc, [])
|
||||
groups[key][1].append(target_cfg)
|
||||
|
||||
# Persist defers + stamp event_log dispatch_status in the same
|
||||
# session that holds the EventLog row, so the "deferred" badge
|
||||
# only appears if the underlying queue rows actually exist.
|
||||
if defers_for_event:
|
||||
earliest = min(defers_for_event.values())
|
||||
for link_id, fire_at in defers_for_event.items():
|
||||
await defer_event(
|
||||
session,
|
||||
event=event,
|
||||
user_id=tracker.user_id,
|
||||
tracker_id=tracker.id,
|
||||
link_id=link_id,
|
||||
event_log_id=event_log_id,
|
||||
fire_at=fire_at,
|
||||
)
|
||||
details = dict(event_log_row.details or {})
|
||||
if not details.get("dispatch_status"):
|
||||
details["dispatch_status"] = "deferred"
|
||||
details["deferred_until"] = earliest.isoformat()
|
||||
event_log_row.details = details
|
||||
session.add(event_log_row)
|
||||
defers_to_schedule.update(defers_for_event.values())
|
||||
|
||||
# Dispatch to targets. Isolate dispatcher exceptions per group so
|
||||
# a failed remote call doesn't bubble out, abort the surrounding
|
||||
# transaction, and roll back the just-written defers/event_log.
|
||||
from ..services.http_session import get_http_session
|
||||
dispatcher = NotificationDispatcher(session=await get_http_session())
|
||||
for tc, target_configs in _build_target_groups(event, link_data, provider_config, app_tz):
|
||||
for tc, target_configs in groups.values():
|
||||
if not target_configs:
|
||||
continue
|
||||
shaped_event = apply_tracking_display_filters(event, tc)
|
||||
if shaped_event is None:
|
||||
continue
|
||||
results = await dispatcher.dispatch(shaped_event, target_configs)
|
||||
try:
|
||||
results = await dispatcher.dispatch(shaped_event, target_configs)
|
||||
except Exception as err: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"Dispatcher raised for tracker %d: %s", tracker.id, err,
|
||||
)
|
||||
continue
|
||||
for r in results:
|
||||
if r.get("success"):
|
||||
dispatched += 1
|
||||
@@ -226,6 +308,18 @@ async def _dispatch_webhook_event(
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Schedule drain jobs OUTSIDE the DB session so an APScheduler hiccup
|
||||
# can't roll back the persisted defer rows.
|
||||
if defers_to_schedule:
|
||||
from ..services.scheduler import schedule_deferred_drain
|
||||
for fire_at in defers_to_schedule:
|
||||
try:
|
||||
schedule_deferred_drain(fire_at)
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"Failed to schedule deferred drain for %s", fire_at,
|
||||
)
|
||||
|
||||
return dispatched
|
||||
|
||||
|
||||
@@ -554,41 +648,3 @@ async def generic_webhook(token: str, request: Request):
|
||||
await log_session.commit()
|
||||
|
||||
return {"ok": True, "dispatched": dispatched}
|
||||
|
||||
|
||||
def _build_target_groups(
|
||||
event: ServiceEvent,
|
||||
link_data: list[dict[str, Any]],
|
||||
provider_config: dict[str, Any],
|
||||
app_tz: str = "UTC",
|
||||
) -> list[tuple[Any, list[TargetConfig]]]:
|
||||
"""Build TargetConfigs for dispatch, grouped by their TrackingConfig.
|
||||
|
||||
Targets sharing a TrackingConfig dispatch together so a single
|
||||
``apply_tracking_display_filters`` pass can shape one event for the
|
||||
whole group; targets with different TCs may see differently-shaped
|
||||
events (e.g. one with favorites_only, one without).
|
||||
"""
|
||||
groups: dict[int, tuple[Any, list[TargetConfig]]] = {}
|
||||
for ld in link_data:
|
||||
tc = ld["tracking_config"]
|
||||
if tc and not event_allowed_by_config(event, tc, app_tz):
|
||||
continue
|
||||
|
||||
tmpl = ld["template_config"]
|
||||
target_cfg = TargetConfig(
|
||||
type=ld["target_type"],
|
||||
config=ld["target_config"],
|
||||
template_slots=ld["template_slots"],
|
||||
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
|
||||
date_only_format=tmpl.date_only_format if tmpl and tmpl.date_only_format else "%d.%m.%Y",
|
||||
provider_api_key=provider_config.get("api_token"),
|
||||
provider_internal_url=provider_config.get("url", ""),
|
||||
provider_external_url=provider_config.get("url", ""),
|
||||
receivers=ld["receivers"],
|
||||
)
|
||||
key = id(tc) if tc is not None else 0
|
||||
if key not in groups:
|
||||
groups[key] = (tc, [])
|
||||
groups[key][1].append(target_cfg)
|
||||
return list(groups.values())
|
||||
|
||||
@@ -1369,6 +1369,12 @@ _INDEXES: list[tuple[str, str, str]] = [
|
||||
("ix_command_template_slot_config_id", "command_template_slot", "config_id"),
|
||||
("ix_action_rule_action_id", "action_rule", "action_id"),
|
||||
("ix_action_execution_action_started", "action_execution", "action_id, started_at DESC"),
|
||||
# Deferred-dispatch drain: WHERE status = 'pending' AND fire_at <= ?
|
||||
# ORDER BY fire_at. The composite (status, fire_at) is the only access
|
||||
# pattern; an individual fire_at index isn't needed.
|
||||
("ix_deferred_dispatch_status_fire_at", "deferred_dispatch", "status, fire_at"),
|
||||
("ix_deferred_dispatch_link_id", "deferred_dispatch", "link_id"),
|
||||
("ix_deferred_dispatch_event_log_id", "deferred_dispatch", "event_log_id"),
|
||||
]
|
||||
|
||||
|
||||
@@ -1397,6 +1403,95 @@ async def migrate_performance_indexes(engine: AsyncEngine) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def migrate_deferred_dispatch_event_log_fk(engine: AsyncEngine) -> None:
|
||||
"""Rebuild ``deferred_dispatch`` if its event_log FK lacks ON DELETE SET NULL.
|
||||
|
||||
Early builds of this feature created the table with a default ``NO ACTION``
|
||||
FK on ``event_log_id``. The daily event_log cleanup deletes rows past the
|
||||
retention horizon — with SQLite's enforced foreign_keys PRAGMA, a pending
|
||||
DeferredDispatch row pointing at an aging-out event_log row would block
|
||||
the cleanup with an FK violation.
|
||||
|
||||
SQLite can't ALTER a constraint without rebuilding the table. The table
|
||||
has zero rows in any prod install old enough to need this fix (the
|
||||
feature shipped in the same release as this migration), so a drop +
|
||||
recreate via ``create_all`` is safe.
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
if not await _has_table(conn, "deferred_dispatch"):
|
||||
return
|
||||
# Read the original CREATE TABLE SQL to see whether SET NULL is wired.
|
||||
row = await conn.run_sync(
|
||||
lambda sync_conn: sync_conn.execute(
|
||||
text(
|
||||
"SELECT sql FROM sqlite_master "
|
||||
"WHERE type='table' AND name='deferred_dispatch'"
|
||||
)
|
||||
).fetchone()
|
||||
)
|
||||
ddl = (row[0] or "") if row else ""
|
||||
if "ON DELETE SET NULL" in ddl.upper():
|
||||
return
|
||||
# Confirm there's nothing to migrate — refuse to drop a populated
|
||||
# table even though the schema was wrong. Better to leave a warning
|
||||
# than to lose state.
|
||||
count_row = await conn.run_sync(
|
||||
lambda sync_conn: sync_conn.execute(
|
||||
text("SELECT COUNT(*) FROM deferred_dispatch")
|
||||
).fetchone()
|
||||
)
|
||||
if count_row and count_row[0]:
|
||||
logger.warning(
|
||||
"deferred_dispatch FK is missing ON DELETE SET NULL but the "
|
||||
"table holds %d rows; not auto-dropping. Inspect manually.",
|
||||
count_row[0],
|
||||
)
|
||||
return
|
||||
await conn.execute(text("DROP TABLE deferred_dispatch"))
|
||||
logger.info(
|
||||
"Dropped deferred_dispatch (empty) so create_all rebuilds it "
|
||||
"with ON DELETE SET NULL on event_log_id",
|
||||
)
|
||||
# Recreate the table from the SQLModel metadata in this same txn.
|
||||
from sqlmodel import SQLModel
|
||||
# Ensure the model is registered on metadata before we ask create_all
|
||||
# to build it. Lazy import to avoid a circular at module load time.
|
||||
from .models import DeferredDispatch # noqa: F401
|
||||
await conn.run_sync(
|
||||
SQLModel.metadata.create_all,
|
||||
tables=[SQLModel.metadata.tables["deferred_dispatch"]],
|
||||
)
|
||||
|
||||
|
||||
async def migrate_deferred_dispatch_unique_pending(engine: AsyncEngine) -> None:
|
||||
"""Add a partial unique index preventing duplicate pending defers.
|
||||
|
||||
Without this, two webhook handlers (or a webhook racing the watcher)
|
||||
can both call ``_find_pending_asset_rows`` and find nothing, then both
|
||||
INSERT — defeating coalescing. The partial index makes the second
|
||||
INSERT raise ``IntegrityError`` and the caller's transaction abort,
|
||||
after which a retry will see the now-visible row.
|
||||
|
||||
SQLite has supported ``CREATE UNIQUE INDEX ... WHERE ...`` since 3.8.
|
||||
Once the table exists this is safe to run on every boot.
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
if not await _has_table(conn, "deferred_dispatch"):
|
||||
return
|
||||
try:
|
||||
await conn.execute(text(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS "
|
||||
"ux_deferred_dispatch_pending "
|
||||
"ON deferred_dispatch(link_id, collection_id, event_type) "
|
||||
"WHERE status = 'pending'"
|
||||
))
|
||||
except Exception: # pragma: no cover — log and continue
|
||||
logger.warning(
|
||||
"Failed to create partial unique index on deferred_dispatch",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def migrate_chat_action_to_column(engine: AsyncEngine) -> None:
|
||||
"""Move ``chat_action`` from ``config`` JSON to the dedicated column.
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import UniqueConstraint, Text
|
||||
from sqlalchemy import ForeignKey, UniqueConstraint, Text
|
||||
from sqlmodel import JSON, Column, Field, SQLModel
|
||||
|
||||
|
||||
@@ -494,6 +494,64 @@ class CommandTrackerListener(SQLModel, table=True):
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
class DeferredDispatch(SQLModel, table=True):
|
||||
"""A dispatch held back by quiet hours, waiting for the window to end.
|
||||
|
||||
One row per ``(link, event_type, collection_id)`` for asset events — newly
|
||||
arriving events for the same key coalesce into the existing row's
|
||||
``event_payload`` (union of added/removed asset sets) instead of inserting
|
||||
a duplicate row. Non-asset events (push, pr_opened, ups_*, …) get a fresh
|
||||
row each time because they aren't logically cancellable.
|
||||
|
||||
At drain time the scheduler picks up rows where ``status='pending'`` and
|
||||
``fire_at <= now``, re-resolves the link/target/config against current
|
||||
state (so subsequent config edits apply), and dispatches.
|
||||
"""
|
||||
|
||||
__tablename__ = "deferred_dispatch"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int | None = Field(default=None, foreign_key="user.id", index=True)
|
||||
tracker_id: int = Field(foreign_key="notification_tracker.id", index=True)
|
||||
# The specific link this deferral targets. On drain we re-fetch by ID; if
|
||||
# the link was disabled or removed in the meantime we drop with a
|
||||
# ``deferred_then_dropped`` log row instead of dispatching to nothing.
|
||||
link_id: int = Field(
|
||||
foreign_key="notification_tracker_target.id", index=True,
|
||||
)
|
||||
# The event_log row written when the event was first detected. The drain
|
||||
# writes a follow-up event_log row referencing this id so the dashboard
|
||||
# can show "delivered at HH:MM, originally detected at HH:MM".
|
||||
#
|
||||
# ``ondelete="SET NULL"`` matters because the daily ``_cleanup_old_events``
|
||||
# job hard-deletes event_log rows past the retention horizon. Without
|
||||
# SET NULL, an old pending DeferredDispatch row referencing an aging-out
|
||||
# event_log row would either (a) prevent the delete with an FK violation
|
||||
# under SQLite's enforced foreign_keys PRAGMA, or (b) leave a dangling
|
||||
# reference on engines that don't enforce.
|
||||
event_log_id: int | None = Field(
|
||||
default=None,
|
||||
sa_column=Column(
|
||||
"event_log_id",
|
||||
ForeignKey("event_log.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
),
|
||||
)
|
||||
event_type: str = Field(index=True)
|
||||
collection_id: str = Field(default="", index=True)
|
||||
# ``dataclasses.asdict(ServiceEvent)`` with datetime/enum normalisation —
|
||||
# round-tripped via the helpers in ``services.deferred_dispatch``.
|
||||
event_payload: dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON))
|
||||
fire_at: datetime = Field(index=True)
|
||||
# ``pending`` until the drain runs; then ``fired``, ``dropped`` (link
|
||||
# gone / event-type disabled after defer), or ``cancelled`` (coalesced
|
||||
# away by a counter-event).
|
||||
status: str = Field(default="pending", index=True)
|
||||
fired_at: datetime | None = Field(default=None)
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
class EventLog(SQLModel, table=True):
|
||||
"""Log of detected events."""
|
||||
|
||||
|
||||
@@ -76,6 +76,8 @@ async def lifespan(app: FastAPI):
|
||||
migrate_user_token_version,
|
||||
migrate_performance_indexes,
|
||||
migrate_chat_action_to_column,
|
||||
migrate_deferred_dispatch_event_log_fk,
|
||||
migrate_deferred_dispatch_unique_pending,
|
||||
migrate_schema_version,
|
||||
)
|
||||
from .database.snapshot import snapshot_and_prune
|
||||
@@ -100,6 +102,11 @@ async def lifespan(app: FastAPI):
|
||||
await migrate_user_token_version(engine)
|
||||
await migrate_performance_indexes(engine)
|
||||
await migrate_chat_action_to_column(engine)
|
||||
# FK-rebuild MUST run before the unique-index creation: drop+create_all
|
||||
# of deferred_dispatch wipes its indexes; the next migration re-establishes
|
||||
# the partial unique index.
|
||||
await migrate_deferred_dispatch_event_log_fk(engine)
|
||||
await migrate_deferred_dispatch_unique_pending(engine)
|
||||
await migrate_schema_version(engine)
|
||||
from .database.seeds import seed_all
|
||||
await seed_all()
|
||||
@@ -147,11 +154,8 @@ async def lifespan(app: FastAPI):
|
||||
await dispose_engine()
|
||||
|
||||
|
||||
try:
|
||||
from importlib.metadata import version as _pkg_version
|
||||
_APP_VERSION = _pkg_version("notify-bridge-server")
|
||||
except Exception: # pragma: no cover — editable install edge cases
|
||||
_APP_VERSION = "0.0.0+unknown"
|
||||
from .version import resolve_version as _resolve_version
|
||||
_APP_VERSION = _resolve_version()
|
||||
|
||||
app = FastAPI(title="Notify Bridge", version=_APP_VERSION, lifespan=lifespan)
|
||||
|
||||
|
||||
@@ -0,0 +1,798 @@
|
||||
"""Deferred-dispatch infrastructure for quiet-hours notifications.
|
||||
|
||||
When ``evaluate_event_gate`` returns ``QUIET_HOURS`` for a deferrable event
|
||||
type, the dispatch site calls :func:`defer_event` instead of dropping. That
|
||||
either inserts a new ``DeferredDispatch`` row or coalesces the event into an
|
||||
existing pending row for the same ``(link_id, collection_id)`` — asset add
|
||||
+ matching remove cancels out, asset add + asset add merges set-union.
|
||||
|
||||
An APScheduler one-shot ``date`` job per quiet-window-end fires
|
||||
:func:`drain_deferred_due` which:
|
||||
1. Re-resolves each pending row's link/target/configs against current state.
|
||||
2. Drops rows whose link/target was deleted or disabled in the meantime.
|
||||
3. Re-checks quiet hours (in case the user extended the window mid-flight)
|
||||
and pushes ``fire_at`` to the new end if still suppressed.
|
||||
4. Dispatches via the existing ``NotificationDispatcher``.
|
||||
5. Writes a follow-up ``event_log`` row referencing the original
|
||||
``event_log_id`` so the dashboard shows "delivered late".
|
||||
|
||||
Wall-clock event types (``scheduled_message``) are explicitly NOT in
|
||||
``_DEFERRABLE_EVENT_TYPES`` — delivering a "good morning" memory at 3 pm is
|
||||
worse than dropping it. Those keep the legacy drop-on-quiet-hours behavior.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.models.events import EventType, ServiceEvent
|
||||
from notify_bridge_core.models.media import MediaAsset, MediaType
|
||||
from notify_bridge_core.notifications.dispatcher import (
|
||||
NotificationDispatcher,
|
||||
TargetConfig,
|
||||
)
|
||||
from notify_bridge_core.providers.base import ServiceProviderType
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import (
|
||||
DeferredDispatch,
|
||||
EventLog,
|
||||
NotificationTracker,
|
||||
ServiceProvider,
|
||||
)
|
||||
from .dispatch_helpers import (
|
||||
GateReason,
|
||||
apply_tracking_display_filters,
|
||||
evaluate_event_gate,
|
||||
get_app_timezone,
|
||||
load_link_data,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Policy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Change-driven event types that are safe to deliver after the quiet window
|
||||
# ends — the underlying state change (a photo was added, a PR was opened, the
|
||||
# UPS went on battery) remains relevant even hours later. Wall-clock event
|
||||
# types (``scheduled_message``) are deliberately excluded: a "good morning"
|
||||
# delivered at 3 pm is wrong, drop is more correct than late delivery.
|
||||
_DEFERRABLE_EVENT_TYPES: frozenset[str] = frozenset({
|
||||
# Immich
|
||||
"assets_added", "assets_removed",
|
||||
"collection_renamed", "collection_deleted", "sharing_changed",
|
||||
# Gitea
|
||||
"push",
|
||||
"issue_opened", "issue_closed", "issue_commented",
|
||||
"pr_opened", "pr_closed", "pr_merged", "pr_commented",
|
||||
"release_published",
|
||||
# Planka
|
||||
"card_created", "card_updated", "card_moved", "card_deleted",
|
||||
"card_commented", "comment_updated",
|
||||
"board_created", "board_updated", "board_deleted",
|
||||
"list_created", "list_updated", "list_deleted",
|
||||
"attachment_created", "card_label_added", "task_completed",
|
||||
# Generic webhook
|
||||
"webhook_received",
|
||||
# NUT (UPS)
|
||||
"ups_online", "ups_on_battery", "ups_low_battery",
|
||||
"ups_battery_restored", "ups_comms_lost", "ups_comms_restored",
|
||||
"ups_replace_battery", "ups_overload",
|
||||
})
|
||||
|
||||
# Per-tracker cap on the pending queue. A misconfigured short quiet window
|
||||
# plus a chatty upstream (e.g. mass-imported album) could otherwise grow
|
||||
# unbounded. On overflow we drop oldest (FIFO) — recent events still survive
|
||||
# to be delivered, ancient ones are sacrificed.
|
||||
_MAX_PENDING_PER_TRACKER = 1000
|
||||
|
||||
# Per-row timeout in the drain. Without this, a single hanging Telegram/SMTP
|
||||
# call could stall the whole drain for hours and leave the rest of the queue
|
||||
# stranded. Generous because legitimate large media uploads can take minutes.
|
||||
_DRAIN_DISPATCH_TIMEOUT_SECONDS = 120
|
||||
|
||||
|
||||
def is_deferrable(event_type: str) -> bool:
|
||||
"""Whether this event type should be deferred (vs. dropped) during quiet hours."""
|
||||
return event_type in _DEFERRABLE_EVENT_TYPES
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ServiceEvent (de)serialization
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# JSON column stores ``dataclasses.asdict(event)`` plus a normalisation pass
|
||||
# for datetimes (ISO strings) and enums (string values). Round-trip via the
|
||||
# reverse pass below.
|
||||
|
||||
def _normalize_for_json(value: Any) -> Any:
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
if isinstance(value, (EventType, MediaType, ServiceProviderType)):
|
||||
return value.value
|
||||
if isinstance(value, dict):
|
||||
return {k: _normalize_for_json(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_normalize_for_json(v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
def serialize_event(event: ServiceEvent) -> dict[str, Any]:
|
||||
"""Convert a ``ServiceEvent`` to a JSON-safe dict for ``DeferredDispatch.event_payload``."""
|
||||
return _normalize_for_json(dataclasses.asdict(event))
|
||||
|
||||
|
||||
def _parse_dt(s: Any) -> datetime:
|
||||
if isinstance(s, datetime):
|
||||
return s
|
||||
return datetime.fromisoformat(s)
|
||||
|
||||
|
||||
def _deserialize_asset(data: dict[str, Any]) -> MediaAsset:
|
||||
return MediaAsset(
|
||||
id=data["id"],
|
||||
type=MediaType(data["type"]),
|
||||
filename=data["filename"],
|
||||
created_at=_parse_dt(data["created_at"]),
|
||||
owner_name=data.get("owner_name"),
|
||||
description=data.get("description"),
|
||||
tags=list(data.get("tags") or []),
|
||||
thumbnail_url=data.get("thumbnail_url"),
|
||||
preview_url=data.get("preview_url"),
|
||||
full_url=data.get("full_url"),
|
||||
extra=dict(data.get("extra") or {}),
|
||||
)
|
||||
|
||||
|
||||
def deserialize_event(data: dict[str, Any]) -> ServiceEvent:
|
||||
"""Inverse of :func:`serialize_event`."""
|
||||
return ServiceEvent(
|
||||
event_type=EventType(data["event_type"]),
|
||||
provider_type=ServiceProviderType(data["provider_type"]),
|
||||
provider_name=data["provider_name"],
|
||||
collection_id=data["collection_id"],
|
||||
collection_name=data["collection_name"],
|
||||
timestamp=_parse_dt(data["timestamp"]),
|
||||
added_assets=[_deserialize_asset(a) for a in data.get("added_assets") or []],
|
||||
removed_asset_ids=list(data.get("removed_asset_ids") or []),
|
||||
added_count=int(data.get("added_count") or 0),
|
||||
removed_count=int(data.get("removed_count") or 0),
|
||||
old_name=data.get("old_name"),
|
||||
new_name=data.get("new_name"),
|
||||
old_shared=data.get("old_shared"),
|
||||
new_shared=data.get("new_shared"),
|
||||
extra=dict(data.get("extra") or {}),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coalescing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _added_ids(payload: dict[str, Any]) -> list[str]:
|
||||
return [a["id"] for a in payload.get("added_assets") or [] if "id" in a]
|
||||
|
||||
|
||||
def _coalesce_assets_added(
|
||||
new_event: ServiceEvent,
|
||||
existing_added_row: DeferredDispatch | None,
|
||||
existing_removed_row: DeferredDispatch | None,
|
||||
) -> tuple[str, DeferredDispatch | None, DeferredDispatch | None]:
|
||||
"""Apply add-then-remove cancellation and add-then-add union.
|
||||
|
||||
Returns ``(action, updated_added_row, updated_removed_row)`` where action
|
||||
is one of ``"insert"`` (caller must create a new row), ``"merge"`` (update
|
||||
existing rows in place — caller must session.add them).
|
||||
"""
|
||||
new_ids = [a.id for a in new_event.added_assets]
|
||||
new_ids_set = set(new_ids)
|
||||
|
||||
# 1) If a matching assets_removed row pending: subtract — that's a re-add.
|
||||
if existing_removed_row is not None:
|
||||
removed_ids = list(existing_removed_row.event_payload.get("removed_asset_ids") or [])
|
||||
kept = [rid for rid in removed_ids if rid not in new_ids_set]
|
||||
if len(kept) != len(removed_ids):
|
||||
payload = dict(existing_removed_row.event_payload)
|
||||
payload["removed_asset_ids"] = kept
|
||||
payload["removed_count"] = len(kept)
|
||||
existing_removed_row.event_payload = payload
|
||||
if not kept:
|
||||
# All previously-removed IDs are being re-added → entire
|
||||
# removal is cancelled. Mark for caller to delete.
|
||||
existing_removed_row.status = "cancelled"
|
||||
# The intersection re-adds are accounted for by the cancellation;
|
||||
# remaining new IDs (those NOT in removed list) still need to land
|
||||
# in the assets_added row.
|
||||
new_ids = [nid for nid in new_ids if nid not in set(removed_ids)]
|
||||
new_ids_set = set(new_ids)
|
||||
|
||||
if not new_ids:
|
||||
# All new added IDs cancelled an existing remove → nothing to enqueue.
|
||||
return ("merge", None, existing_removed_row)
|
||||
|
||||
if existing_added_row is None:
|
||||
return ("insert", None, existing_removed_row)
|
||||
|
||||
# 2) Union with existing assets_added — earliest fire_at wins.
|
||||
payload = dict(existing_added_row.event_payload)
|
||||
existing_assets = list(payload.get("added_assets") or [])
|
||||
seen = {a.get("id") for a in existing_assets}
|
||||
new_serialized = serialize_event(new_event)
|
||||
for a in new_serialized.get("added_assets") or []:
|
||||
if a.get("id") in new_ids_set and a.get("id") not in seen:
|
||||
existing_assets.append(a)
|
||||
seen.add(a.get("id"))
|
||||
payload["added_assets"] = existing_assets
|
||||
payload["added_count"] = len(existing_assets)
|
||||
existing_added_row.event_payload = payload
|
||||
return ("merge", existing_added_row, existing_removed_row)
|
||||
|
||||
|
||||
def _coalesce_assets_removed(
|
||||
new_event: ServiceEvent,
|
||||
existing_added_row: DeferredDispatch | None,
|
||||
existing_removed_row: DeferredDispatch | None,
|
||||
) -> tuple[str, DeferredDispatch | None, DeferredDispatch | None]:
|
||||
"""Mirror of :func:`_coalesce_assets_added` for removal events."""
|
||||
new_ids = list(new_event.removed_asset_ids)
|
||||
new_ids_set = set(new_ids)
|
||||
|
||||
# 1) If a matching assets_added row pending: subtract — that's an
|
||||
# add-then-remove within the window, cancel both sides.
|
||||
if existing_added_row is not None:
|
||||
added = list(existing_added_row.event_payload.get("added_assets") or [])
|
||||
kept_assets = [a for a in added if a.get("id") not in new_ids_set]
|
||||
if len(kept_assets) != len(added):
|
||||
payload = dict(existing_added_row.event_payload)
|
||||
payload["added_assets"] = kept_assets
|
||||
payload["added_count"] = len(kept_assets)
|
||||
existing_added_row.event_payload = payload
|
||||
if not kept_assets:
|
||||
existing_added_row.status = "cancelled"
|
||||
# IDs that were just added during the window don't need to flow
|
||||
# into the assets_removed row — they're a wash.
|
||||
cancelled_ids = {a.get("id") for a in added if a.get("id") in new_ids_set}
|
||||
new_ids = [nid for nid in new_ids if nid not in cancelled_ids]
|
||||
new_ids_set = set(new_ids)
|
||||
|
||||
if not new_ids:
|
||||
return ("merge", existing_added_row, None)
|
||||
|
||||
if existing_removed_row is None:
|
||||
return ("insert", existing_added_row, None)
|
||||
|
||||
# 2) Union with existing assets_removed — earliest fire_at wins.
|
||||
payload = dict(existing_removed_row.event_payload)
|
||||
existing_ids = list(payload.get("removed_asset_ids") or [])
|
||||
seen = set(existing_ids)
|
||||
for rid in new_ids:
|
||||
if rid not in seen:
|
||||
existing_ids.append(rid)
|
||||
seen.add(rid)
|
||||
payload["removed_asset_ids"] = existing_ids
|
||||
payload["removed_count"] = len(existing_ids)
|
||||
existing_removed_row.event_payload = payload
|
||||
return ("merge", existing_added_row, existing_removed_row)
|
||||
|
||||
|
||||
async def _find_pending_asset_rows(
|
||||
session: AsyncSession,
|
||||
link_id: int,
|
||||
collection_id: str,
|
||||
) -> tuple[DeferredDispatch | None, DeferredDispatch | None]:
|
||||
"""Return ``(assets_added_row, assets_removed_row)`` pending for this link+collection."""
|
||||
result = await session.exec(
|
||||
select(DeferredDispatch).where(
|
||||
DeferredDispatch.link_id == link_id,
|
||||
DeferredDispatch.collection_id == collection_id,
|
||||
DeferredDispatch.status == "pending",
|
||||
DeferredDispatch.event_type.in_(["assets_added", "assets_removed"]),
|
||||
)
|
||||
)
|
||||
added_row: DeferredDispatch | None = None
|
||||
removed_row: DeferredDispatch | None = None
|
||||
for row in result.all():
|
||||
if row.event_type == "assets_added":
|
||||
added_row = row
|
||||
elif row.event_type == "assets_removed":
|
||||
removed_row = row
|
||||
return added_row, removed_row
|
||||
|
||||
|
||||
async def _trim_queue_if_needed(
|
||||
session: AsyncSession,
|
||||
tracker_id: int,
|
||||
) -> None:
|
||||
"""Drop oldest pending rows beyond the per-tracker cap with a log row each.
|
||||
|
||||
Loads the parent tracker so the emitted event_log rows carry proper
|
||||
``tracker_name``/``provider_id``/``provider_name`` and slot into the
|
||||
dashboard's "by tracker" grouping — without these the drop rows show up
|
||||
under an unattributed bucket and confuse the audit trail.
|
||||
"""
|
||||
rows = (await session.exec(
|
||||
select(DeferredDispatch).where(
|
||||
DeferredDispatch.tracker_id == tracker_id,
|
||||
DeferredDispatch.status == "pending",
|
||||
).order_by(DeferredDispatch.fire_at.asc(), DeferredDispatch.id.asc())
|
||||
)).all()
|
||||
overflow = len(rows) - _MAX_PENDING_PER_TRACKER
|
||||
if overflow <= 0:
|
||||
return
|
||||
_LOGGER.warning(
|
||||
"Deferred queue for tracker %d exceeds cap (%d > %d); dropping %d oldest",
|
||||
tracker_id, len(rows), _MAX_PENDING_PER_TRACKER, overflow,
|
||||
)
|
||||
tracker = await session.get(NotificationTracker, tracker_id)
|
||||
tracker_name = tracker.name if tracker else ""
|
||||
provider_id = tracker.provider_id if tracker else None
|
||||
provider_name = ""
|
||||
if tracker is not None and provider_id is not None:
|
||||
provider = await session.get(ServiceProvider, provider_id)
|
||||
if provider is not None:
|
||||
provider_name = provider.name
|
||||
for row in rows[:overflow]:
|
||||
await _mark_dropped(
|
||||
session, row,
|
||||
tracker_name=tracker_name,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
reason="queue_overflow",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enqueue (called from dispatch sites when gate returns QUIET_HOURS)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def defer_event(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
event: ServiceEvent,
|
||||
user_id: int | None,
|
||||
tracker_id: int,
|
||||
link_id: int,
|
||||
event_log_id: int | None,
|
||||
fire_at: datetime,
|
||||
) -> str:
|
||||
"""Persist a deferred dispatch (or coalesce into an existing one).
|
||||
|
||||
Caller is responsible for committing the session. Returns one of:
|
||||
|
||||
* ``"inserted"`` — a fresh DeferredDispatch row was created.
|
||||
* ``"merged"`` — coalesced into an existing row (union or partial cancel).
|
||||
* ``"cancelled"`` — the new event fully cancelled an existing pending one
|
||||
(add-then-remove or remove-then-readd of the same asset IDs). Both sides
|
||||
are gone after this call.
|
||||
* ``"non_deferrable"`` — event type is wall-clock; caller should drop it
|
||||
with a ``"suppressed_quiet_hours_nondeferrable"`` event_log row.
|
||||
"""
|
||||
event_type = event.event_type.value
|
||||
if not is_deferrable(event_type):
|
||||
return "non_deferrable"
|
||||
|
||||
fire_at_utc = fire_at.astimezone(timezone.utc) if fire_at.tzinfo else fire_at.replace(tzinfo=timezone.utc)
|
||||
|
||||
# Asset events get set-merging across the same link+collection. Everything
|
||||
# else just gets a new row — those events aren't naturally cancellable.
|
||||
if event_type in ("assets_added", "assets_removed"):
|
||||
added_row, removed_row = await _find_pending_asset_rows(
|
||||
session, link_id, event.collection_id,
|
||||
)
|
||||
if event_type == "assets_added":
|
||||
action, upd_added, upd_removed = _coalesce_assets_added(
|
||||
event, added_row, removed_row,
|
||||
)
|
||||
else:
|
||||
action, upd_added, upd_removed = _coalesce_assets_removed(
|
||||
event, added_row, removed_row,
|
||||
)
|
||||
|
||||
# Apply pending updates. ``status="cancelled"`` rows are deleted
|
||||
# outright so the drain doesn't see them.
|
||||
fully_cancelled = False
|
||||
for row in (upd_added, upd_removed):
|
||||
if row is None:
|
||||
continue
|
||||
if row.status == "cancelled":
|
||||
await session.delete(row)
|
||||
fully_cancelled = True
|
||||
else:
|
||||
session.add(row)
|
||||
|
||||
if action == "insert":
|
||||
new_row = DeferredDispatch(
|
||||
user_id=user_id,
|
||||
tracker_id=tracker_id,
|
||||
link_id=link_id,
|
||||
event_log_id=event_log_id,
|
||||
event_type=event_type,
|
||||
collection_id=event.collection_id,
|
||||
event_payload=serialize_event(event),
|
||||
fire_at=fire_at_utc,
|
||||
status="pending",
|
||||
)
|
||||
session.add(new_row)
|
||||
await _trim_queue_if_needed(session, tracker_id)
|
||||
return "inserted"
|
||||
|
||||
# action == "merge" — either updated existing or fully cancelled.
|
||||
return "cancelled" if fully_cancelled and (upd_added is None or upd_added.status == "cancelled") and (upd_removed is None or upd_removed.status == "cancelled") else "merged"
|
||||
|
||||
# Non-asset event: no coalescing, fresh row.
|
||||
new_row = DeferredDispatch(
|
||||
user_id=user_id,
|
||||
tracker_id=tracker_id,
|
||||
link_id=link_id,
|
||||
event_log_id=event_log_id,
|
||||
event_type=event_type,
|
||||
collection_id=event.collection_id,
|
||||
event_payload=serialize_event(event),
|
||||
fire_at=fire_at_utc,
|
||||
status="pending",
|
||||
)
|
||||
session.add(new_row)
|
||||
await _trim_queue_if_needed(session, tracker_id)
|
||||
return "inserted"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Drain (called by APScheduler date job at quiet_hours_end_at)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def drain_deferred_due(now: datetime | None = None) -> dict[str, int]:
|
||||
"""Dispatch all pending DeferredDispatch rows whose ``fire_at <= now``.
|
||||
|
||||
Re-resolves link/target/configs against current DB state so config edits
|
||||
between suppression and drain time take effect. Returns a small stats
|
||||
dict for logging.
|
||||
|
||||
Implementation note: rows are *re-fetched* by id inside each per-tracker
|
||||
session rather than carried across session boundaries. Carrying a row
|
||||
instance to a new session and calling ``session.add(row)`` on a detached
|
||||
PK-bearing instance triggers an INSERT (collision with the existing PK)
|
||||
on flush — a class of bug that's invisible until the first session
|
||||
closes, hence the up-front re-fetch.
|
||||
"""
|
||||
now_utc = (now or datetime.now(timezone.utc))
|
||||
if now_utc.tzinfo is None:
|
||||
now_utc = now_utc.replace(tzinfo=timezone.utc)
|
||||
|
||||
stats = {"fired": 0, "dropped": 0, "rescheduled": 0, "errors": 0}
|
||||
engine = get_engine()
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
# Only pull the row identity + grouping key. Loading the full ORM
|
||||
# objects in a session that's about to close just wastes work — we
|
||||
# re-fetch fresh attached instances in the per-tracker session below.
|
||||
ident_rows = (await session.exec(
|
||||
select(DeferredDispatch.id, DeferredDispatch.tracker_id).where(
|
||||
DeferredDispatch.status == "pending",
|
||||
DeferredDispatch.fire_at <= now_utc,
|
||||
).order_by(DeferredDispatch.fire_at.asc())
|
||||
)).all()
|
||||
|
||||
if not ident_rows:
|
||||
_LOGGER.debug("drain_deferred_due: no pending rows due")
|
||||
return stats
|
||||
|
||||
_LOGGER.info(
|
||||
"Draining %d deferred dispatches due at %s",
|
||||
len(ident_rows), now_utc.isoformat(),
|
||||
)
|
||||
|
||||
# Group by tracker so a single per-tracker session can re-fetch its rows
|
||||
# (attached) and re-resolve link state once.
|
||||
ids_by_tracker: dict[int, list[int]] = {}
|
||||
for row_id, tracker_id in ident_rows:
|
||||
if row_id is None:
|
||||
continue
|
||||
ids_by_tracker.setdefault(tracker_id, []).append(row_id)
|
||||
|
||||
from .watcher import _get_telegram_caches
|
||||
from .http_session import get_http_session
|
||||
url_cache, asset_cache = await _get_telegram_caches()
|
||||
shared_session = await get_http_session()
|
||||
dispatcher = NotificationDispatcher(
|
||||
url_cache=url_cache, asset_cache=asset_cache, session=shared_session,
|
||||
)
|
||||
|
||||
for tracker_id, row_ids in ids_by_tracker.items():
|
||||
async with AsyncSession(engine) as session:
|
||||
tracker = await session.get(NotificationTracker, tracker_id)
|
||||
# Re-fetch rows freshly attached to THIS session.
|
||||
rows = (await session.exec(
|
||||
select(DeferredDispatch).where(DeferredDispatch.id.in_(row_ids))
|
||||
)).all()
|
||||
|
||||
if tracker is None or not tracker.enabled:
|
||||
# Tracker deleted or disabled between defer and drain — drop
|
||||
# all pending rows for it. Disable matches the live-path
|
||||
# invariant (watcher / webhooks / scheduled_dispatch all
|
||||
# short-circuit when ``tracker.enabled`` is False).
|
||||
reason = "tracker_removed" if tracker is None else "tracker_disabled_after_defer"
|
||||
for row in rows:
|
||||
await _mark_dropped(
|
||||
session, row,
|
||||
tracker=tracker, reason=reason,
|
||||
)
|
||||
stats["dropped"] += 1
|
||||
await session.commit()
|
||||
continue
|
||||
|
||||
provider = await session.get(ServiceProvider, tracker.provider_id)
|
||||
provider_config = dict(provider.config) if provider else {}
|
||||
provider_id = provider.id if provider else tracker.provider_id
|
||||
provider_name = provider.name if provider else ""
|
||||
app_tz = await get_app_timezone(session)
|
||||
|
||||
# Reload current link state. Broadcast links emit ONE entry per
|
||||
# child target sharing the SAME parent ``link_id`` — a plain
|
||||
# ``{link_id: ld}`` dict would silently drop N-1 children. The
|
||||
# drain dispatches to every expanded entry for the parent.
|
||||
link_data = await load_link_data(session, tracker_id)
|
||||
link_by_id: dict[int, list[dict[str, Any]]] = {}
|
||||
for ld in link_data:
|
||||
key = ld.get("link_id")
|
||||
if key is None:
|
||||
continue
|
||||
link_by_id.setdefault(key, []).append(ld)
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
await _process_row(
|
||||
session, row, tracker, provider_id, provider_name,
|
||||
provider_config, app_tz, link_by_id, dispatcher, stats,
|
||||
)
|
||||
except Exception as err: # noqa: BLE001 — keep draining other rows
|
||||
_LOGGER.exception(
|
||||
"Drain failed for deferred dispatch id=%s: %s", row.id, err,
|
||||
)
|
||||
stats["errors"] += 1
|
||||
|
||||
await session.commit()
|
||||
|
||||
_LOGGER.info("Drain complete: %s", stats)
|
||||
return stats
|
||||
|
||||
|
||||
async def _mark_dropped(
|
||||
session: AsyncSession,
|
||||
row: DeferredDispatch,
|
||||
*,
|
||||
tracker: NotificationTracker | None = None,
|
||||
tracker_name: str = "",
|
||||
provider_id: int | None = None,
|
||||
provider_name: str = "",
|
||||
reason: str,
|
||||
) -> None:
|
||||
"""Record a drop on the deferred row and emit a follow-up event_log entry.
|
||||
|
||||
``tracker``/``tracker_name``/``provider_id``/``provider_name`` populate
|
||||
the new event_log row's owner/provider columns so the dashboard "by
|
||||
tracker" grouping works for the drop path. Without these the row would
|
||||
have empty strings and slot into the "unknown" bucket.
|
||||
"""
|
||||
if tracker is not None:
|
||||
tracker_name = tracker_name or tracker.name
|
||||
if provider_id is None:
|
||||
provider_id = tracker.provider_id
|
||||
payload = row.event_payload if isinstance(row.event_payload, dict) else {}
|
||||
row.status = "dropped"
|
||||
row.fired_at = datetime.now(timezone.utc)
|
||||
session.add(row)
|
||||
session.add(EventLog(
|
||||
user_id=row.user_id,
|
||||
tracker_id=row.tracker_id,
|
||||
tracker_name=tracker_name,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
event_type=row.event_type,
|
||||
collection_id=row.collection_id,
|
||||
collection_name=payload.get("collection_name", ""),
|
||||
assets_count=int(payload.get("added_count", 0))
|
||||
or int(payload.get("removed_count", 0)),
|
||||
details={
|
||||
"dispatch_status": "deferred_then_dropped",
|
||||
"reason": reason,
|
||||
"original_event_log_id": row.event_log_id,
|
||||
"provider_type": payload.get("provider_type", ""),
|
||||
},
|
||||
))
|
||||
|
||||
|
||||
async def _process_row(
|
||||
session: AsyncSession,
|
||||
row: DeferredDispatch,
|
||||
tracker: NotificationTracker,
|
||||
provider_id: int,
|
||||
provider_name: str,
|
||||
provider_config: dict[str, Any],
|
||||
app_tz: str,
|
||||
link_by_id: dict[int, list[dict[str, Any]]],
|
||||
dispatcher: NotificationDispatcher,
|
||||
stats: dict[str, int],
|
||||
) -> None:
|
||||
"""Drain a single row: re-resolve link, re-evaluate gate, dispatch.
|
||||
|
||||
``link_by_id`` maps parent link_id → list of expanded entries (one per
|
||||
broadcast child, or a single-element list for regular targets). Every
|
||||
entry produces its own target_config so a broadcast deferred row fans
|
||||
out to all current children at drain time.
|
||||
"""
|
||||
expanded = link_by_id.get(row.link_id)
|
||||
if not expanded:
|
||||
# Link removed/disabled between defer and drain.
|
||||
await _mark_dropped(
|
||||
session, row,
|
||||
tracker=tracker, provider_id=provider_id, provider_name=provider_name,
|
||||
reason="link_removed",
|
||||
)
|
||||
stats["dropped"] += 1
|
||||
return
|
||||
|
||||
# Every expanded entry for a parent link shares the same tracking_config,
|
||||
# so the gate decision and ``apply_tracking_display_filters`` shaping are
|
||||
# made once. Only the target_configs differ across children.
|
||||
tc = expanded[0].get("tracking_config")
|
||||
event = deserialize_event(row.event_payload)
|
||||
|
||||
if tc is not None:
|
||||
outcome = evaluate_event_gate(event, tc, app_tz)
|
||||
if outcome.reason is GateReason.EVENT_TYPE_DISABLED:
|
||||
await _mark_dropped(
|
||||
session, row,
|
||||
tracker=tracker, provider_id=provider_id, provider_name=provider_name,
|
||||
reason="event_type_disabled_after_defer",
|
||||
)
|
||||
stats["dropped"] += 1
|
||||
return
|
||||
if outcome.reason is GateReason.QUIET_HOURS and outcome.quiet_hours_end_at is not None:
|
||||
row.fire_at = outcome.quiet_hours_end_at
|
||||
session.add(row)
|
||||
stats["rescheduled"] += 1
|
||||
try:
|
||||
from .scheduler import schedule_deferred_drain
|
||||
schedule_deferred_drain(outcome.quiet_hours_end_at)
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"Failed to reschedule drain for %s", outcome.quiet_hours_end_at,
|
||||
)
|
||||
return
|
||||
|
||||
shaped = apply_tracking_display_filters(event, tc)
|
||||
if shaped is None:
|
||||
# ``notify_favorites_only`` (or another display filter) dropped every
|
||||
# asset from the event. Inconsistent earlier behavior swallowed this
|
||||
# silently; we now route through the same "dropped + event_log"
|
||||
# pathway as link_removed so the dashboard shows why.
|
||||
await _mark_dropped(
|
||||
session, row,
|
||||
tracker=tracker, provider_id=provider_id, provider_name=provider_name,
|
||||
reason="filtered_after_defer",
|
||||
)
|
||||
stats["dropped"] += 1
|
||||
return
|
||||
|
||||
# Build one target_config per expanded child (regular targets → length 1;
|
||||
# broadcast → length N children).
|
||||
target_configs: list[TargetConfig] = []
|
||||
for ld in expanded:
|
||||
tmpl = ld.get("template_config")
|
||||
target_configs.append(TargetConfig(
|
||||
type=ld["target_type"],
|
||||
config=ld["target_config"],
|
||||
template_slots=ld.get("template_slots"),
|
||||
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
|
||||
date_only_format=(tmpl.date_only_format if tmpl and tmpl.date_only_format else "%d.%m.%Y"),
|
||||
provider_api_key=provider_config.get("api_key") or provider_config.get("api_token"),
|
||||
provider_internal_url=provider_config.get("url", ""),
|
||||
provider_external_url=provider_config.get("external_domain", "") or provider_config.get("url", ""),
|
||||
receivers=ld["receivers"],
|
||||
))
|
||||
|
||||
# Per-row timeout — a single hanging remote call (Telegram outage, slow
|
||||
# SMTP) must not stall the rest of the queue.
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
dispatcher.dispatch(shaped, target_configs),
|
||||
timeout=_DRAIN_DISPATCH_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
_LOGGER.warning(
|
||||
"Drain dispatch for row %s timed out after %ds",
|
||||
row.id, _DRAIN_DISPATCH_TIMEOUT_SECONDS,
|
||||
)
|
||||
results = [{"success": False, "error": f"timeout after {_DRAIN_DISPATCH_TIMEOUT_SECONDS}s"}]
|
||||
|
||||
success = any(r.get("success") for r in results)
|
||||
|
||||
row.status = "fired" if success else "dropped"
|
||||
row.fired_at = datetime.now(timezone.utc)
|
||||
session.add(row)
|
||||
|
||||
if success:
|
||||
stats["fired"] += 1
|
||||
session.add(EventLog(
|
||||
user_id=row.user_id,
|
||||
tracker_id=row.tracker_id,
|
||||
tracker_name=tracker.name,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
event_type=row.event_type,
|
||||
collection_id=row.collection_id,
|
||||
collection_name=event.collection_name,
|
||||
assets_count=event.added_count or event.removed_count or 0,
|
||||
details={
|
||||
"dispatch_status": "delivered_after_quiet_hours",
|
||||
"original_event_log_id": row.event_log_id,
|
||||
"deferred_for_seconds": int(
|
||||
(row.fired_at - row.created_at).total_seconds()
|
||||
),
|
||||
"provider_type": event.provider_type.value,
|
||||
},
|
||||
))
|
||||
else:
|
||||
stats["dropped"] += 1
|
||||
first_err = next((r.get("error") for r in results if not r.get("success")), "unknown")
|
||||
session.add(EventLog(
|
||||
user_id=row.user_id,
|
||||
tracker_id=row.tracker_id,
|
||||
tracker_name=tracker.name,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
event_type=row.event_type,
|
||||
collection_id=row.collection_id,
|
||||
collection_name=event.collection_name,
|
||||
assets_count=event.added_count or event.removed_count or 0,
|
||||
details={
|
||||
"dispatch_status": "deferred_then_failed",
|
||||
"reason": str(first_err)[:200],
|
||||
"original_event_log_id": row.event_log_id,
|
||||
"provider_type": event.provider_type.value,
|
||||
},
|
||||
))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Startup: reschedule pending drain jobs found in the DB
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def load_pending_drain_jobs() -> int:
|
||||
"""At startup, scan ``DeferredDispatch`` for pending rows and (re)schedule drains.
|
||||
|
||||
Rows whose ``fire_at`` already passed get a single immediate-fire job; the
|
||||
rest get one job per distinct ``fire_at`` (minute-rounded) so all rows
|
||||
sharing a window end share a drain.
|
||||
"""
|
||||
from .scheduler import schedule_deferred_drain
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
rows = (await session.exec(
|
||||
select(DeferredDispatch.fire_at).where(
|
||||
DeferredDispatch.status == "pending",
|
||||
)
|
||||
)).all()
|
||||
if not rows:
|
||||
return 0
|
||||
unique_fire_ats: set[datetime] = set()
|
||||
for fa in rows:
|
||||
if isinstance(fa, datetime):
|
||||
unique_fire_ats.add(fa.astimezone(timezone.utc) if fa.tzinfo else fa.replace(tzinfo=timezone.utc))
|
||||
for fa in unique_fire_ats:
|
||||
schedule_deferred_drain(fa)
|
||||
_LOGGER.info(
|
||||
"Loaded %d pending deferred dispatches; scheduled %d drain job(s)",
|
||||
len(rows), len(unique_fire_ats),
|
||||
)
|
||||
return len(unique_fire_ats)
|
||||
@@ -5,7 +5,9 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime, time, timezone
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, time, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Callable
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
@@ -33,6 +35,35 @@ from ..database.models import (
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GateReason(str, Enum):
|
||||
"""Why ``evaluate_event_gate`` allowed or blocked a dispatch.
|
||||
|
||||
String-backed so it can be persisted in ``EventLog.details`` JSON and
|
||||
round-trip cleanly.
|
||||
"""
|
||||
|
||||
ALLOWED = "allowed"
|
||||
EVENT_TYPE_DISABLED = "event_type_disabled"
|
||||
QUIET_HOURS = "quiet_hours"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GateOutcome:
|
||||
"""Result of evaluating a (event, tracking_config) pair against dispatch gates.
|
||||
|
||||
``quiet_hours_end_at`` is set iff ``reason == QUIET_HOURS`` and gives the
|
||||
UTC datetime at which the current quiet window ends — used by the
|
||||
deferred-dispatch scheduler to know when to fire the held notification.
|
||||
"""
|
||||
|
||||
reason: GateReason
|
||||
quiet_hours_end_at: datetime | None = None
|
||||
|
||||
@property
|
||||
def allowed(self) -> bool:
|
||||
return self.reason is GateReason.ALLOWED
|
||||
|
||||
|
||||
def _resolve_zoneinfo(tz_name: str | None) -> ZoneInfo:
|
||||
"""Resolve an IANA tz string to a ZoneInfo, falling back to UTC on any error."""
|
||||
if not tz_name:
|
||||
@@ -44,6 +75,59 @@ def _resolve_zoneinfo(tz_name: str | None) -> ZoneInfo:
|
||||
return ZoneInfo("UTC")
|
||||
|
||||
|
||||
def quiet_hours_status(
|
||||
start: str | None,
|
||||
end: str | None,
|
||||
tz_name: str | None = "UTC",
|
||||
) -> datetime | None:
|
||||
"""Return the UTC datetime when the current quiet window ends, or None.
|
||||
|
||||
Returns ``None`` when:
|
||||
* either bound is missing,
|
||||
* the bounds are malformed,
|
||||
* the current local time is outside the configured window.
|
||||
|
||||
Returns a UTC ``datetime`` aligned to ``HH:MM`` (seconds=0, microseconds=0)
|
||||
representing the next end-of-window moment after "now" when the current
|
||||
time IS inside the window. For overnight windows (e.g. 22:00-06:00) the
|
||||
end may be tomorrow.
|
||||
"""
|
||||
if not start or not end:
|
||||
return None
|
||||
try:
|
||||
tz = _resolve_zoneinfo(tz_name)
|
||||
now_local = datetime.now(timezone.utc).astimezone(tz)
|
||||
t_start = time.fromisoformat(start)
|
||||
t_end = time.fromisoformat(end)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
# ``start == end`` (e.g. "00:00-00:00") has no consistent meaning: under
|
||||
# the normal-window branch the window is one instant wide; under the
|
||||
# overnight-window branch it's effectively always-on. Either is almost
|
||||
# certainly a user mistake, so treat it as "no window configured" rather
|
||||
# than silently deferring every notification all day.
|
||||
if t_start == t_end:
|
||||
return None
|
||||
|
||||
now_t = now_local.time()
|
||||
if t_start <= t_end:
|
||||
in_window = t_start <= now_t <= t_end
|
||||
else:
|
||||
in_window = now_t >= t_start or now_t <= t_end
|
||||
if not in_window:
|
||||
return None
|
||||
|
||||
end_today = now_local.replace(
|
||||
hour=t_end.hour, minute=t_end.minute, second=0, microsecond=0,
|
||||
)
|
||||
# If today's end already passed (overnight window, post-midnight half),
|
||||
# the actual end is tomorrow at the same wall-clock time.
|
||||
if end_today <= now_local:
|
||||
end_today = end_today + timedelta(days=1)
|
||||
return end_today.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def in_quiet_hours(
|
||||
start: str | None,
|
||||
end: str | None,
|
||||
@@ -51,23 +135,12 @@ def in_quiet_hours(
|
||||
) -> bool:
|
||||
"""Check if the current time (in the given timezone) is within the quiet window.
|
||||
|
||||
HH:MM strings are interpreted in the supplied timezone. If either bound is
|
||||
missing, quiet hours are disabled.
|
||||
Thin wrapper over ``quiet_hours_status`` preserved for back-compat with
|
||||
callers that only need the boolean. New code should prefer
|
||||
``quiet_hours_status`` (or ``evaluate_event_gate``) when the window end
|
||||
time matters.
|
||||
"""
|
||||
if not start or not end:
|
||||
return False
|
||||
try:
|
||||
tz = _resolve_zoneinfo(tz_name)
|
||||
now = datetime.now(timezone.utc).astimezone(tz).time()
|
||||
t_start = time.fromisoformat(start)
|
||||
t_end = time.fromisoformat(end)
|
||||
if t_start <= t_end:
|
||||
return t_start <= now <= t_end
|
||||
else:
|
||||
# Overnight window (e.g., 22:00 - 06:00)
|
||||
return now >= t_start or now <= t_end
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
return quiet_hours_status(start, end, tz_name) is not None
|
||||
|
||||
|
||||
async def get_app_timezone(session: AsyncSession) -> str:
|
||||
@@ -77,18 +150,13 @@ async def get_app_timezone(session: AsyncSession) -> str:
|
||||
return value or "UTC"
|
||||
|
||||
|
||||
def event_allowed_by_config(
|
||||
event: ServiceEvent,
|
||||
tc: TrackingConfig,
|
||||
tz_name: str | None = "UTC",
|
||||
) -> bool:
|
||||
"""Check if an event is allowed by the tracking config's flags + quiet hours."""
|
||||
# Quiet hours gate every event type when enabled.
|
||||
if tc.quiet_hours_enabled and in_quiet_hours(
|
||||
tc.quiet_hours_start, tc.quiet_hours_end, tz_name
|
||||
):
|
||||
return False
|
||||
def _event_type_enabled(event: ServiceEvent, tc: TrackingConfig) -> bool:
|
||||
"""Return True iff the tracking config's per-event-type flag allows this event.
|
||||
|
||||
Quiet hours are NOT considered here — this is the user's "do I care about
|
||||
this kind of event at all" gate. See ``evaluate_event_gate`` for the
|
||||
combined gate that also folds in quiet hours.
|
||||
"""
|
||||
event_type = event.event_type.value
|
||||
flag_map = {
|
||||
# Immich events
|
||||
@@ -140,6 +208,52 @@ def event_allowed_by_config(
|
||||
return flag_map.get(event_type, True)
|
||||
|
||||
|
||||
def evaluate_event_gate(
|
||||
event: ServiceEvent,
|
||||
tc: TrackingConfig,
|
||||
tz_name: str | None = "UTC",
|
||||
) -> GateOutcome:
|
||||
"""Decide whether an event should dispatch through the given tracking config.
|
||||
|
||||
Returns a :class:`GateOutcome` carrying both the verdict and — when blocked
|
||||
by quiet hours — the UTC datetime at which the window ends so the caller
|
||||
can schedule a deferred dispatch.
|
||||
|
||||
Order of checks: quiet hours first, then per-event-type flag. Quiet hours
|
||||
is the "louder" gate (it applies to every type), so reporting it first
|
||||
avoids the surprising case of "you disabled this event type" showing up
|
||||
when the user really just opened the quiet window.
|
||||
"""
|
||||
if tc.quiet_hours_enabled:
|
||||
end_at = quiet_hours_status(
|
||||
tc.quiet_hours_start, tc.quiet_hours_end, tz_name,
|
||||
)
|
||||
if end_at is not None:
|
||||
return GateOutcome(
|
||||
reason=GateReason.QUIET_HOURS,
|
||||
quiet_hours_end_at=end_at,
|
||||
)
|
||||
|
||||
if not _event_type_enabled(event, tc):
|
||||
return GateOutcome(reason=GateReason.EVENT_TYPE_DISABLED)
|
||||
|
||||
return GateOutcome(reason=GateReason.ALLOWED)
|
||||
|
||||
|
||||
def event_allowed_by_config(
|
||||
event: ServiceEvent,
|
||||
tc: TrackingConfig,
|
||||
tz_name: str | None = "UTC",
|
||||
) -> bool:
|
||||
"""Boolean back-compat wrapper around :func:`evaluate_event_gate`.
|
||||
|
||||
New call sites should use ``evaluate_event_gate`` directly so they can
|
||||
distinguish a quiet-hours suppression (deferrable) from an event-type
|
||||
disable (drop forever).
|
||||
"""
|
||||
return evaluate_event_gate(event, tc, tz_name).allowed
|
||||
|
||||
|
||||
# --- Display-time filters driven by TrackingConfig -------------------------
|
||||
#
|
||||
# These transform a ServiceEvent so the dispatched notification reflects the
|
||||
@@ -472,6 +586,7 @@ async def load_link_data(
|
||||
resolved = await _resolve_target(session, child_target)
|
||||
link_data.append({
|
||||
**resolved,
|
||||
"link_id": tt.id,
|
||||
"tracking_config": tracking_config,
|
||||
"template_config": template_config,
|
||||
"template_slots": template_slots,
|
||||
@@ -482,6 +597,7 @@ async def load_link_data(
|
||||
resolved = await _resolve_target(session, target)
|
||||
link_data.append({
|
||||
**resolved,
|
||||
"link_id": tt.id,
|
||||
"tracking_config": tracking_config,
|
||||
"template_config": template_config,
|
||||
"template_slots": template_slots,
|
||||
|
||||
@@ -0,0 +1,295 @@
|
||||
"""Upstream release-check service.
|
||||
|
||||
Reads the configured release provider, asks it for the latest upstream release,
|
||||
and caches the result into :class:`AppSetting` rows so the API can serve the
|
||||
status without re-hitting the network. All failures are swallowed and surfaced
|
||||
through ``release_error`` — the server must stay up even if Gitea is down.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.release import (
|
||||
ReleaseErrorCode,
|
||||
ReleaseInfo,
|
||||
ReleaseProviderKind,
|
||||
build_release_provider,
|
||||
)
|
||||
from notify_bridge_core.release.base import is_newer
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..api.app_settings import get_setting
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import AppSetting
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Cached-state AppSetting keys (read by the API, written by the checker).
|
||||
KEY_LATEST_TAG = "release_latest_tag"
|
||||
KEY_LATEST_VERSION = "release_latest_version"
|
||||
KEY_LATEST_URL = "release_latest_url"
|
||||
KEY_LATEST_BODY = "release_latest_body"
|
||||
KEY_LATEST_NAME = "release_latest_name"
|
||||
KEY_LATEST_PUBLISHED_AT = "release_latest_published_at"
|
||||
KEY_LATEST_PRERELEASE = "release_latest_prerelease"
|
||||
KEY_CHECKED_AT = "release_checked_at"
|
||||
KEY_ERROR = "release_error"
|
||||
|
||||
# Operator-configured keys.
|
||||
KEY_PROVIDER_KIND = "release_provider_kind"
|
||||
KEY_PROVIDER_URL = "release_provider_url"
|
||||
KEY_PROVIDER_REPO = "release_provider_repo"
|
||||
KEY_INCLUDE_PRERELEASES = "release_include_prereleases"
|
||||
KEY_CHECK_INTERVAL_HOURS = "release_check_interval_hours"
|
||||
|
||||
# Allowed range for the interval (matches the UI hint).
|
||||
INTERVAL_MIN_HOURS = 1
|
||||
INTERVAL_MAX_HOURS = 168
|
||||
|
||||
# Minimum gap between checks. Independent of the configured interval — a flood
|
||||
# of /release/check API calls or scheduler misfires can't push real load on
|
||||
# upstream Gitea within this window.
|
||||
_MIN_CHECK_INTERVAL = timedelta(seconds=30)
|
||||
|
||||
# Serialises concurrent run_check invocations (scheduled job + manual force
|
||||
# check + provider-changed save can all fire close together).
|
||||
_run_lock = asyncio.Lock()
|
||||
|
||||
_CACHED_KEYS = (
|
||||
KEY_LATEST_TAG,
|
||||
KEY_LATEST_VERSION,
|
||||
KEY_LATEST_URL,
|
||||
KEY_LATEST_BODY,
|
||||
KEY_LATEST_NAME,
|
||||
KEY_LATEST_PUBLISHED_AT,
|
||||
KEY_LATEST_PRERELEASE,
|
||||
KEY_CHECKED_AT,
|
||||
KEY_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReleaseStatus:
|
||||
"""Snapshot returned by :func:`load_status` and friends."""
|
||||
|
||||
provider: str
|
||||
current: str
|
||||
latest: str | None
|
||||
latest_tag: str | None
|
||||
latest_url: str | None
|
||||
latest_body: str | None
|
||||
latest_name: str | None
|
||||
latest_published_at: str | None
|
||||
latest_prerelease: bool
|
||||
checked_at: str | None
|
||||
update_available: bool
|
||||
error: str | None
|
||||
|
||||
|
||||
def _server_version() -> str:
|
||||
"""Resolve the running server version (delegates to the shared helper).
|
||||
|
||||
Routed through :mod:`notify_bridge_server.version` so the "current" the
|
||||
UI reports matches `/api/health` and is robust to stale editable installs.
|
||||
"""
|
||||
from ..version import resolve_version
|
||||
|
||||
return resolve_version()
|
||||
|
||||
|
||||
def parse_interval_hours(raw: str | None, default: int = 12) -> int:
|
||||
"""Clamp/parse the interval setting into a sensible integer."""
|
||||
|
||||
try:
|
||||
value = int((raw or "").strip() or default)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return max(INTERVAL_MIN_HOURS, min(INTERVAL_MAX_HOURS, value))
|
||||
|
||||
|
||||
def _coerce_provider_kind(raw: str | None) -> str:
|
||||
"""Normalise the stored kind to a known enum value (default: disabled)."""
|
||||
try:
|
||||
return ReleaseProviderKind(raw or "").value
|
||||
except ValueError:
|
||||
return ReleaseProviderKind.DISABLED.value
|
||||
|
||||
|
||||
async def load_status() -> ReleaseStatus:
|
||||
"""Read the latest cached status without performing a network call."""
|
||||
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
provider = await get_setting(session, KEY_PROVIDER_KIND)
|
||||
latest_tag = await get_setting(session, KEY_LATEST_TAG)
|
||||
latest_version = await get_setting(session, KEY_LATEST_VERSION)
|
||||
latest_url = await get_setting(session, KEY_LATEST_URL)
|
||||
latest_body = await get_setting(session, KEY_LATEST_BODY)
|
||||
latest_name = await get_setting(session, KEY_LATEST_NAME)
|
||||
latest_published_at = await get_setting(session, KEY_LATEST_PUBLISHED_AT)
|
||||
latest_prerelease = await get_setting(session, KEY_LATEST_PRERELEASE)
|
||||
checked_at = await get_setting(session, KEY_CHECKED_AT)
|
||||
error = await get_setting(session, KEY_ERROR)
|
||||
|
||||
current = _server_version()
|
||||
has_latest = bool(latest_version)
|
||||
update_available = bool(has_latest and is_newer(latest_version, current))
|
||||
return ReleaseStatus(
|
||||
provider=_coerce_provider_kind(provider),
|
||||
current=current,
|
||||
latest=latest_version or None,
|
||||
latest_tag=latest_tag or None,
|
||||
latest_url=latest_url or None,
|
||||
latest_body=latest_body or None,
|
||||
latest_name=latest_name or None,
|
||||
latest_published_at=latest_published_at or None,
|
||||
latest_prerelease=latest_prerelease == "1",
|
||||
checked_at=checked_at or None,
|
||||
update_available=update_available,
|
||||
error=error or None,
|
||||
)
|
||||
|
||||
|
||||
async def run_check(*, force: bool = False) -> ReleaseStatus:
|
||||
"""Hit the configured provider and persist the result, then return it.
|
||||
|
||||
Args:
|
||||
force: bypass the per-process rate limit. Used by the manual
|
||||
"Check now" admin action; the scheduled probe never forces.
|
||||
"""
|
||||
async with _run_lock:
|
||||
return await _run_check_locked(force=force)
|
||||
|
||||
|
||||
async def _run_check_locked(*, force: bool) -> ReleaseStatus:
|
||||
from .http_session import get_http_session
|
||||
|
||||
# Throttle: if the last check landed within _MIN_CHECK_INTERVAL and the
|
||||
# caller didn't ask for force, skip the network round-trip and return the
|
||||
# cached status. Force is still gated by the lock above, so an abusive
|
||||
# admin spamming /release/check serialises to one in-flight at a time.
|
||||
if not force:
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
last = await get_setting(session, KEY_CHECKED_AT)
|
||||
if last:
|
||||
try:
|
||||
last_dt = datetime.fromisoformat(last)
|
||||
if datetime.now(timezone.utc) - last_dt < _MIN_CHECK_INTERVAL:
|
||||
return await load_status()
|
||||
except ValueError:
|
||||
pass # corrupted timestamp → fall through and overwrite
|
||||
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
provider_kind = await get_setting(session, KEY_PROVIDER_KIND)
|
||||
provider_url = await get_setting(session, KEY_PROVIDER_URL)
|
||||
provider_repo = await get_setting(session, KEY_PROVIDER_REPO)
|
||||
include_prereleases = (await get_setting(session, KEY_INCLUDE_PRERELEASES)) == "1"
|
||||
|
||||
http = await get_http_session()
|
||||
provider = build_release_provider(
|
||||
provider_kind or ReleaseProviderKind.DISABLED.value,
|
||||
session=http,
|
||||
url=provider_url,
|
||||
repo=provider_repo,
|
||||
)
|
||||
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
if provider is None:
|
||||
# Disabled (no error to surface) vs misconfigured (operator action
|
||||
# required) are different states — the UI distinguishes them.
|
||||
kind = _coerce_provider_kind(provider_kind)
|
||||
err = (
|
||||
ReleaseErrorCode.DISABLED.value
|
||||
if kind == ReleaseProviderKind.DISABLED.value
|
||||
else ReleaseErrorCode.MISCONFIGURED.value
|
||||
)
|
||||
await persist_release_state(checked_at=timestamp, error=err, info=None)
|
||||
return await load_status()
|
||||
|
||||
try:
|
||||
info = await provider.fetch_latest(include_prereleases=include_prereleases)
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as err:
|
||||
_LOGGER.warning("Release provider network error: %s", err)
|
||||
await persist_release_state(
|
||||
checked_at=timestamp,
|
||||
error=ReleaseErrorCode.NETWORK_ERROR.value,
|
||||
info=None,
|
||||
)
|
||||
return await load_status()
|
||||
except ValueError as err:
|
||||
_LOGGER.warning("Release provider parse/validation error: %s", err)
|
||||
await persist_release_state(
|
||||
checked_at=timestamp,
|
||||
error=ReleaseErrorCode.PARSE_ERROR.value,
|
||||
info=None,
|
||||
)
|
||||
return await load_status()
|
||||
|
||||
if info is None:
|
||||
await persist_release_state(
|
||||
checked_at=timestamp,
|
||||
error=ReleaseErrorCode.NO_RELEASE_FOUND.value,
|
||||
info=None,
|
||||
)
|
||||
return await load_status()
|
||||
|
||||
await persist_release_state(checked_at=timestamp, error=None, info=info)
|
||||
return await load_status()
|
||||
|
||||
|
||||
async def persist_release_state(
|
||||
*,
|
||||
checked_at: str,
|
||||
error: str | None,
|
||||
info: ReleaseInfo | None,
|
||||
) -> None:
|
||||
"""Write all cached-state keys in one transaction.
|
||||
|
||||
Public because the settings PUT handler invokes it to flush stale cache
|
||||
when the operator points the provider at a different repo — we don't want
|
||||
the previous repo's "latest" to keep advertising as available.
|
||||
"""
|
||||
|
||||
if info is None:
|
||||
rows: dict[str, str] = {
|
||||
KEY_LATEST_TAG: "",
|
||||
KEY_LATEST_VERSION: "",
|
||||
KEY_LATEST_URL: "",
|
||||
KEY_LATEST_BODY: "",
|
||||
KEY_LATEST_NAME: "",
|
||||
KEY_LATEST_PUBLISHED_AT: "",
|
||||
KEY_LATEST_PRERELEASE: "0",
|
||||
}
|
||||
else:
|
||||
rows = {
|
||||
KEY_LATEST_TAG: info.tag,
|
||||
KEY_LATEST_VERSION: info.version,
|
||||
KEY_LATEST_URL: info.url or "",
|
||||
KEY_LATEST_BODY: info.body or "",
|
||||
KEY_LATEST_NAME: info.name or "",
|
||||
KEY_LATEST_PUBLISHED_AT: info.published_at or "",
|
||||
KEY_LATEST_PRERELEASE: "1" if info.prerelease else "0",
|
||||
}
|
||||
rows[KEY_CHECKED_AT] = checked_at
|
||||
rows[KEY_ERROR] = error or ""
|
||||
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
for key, value in rows.items():
|
||||
row = await session.get(AppSetting, key)
|
||||
if row:
|
||||
row.value = value
|
||||
else:
|
||||
row = AppSetting(key=key, value=value)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
|
||||
|
||||
def cached_keys() -> tuple[str, ...]:
|
||||
"""Return the keys the checker writes — used by API masking helpers."""
|
||||
return _CACHED_KEYS
|
||||
@@ -42,8 +42,9 @@ from ..database.models import (
|
||||
TrackingConfig,
|
||||
)
|
||||
from .dispatch_helpers import (
|
||||
GateReason,
|
||||
apply_tracking_display_filters,
|
||||
event_allowed_by_config,
|
||||
evaluate_event_gate,
|
||||
get_app_timezone,
|
||||
load_link_data,
|
||||
)
|
||||
@@ -262,7 +263,11 @@ async def dispatch_scheduled_for_tracker(
|
||||
if tc is not None:
|
||||
if not getattr(tc, f"{kind}_enabled", True):
|
||||
continue
|
||||
if not event_allowed_by_config(event, tc, app_tz):
|
||||
# Scheduled / periodic / memory dispatches are wall-clock
|
||||
# by nature — a "good morning" delivered at 3 pm is wrong,
|
||||
# so quiet hours = drop (not defer) for these kinds. The
|
||||
# other gate (per-event-type flag) still applies.
|
||||
if not evaluate_event_gate(event, tc, app_tz).allowed:
|
||||
continue
|
||||
if tmpl is None:
|
||||
continue
|
||||
|
||||
@@ -153,6 +153,16 @@ async def start_scheduler() -> None:
|
||||
# Load scheduled backup job if enabled
|
||||
await _load_backup_job()
|
||||
|
||||
# Re-arm any deferred-dispatch drains that were pending across restart.
|
||||
from .deferred_dispatch import load_pending_drain_jobs
|
||||
await load_pending_drain_jobs()
|
||||
|
||||
# And install the periodic safety-net catch-up scan.
|
||||
_schedule_drain_catchup()
|
||||
|
||||
# Schedule the upstream release-check probe.
|
||||
await _schedule_release_check()
|
||||
|
||||
|
||||
def _schedule_event_cleanup() -> None:
|
||||
"""Schedule a daily job to delete EventLog entries older than 90 days."""
|
||||
@@ -1079,6 +1089,129 @@ async def unschedule_backup() -> None:
|
||||
_LOGGER.info("Unscheduled backup job")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deferred-dispatch drain
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# When ``defer_event`` enqueues a quiet-hours notification, the calling site
|
||||
# asks us to add a one-shot ``date`` job at ``quiet_hours_end_at``. We key the
|
||||
# job id by the minute-rounded end time so multiple defers that share the same
|
||||
# window-end share a single drain job (idempotent via ``replace_existing``).
|
||||
#
|
||||
# At fire time the job runs ``drain_deferred_due`` which scans all pending
|
||||
# rows and dispatches whatever is ready.
|
||||
#
|
||||
# A periodic catch-up scan runs every ``_DRAIN_CATCHUP_INTERVAL_SECONDS`` as
|
||||
# the safety net for failure modes the one-shot job can't cover:
|
||||
# * APScheduler's misfire grace exceeded (event loop blocked past fire_at;
|
||||
# the date job is silently discarded by the scheduler)
|
||||
# * Process killed between the deferred-row DB commit and the
|
||||
# ``schedule_deferred_drain`` call — row exists, job doesn't
|
||||
# * Clock drift / DST seam edge cases
|
||||
|
||||
_DEFERRED_DRAIN_PREFIX = "deferred_drain_"
|
||||
_DEFERRED_DRAIN_CATCHUP_JOB = "deferred_drain_catchup"
|
||||
# Generous so a temporarily-blocked event loop doesn't make the scheduler
|
||||
# discard our drain job. Once discarded the deferred rows would wait for the
|
||||
# next process restart or the catch-up scan below — survivable but visibly
|
||||
# late from the user's perspective.
|
||||
_DEFERRED_DRAIN_MISFIRE_GRACE_SECONDS = 3600
|
||||
# 5 min trade-off between "promptness of late delivery" and "extra DB churn".
|
||||
# The scan is a single indexed lookup on (status, fire_at).
|
||||
_DRAIN_CATCHUP_INTERVAL_SECONDS = 300
|
||||
|
||||
|
||||
def _drain_job_id_for(fire_at_utc: datetime) -> str:
|
||||
return f"{_DEFERRED_DRAIN_PREFIX}{fire_at_utc.strftime('%Y%m%d%H%M')}"
|
||||
|
||||
|
||||
def schedule_deferred_drain(fire_at_utc: datetime) -> None:
|
||||
"""Add an idempotent one-shot drain job for ``fire_at_utc``.
|
||||
|
||||
Past times schedule a near-immediate firing (now+1s) — the drain query
|
||||
handles ``fire_at <= now`` regardless of which job fired, so a near-miss
|
||||
still picks up the work.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
if fire_at_utc.tzinfo is None:
|
||||
fire_at_utc = fire_at_utc.replace(tzinfo=timezone.utc)
|
||||
|
||||
scheduler = get_scheduler()
|
||||
job_id = _drain_job_id_for(fire_at_utc)
|
||||
run_at = fire_at_utc
|
||||
if run_at <= datetime.now(timezone.utc):
|
||||
from datetime import timedelta
|
||||
run_at = datetime.now(timezone.utc) + timedelta(seconds=1)
|
||||
|
||||
scheduler.add_job(
|
||||
_run_deferred_drain,
|
||||
"date",
|
||||
run_date=run_at,
|
||||
id=job_id,
|
||||
args=[fire_at_utc.isoformat()],
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
# Override the global 5-min grace — see module-level comment.
|
||||
misfire_grace_time=_DEFERRED_DRAIN_MISFIRE_GRACE_SECONDS,
|
||||
)
|
||||
_LOGGER.debug("Scheduled deferred drain %s (fire_at=%s)", job_id, fire_at_utc.isoformat())
|
||||
|
||||
|
||||
def _schedule_drain_catchup() -> None:
|
||||
"""Install the periodic catch-up scan. See module comment."""
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
scheduler = get_scheduler()
|
||||
if scheduler.get_job(_DEFERRED_DRAIN_CATCHUP_JOB):
|
||||
return
|
||||
scheduler.add_job(
|
||||
_run_deferred_drain_catchup,
|
||||
IntervalTrigger(seconds=_DRAIN_CATCHUP_INTERVAL_SECONDS),
|
||||
id=_DEFERRED_DRAIN_CATCHUP_JOB,
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
coalesce=True,
|
||||
)
|
||||
_LOGGER.info(
|
||||
"Scheduled deferred-dispatch catch-up scan every %ds",
|
||||
_DRAIN_CATCHUP_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
async def _run_deferred_drain(fire_at_iso: str) -> None:
|
||||
"""APScheduler entry point — log the original fire_at then drain due rows.
|
||||
|
||||
The ``fire_at_iso`` arg is only used for logging; the drain itself picks
|
||||
up every pending row whose ``fire_at`` has passed.
|
||||
"""
|
||||
from .deferred_dispatch import drain_deferred_due
|
||||
try:
|
||||
stats = await drain_deferred_due()
|
||||
_LOGGER.info("Deferred drain (fire_at=%s) stats: %s", fire_at_iso, stats)
|
||||
except Exception as err: # noqa: BLE001
|
||||
_LOGGER.exception("Deferred drain (fire_at=%s) failed: %s", fire_at_iso, err)
|
||||
|
||||
|
||||
async def _run_deferred_drain_catchup() -> None:
|
||||
"""Periodic safety-net drain — see module comment.
|
||||
|
||||
Distinct from the per-fire-at job only in cadence and log line; calls the
|
||||
same ``drain_deferred_due`` which is a no-op when nothing is due.
|
||||
"""
|
||||
from .deferred_dispatch import drain_deferred_due
|
||||
try:
|
||||
stats = await drain_deferred_due()
|
||||
# Quiet at debug level when nothing happened — every 5 min is too
|
||||
# noisy at info on an idle system.
|
||||
if stats.get("fired") or stats.get("dropped") or stats.get("errors"):
|
||||
_LOGGER.info("Deferred catch-up stats: %s", stats)
|
||||
else:
|
||||
_LOGGER.debug("Deferred catch-up stats: %s", stats)
|
||||
except Exception as err: # noqa: BLE001
|
||||
_LOGGER.exception("Deferred catch-up drain failed: %s", err)
|
||||
|
||||
|
||||
async def _run_scheduled_backup() -> None:
|
||||
"""Run a scheduled backup (called by APScheduler)."""
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession as _AS
|
||||
@@ -1116,3 +1249,66 @@ async def _run_scheduled_backup() -> None:
|
||||
|
||||
except Exception as e:
|
||||
_LOGGER.error("Scheduled backup failed: %s", e)
|
||||
|
||||
|
||||
# --- Release-check probe -----------------------------------------------------
|
||||
|
||||
_RELEASE_CHECK_JOB_ID = "upstream_release_check"
|
||||
_RELEASE_CHECK_ONESHOT_JOB_ID = "upstream_release_check_oneshot"
|
||||
_RELEASE_CHECK_ONESHOT_DELAY_SECONDS = 30
|
||||
|
||||
|
||||
async def _schedule_release_check() -> None:
|
||||
"""Register the interval + one-shot release-check jobs.
|
||||
|
||||
Reads the configured interval from AppSettings at startup. Idempotent —
|
||||
APScheduler de-dupes via ``replace_existing=True``.
|
||||
"""
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..api.app_settings import get_setting
|
||||
from ..database.engine import get_engine
|
||||
from .release_check import parse_interval_hours, run_check
|
||||
|
||||
async with AsyncSession(get_engine()) as session:
|
||||
raw = await get_setting(session, "release_check_interval_hours")
|
||||
interval_hours = parse_interval_hours(raw)
|
||||
|
||||
scheduler = get_scheduler()
|
||||
scheduler.add_job(
|
||||
run_check,
|
||||
IntervalTrigger(hours=interval_hours),
|
||||
id=_RELEASE_CHECK_JOB_ID,
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
# One-shot probe shortly after start so admins see a fresh status without
|
||||
# waiting for the first interval tick. Mirrors the chat-title sync.
|
||||
scheduler.add_job(
|
||||
run_check,
|
||||
"date",
|
||||
run_date=datetime.now(timezone.utc) + timedelta(seconds=_RELEASE_CHECK_ONESHOT_DELAY_SECONDS),
|
||||
id=_RELEASE_CHECK_ONESHOT_JOB_ID,
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Scheduled release-check every %sh (one-shot in %ss)",
|
||||
interval_hours, _RELEASE_CHECK_ONESHOT_DELAY_SECONDS)
|
||||
|
||||
|
||||
async def reschedule_release_check() -> None:
|
||||
"""Re-arm the release-check job after settings changed.
|
||||
|
||||
Called from the PUT /settings handler when the interval or provider config
|
||||
changes. Removes the existing interval job, lets ``_schedule_release_check``
|
||||
re-read the setting and rebuild it, and queues a fresh one-shot so the new
|
||||
config takes effect within seconds rather than at the next interval tick.
|
||||
"""
|
||||
scheduler = get_scheduler()
|
||||
if scheduler.get_job(_RELEASE_CHECK_JOB_ID):
|
||||
scheduler.remove_job(_RELEASE_CHECK_JOB_ID)
|
||||
if scheduler.get_job(_RELEASE_CHECK_ONESHOT_JOB_ID):
|
||||
scheduler.remove_job(_RELEASE_CHECK_ONESHOT_JOB_ID)
|
||||
await _schedule_release_check()
|
||||
|
||||
@@ -22,8 +22,9 @@ from ..database.models import (
|
||||
ServiceProvider,
|
||||
)
|
||||
from .dispatch_helpers import (
|
||||
GateReason,
|
||||
apply_tracking_display_filters,
|
||||
event_allowed_by_config,
|
||||
evaluate_event_gate,
|
||||
get_app_timezone,
|
||||
load_link_data,
|
||||
)
|
||||
@@ -205,11 +206,16 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
# Load app-level timezone for quiet-hours evaluation.
|
||||
app_tz = await get_app_timezone(session)
|
||||
|
||||
# Snapshot the data we need
|
||||
# Snapshot the data we need. These reads happen INSIDE the open
|
||||
# session so we get fresh attribute values; once the block exits, the
|
||||
# ORM instances become detached and any unfetched attribute access
|
||||
# would raise. Pulling primitives here is the deliberate isolation
|
||||
# boundary between the DB phase and the network phase.
|
||||
provider_type = provider.type
|
||||
provider_config = dict(provider.config)
|
||||
provider_name = provider.name
|
||||
tracker_name = tracker.name
|
||||
tracker_user_id = tracker.user_id
|
||||
tracker_filters = dict(tracker.filters) if tracker.filters else {}
|
||||
collection_ids = list(tracker.collection_ids or [])
|
||||
|
||||
@@ -317,6 +323,10 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
)
|
||||
session.add(new_ts)
|
||||
|
||||
# Capture the event_log row id alongside each event so the dispatch
|
||||
# loop below can stamp a "dispatch_status=deferred" pointer onto the
|
||||
# row if quiet hours suppresses it.
|
||||
event_log_id_by_event: dict[int, int] = {}
|
||||
for event in events:
|
||||
assets_count = event.added_count or event.removed_count or 0
|
||||
details: dict[str, Any] = {
|
||||
@@ -352,6 +362,8 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
details=details,
|
||||
)
|
||||
session.add(log)
|
||||
await session.flush()
|
||||
event_log_id_by_event[id(event)] = log.id
|
||||
|
||||
await session.commit()
|
||||
|
||||
@@ -377,21 +389,54 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
asset_cache=asset_cache,
|
||||
session=shared_session,
|
||||
)
|
||||
from .deferred_dispatch import defer_event, is_deferrable
|
||||
from .scheduler import schedule_deferred_drain
|
||||
from ..database.models import EventLog as _EventLog
|
||||
|
||||
for event in events:
|
||||
_LOGGER.info(
|
||||
"Dispatching event %s for %s (added=%d removed=%d)",
|
||||
event.event_type.value, event.collection_name,
|
||||
event.added_count, event.removed_count,
|
||||
)
|
||||
event_log_id = event_log_id_by_event.get(id(event))
|
||||
# Group targets by tracking-config identity so each unique TC
|
||||
# gets one event-transform pass; targets sharing a TC dispatch
|
||||
# together (preserves the gather-fan-out inside the dispatcher).
|
||||
groups: dict[int, tuple[Any, list[TargetConfig]]] = {}
|
||||
# Track defers in a single dict so we can persist them in one
|
||||
# session + commit at the end of the iteration. ``load_link_data``
|
||||
# emits multiple entries per broadcast link (one per child) sharing
|
||||
# the same parent ``link_id``; the deferred row is one-per-link, so
|
||||
# ``dict`` keying by ``link_id`` naturally dedupes.
|
||||
defers_for_event: dict[int, datetime] = {}
|
||||
scheduled_until: datetime | None = None
|
||||
|
||||
for ld in link_data:
|
||||
tc = ld["tracking_config"]
|
||||
if tc and not event_allowed_by_config(event, tc, app_tz):
|
||||
_LOGGER.info(" Skipped by tracking config filter")
|
||||
continue
|
||||
if tc is not None:
|
||||
outcome = evaluate_event_gate(event, tc, app_tz)
|
||||
if outcome.reason is GateReason.QUIET_HOURS:
|
||||
if is_deferrable(event.event_type.value) and outcome.quiet_hours_end_at is not None:
|
||||
link_id = ld.get("link_id")
|
||||
if link_id is not None:
|
||||
# Per-link earliest fire_at wins if a future
|
||||
# iteration ever supplies a different end.
|
||||
prior = defers_for_event.get(link_id)
|
||||
if prior is None or outcome.quiet_hours_end_at < prior:
|
||||
defers_for_event[link_id] = outcome.quiet_hours_end_at
|
||||
_LOGGER.info(
|
||||
" Deferred until %s (quiet hours)",
|
||||
outcome.quiet_hours_end_at.isoformat() if outcome.quiet_hours_end_at else "?",
|
||||
)
|
||||
else:
|
||||
_LOGGER.info(
|
||||
" Suppressed (quiet hours; event type not deferrable)",
|
||||
)
|
||||
continue
|
||||
if outcome.reason is GateReason.EVENT_TYPE_DISABLED:
|
||||
_LOGGER.info(" Skipped by tracking config filter")
|
||||
continue
|
||||
|
||||
tmpl = ld["template_config"]
|
||||
target_cfg = TargetConfig(
|
||||
@@ -410,6 +455,47 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
groups[key] = (tc, [])
|
||||
groups[key][1].append(target_cfg)
|
||||
|
||||
# Persist defers + stamp the event_log row + schedule drains in a
|
||||
# single transaction. This keeps the "deferred" pill on the
|
||||
# dashboard consistent with the existence of pending rows even if
|
||||
# the process is killed mid-way (either both land or neither does).
|
||||
if defers_for_event:
|
||||
async with AsyncSession(engine) as defer_session:
|
||||
for link_id, fire_at in defers_for_event.items():
|
||||
await defer_event(
|
||||
defer_session,
|
||||
event=event,
|
||||
user_id=tracker_user_id,
|
||||
tracker_id=tracker_id,
|
||||
link_id=link_id,
|
||||
event_log_id=event_log_id,
|
||||
fire_at=fire_at,
|
||||
)
|
||||
if scheduled_until is None or fire_at < scheduled_until:
|
||||
scheduled_until = fire_at
|
||||
# Stamp event_log row inside the SAME session so the
|
||||
# "deferred until" pill is only visible if the rows
|
||||
# actually persist.
|
||||
if event_log_id is not None and scheduled_until is not None:
|
||||
el = await defer_session.get(_EventLog, event_log_id)
|
||||
if el is not None:
|
||||
existing = dict(el.details or {})
|
||||
if not existing.get("dispatch_status"):
|
||||
existing["dispatch_status"] = "deferred"
|
||||
existing["deferred_until"] = scheduled_until.isoformat()
|
||||
el.details = existing
|
||||
defer_session.add(el)
|
||||
await defer_session.commit()
|
||||
# Drain job registration is best-effort: a failure here just
|
||||
# delays delivery until the next scan/restart, not data loss.
|
||||
for fire_at in {*defers_for_event.values()}:
|
||||
try:
|
||||
schedule_deferred_drain(fire_at)
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"Failed to schedule deferred drain for %s", fire_at,
|
||||
)
|
||||
|
||||
for tc, target_configs in groups.values():
|
||||
if not target_configs:
|
||||
continue
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
"""Server version resolution.
|
||||
|
||||
Production Docker images install the wheel and ``importlib.metadata`` is the
|
||||
truth. Editable dev installs (``pip install -e packages/server``) record the
|
||||
version at install time and *don't auto-refresh* when the source ``pyproject.toml``
|
||||
bumps — so a developer that bumped from 0.3.x to 0.7.x without reinstalling
|
||||
will keep reporting 0.3.x via ``importlib.metadata``.
|
||||
|
||||
To make the running app match the source tree without forcing a reinstall,
|
||||
we read both and return the higher of the two. The dist-info wins in prod
|
||||
(no pyproject alongside), the source wins in dev when the editable install is
|
||||
stale.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from importlib.metadata import PackageNotFoundError, version as _pkg_version
|
||||
from pathlib import Path
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_PACKAGE_NAME = "notify-bridge-server"
|
||||
_UNKNOWN = "0.0.0+unknown"
|
||||
|
||||
|
||||
def _read_source_version() -> str | None:
|
||||
"""Best-effort read of the source ``pyproject.toml`` version.
|
||||
|
||||
Returns ``None`` when the file isn't reachable (the normal prod case),
|
||||
so callers fall back to the installed metadata.
|
||||
"""
|
||||
# Module is at packages/server/src/notify_bridge_server/version.py,
|
||||
# pyproject sits at packages/server/pyproject.toml — three parents up.
|
||||
pyproject = Path(__file__).resolve().parents[2] / "pyproject.toml"
|
||||
if not pyproject.is_file():
|
||||
return None
|
||||
try:
|
||||
import tomllib # Python 3.11+ stdlib — server requires 3.12.
|
||||
|
||||
data = tomllib.loads(pyproject.read_text(encoding="utf-8"))
|
||||
version = data.get("project", {}).get("version")
|
||||
return str(version) if version else None
|
||||
except (OSError, ValueError) as err: # pragma: no cover — defensive
|
||||
_LOGGER.debug("Could not read source pyproject version: %s", err)
|
||||
return None
|
||||
|
||||
|
||||
def _segments(version: str) -> tuple[int, ...]:
|
||||
"""Best-effort tuple-of-ints for ordering. Suffixes (``-rc1``) are stripped."""
|
||||
if not version:
|
||||
return ()
|
||||
head = version.split("+", 1)[0].split("-", 1)[0]
|
||||
out: list[int] = []
|
||||
for piece in head.split("."):
|
||||
digits = "".join(c for c in piece if c.isdigit())
|
||||
if digits:
|
||||
out.append(int(digits))
|
||||
return tuple(out)
|
||||
|
||||
|
||||
def resolve_version() -> str:
|
||||
"""Return the version the running server should advertise.
|
||||
|
||||
Prefers the highest of (installed metadata, source pyproject) so an
|
||||
out-of-date editable install never lies to the UI. In production builds
|
||||
only the installed metadata is available, which is correct by definition.
|
||||
"""
|
||||
try:
|
||||
installed: str | None = _pkg_version(_PACKAGE_NAME)
|
||||
except PackageNotFoundError:
|
||||
installed = None
|
||||
source = _read_source_version()
|
||||
|
||||
candidates = [v for v in (installed, source) if v]
|
||||
if not candidates:
|
||||
return _UNKNOWN
|
||||
if len(candidates) == 1:
|
||||
return candidates[0]
|
||||
# Two candidates — return the higher by numeric segments. Ties: prefer
|
||||
# source, since that's what the developer just edited.
|
||||
a, b = candidates
|
||||
return a if _segments(a) > _segments(b) else b
|
||||
@@ -0,0 +1,431 @@
|
||||
"""Tests for the quiet-hours deferred-dispatch pipeline.
|
||||
|
||||
Covers the four behaviours that distinguish the new feature from the legacy
|
||||
"drop on quiet hours" code path:
|
||||
|
||||
1. ``quiet_hours_status`` returns the correct UTC end datetime, including
|
||||
overnight windows that wrap past midnight.
|
||||
2. ``evaluate_event_gate`` distinguishes ``QUIET_HOURS`` (deferrable) from
|
||||
``EVENT_TYPE_DISABLED`` (drop forever).
|
||||
3. ``serialize_event`` / ``deserialize_event`` round-trip without losing
|
||||
asset metadata.
|
||||
4. ``defer_event`` coalesces ``assets_added`` + ``assets_removed`` of the
|
||||
same IDs for the same link+collection — the cancellation case that
|
||||
motivated the whole feature.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from notify_bridge_core.models.events import EventType, ServiceEvent
|
||||
from notify_bridge_core.models.media import MediaAsset, MediaType
|
||||
from notify_bridge_core.providers.base import ServiceProviderType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Quiet-hours math
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_quiet_hours_status_inside_normal_window(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from notify_bridge_server.services import dispatch_helpers as dh
|
||||
|
||||
# Pretend it's 13:00 UTC inside a 12:00-14:00 window.
|
||||
class _FixedDatetime(datetime):
|
||||
@classmethod
|
||||
def now(cls, tz=None):
|
||||
return datetime(2026, 5, 12, 13, 0, tzinfo=timezone.utc)
|
||||
|
||||
monkeypatch.setattr(dh, "datetime", _FixedDatetime)
|
||||
end_at = dh.quiet_hours_status("12:00", "14:00", "UTC")
|
||||
assert end_at == datetime(2026, 5, 12, 14, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def test_quiet_hours_status_start_equals_end_returns_none() -> None:
|
||||
"""``00:00-00:00`` is ambiguous (single instant vs always-on); treat as no window.
|
||||
|
||||
Code-review feedback: without this guard, the overnight-window branch would
|
||||
interpret it as "always quiet" and silently defer every notification all
|
||||
day. The conservative read is that the user misconfigured and we should
|
||||
behave as if quiet hours were off.
|
||||
"""
|
||||
from notify_bridge_server.services import dispatch_helpers as dh
|
||||
|
||||
assert dh.quiet_hours_status("00:00", "00:00", "UTC") is None
|
||||
assert dh.quiet_hours_status("13:30", "13:30", "UTC") is None
|
||||
|
||||
|
||||
def test_quiet_hours_status_outside_window_returns_none(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from notify_bridge_server.services import dispatch_helpers as dh
|
||||
|
||||
class _FixedDatetime(datetime):
|
||||
@classmethod
|
||||
def now(cls, tz=None):
|
||||
return datetime(2026, 5, 12, 15, 0, tzinfo=timezone.utc)
|
||||
|
||||
monkeypatch.setattr(dh, "datetime", _FixedDatetime)
|
||||
assert dh.quiet_hours_status("12:00", "14:00", "UTC") is None
|
||||
|
||||
|
||||
def test_quiet_hours_status_overnight_window_post_midnight(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""22:00-06:00 window, current time 03:00 → window ends today at 06:00."""
|
||||
from notify_bridge_server.services import dispatch_helpers as dh
|
||||
|
||||
class _FixedDatetime(datetime):
|
||||
@classmethod
|
||||
def now(cls, tz=None):
|
||||
return datetime(2026, 5, 12, 3, 0, tzinfo=timezone.utc)
|
||||
|
||||
monkeypatch.setattr(dh, "datetime", _FixedDatetime)
|
||||
end_at = dh.quiet_hours_status("22:00", "06:00", "UTC")
|
||||
assert end_at == datetime(2026, 5, 12, 6, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def test_quiet_hours_status_overnight_window_pre_midnight(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""22:00-06:00 window, current time 23:30 → window ends tomorrow at 06:00."""
|
||||
from notify_bridge_server.services import dispatch_helpers as dh
|
||||
|
||||
class _FixedDatetime(datetime):
|
||||
@classmethod
|
||||
def now(cls, tz=None):
|
||||
return datetime(2026, 5, 12, 23, 30, tzinfo=timezone.utc)
|
||||
|
||||
monkeypatch.setattr(dh, "datetime", _FixedDatetime)
|
||||
end_at = dh.quiet_hours_status("22:00", "06:00", "UTC")
|
||||
assert end_at == datetime(2026, 5, 13, 6, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gate enum / outcome
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_event(
|
||||
event_type: EventType = EventType.ASSETS_ADDED,
|
||||
*,
|
||||
added_assets: list[MediaAsset] | None = None,
|
||||
) -> ServiceEvent:
|
||||
return ServiceEvent(
|
||||
event_type=event_type,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
provider_name="test-immich",
|
||||
collection_id="col-1",
|
||||
collection_name="Album A",
|
||||
timestamp=datetime(2026, 5, 12, 12, 0, tzinfo=timezone.utc),
|
||||
added_assets=added_assets or [],
|
||||
added_count=len(added_assets or []),
|
||||
)
|
||||
|
||||
|
||||
def _make_asset(asset_id: str, *, filename: str | None = None) -> MediaAsset:
|
||||
return MediaAsset(
|
||||
id=asset_id,
|
||||
type=MediaType.IMAGE,
|
||||
filename=filename or f"{asset_id}.jpg",
|
||||
created_at=datetime(2026, 5, 12, 12, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
class _FakeTrackingConfig:
|
||||
"""Minimal stand-in for TrackingConfig — only the fields the gate reads."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
quiet_hours_enabled: bool = False,
|
||||
quiet_hours_start: str | None = None,
|
||||
quiet_hours_end: str | None = None,
|
||||
track_assets_added: bool = True,
|
||||
) -> None:
|
||||
self.quiet_hours_enabled = quiet_hours_enabled
|
||||
self.quiet_hours_start = quiet_hours_start
|
||||
self.quiet_hours_end = quiet_hours_end
|
||||
self.track_assets_added = track_assets_added
|
||||
# The gate's flag map reads every track_* attribute; set the rest to
|
||||
# True so it doesn't accidentally block on an unrelated event type.
|
||||
for attr in (
|
||||
"track_assets_removed", "track_collection_renamed",
|
||||
"track_collection_deleted", "track_sharing_changed",
|
||||
"track_push", "track_issue_opened", "track_issue_closed",
|
||||
"track_issue_commented", "track_pr_opened", "track_pr_closed",
|
||||
"track_pr_merged", "track_pr_commented", "track_release_published",
|
||||
"track_card_created", "track_card_updated", "track_card_moved",
|
||||
"track_card_deleted", "track_card_commented", "track_comment_updated",
|
||||
"track_board_created", "track_board_updated", "track_board_deleted",
|
||||
"track_list_created", "track_list_updated", "track_list_deleted",
|
||||
"track_attachment_created", "track_card_label_added",
|
||||
"track_task_completed", "track_scheduled_message",
|
||||
"track_webhook_received", "track_ups_online", "track_ups_on_battery",
|
||||
"track_ups_low_battery", "track_ups_battery_restored",
|
||||
"track_ups_comms_lost", "track_ups_comms_restored",
|
||||
"track_ups_replace_battery", "track_ups_overload",
|
||||
):
|
||||
setattr(self, attr, True)
|
||||
|
||||
|
||||
def test_gate_quiet_hours_wins_over_event_type_flag(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from notify_bridge_server.services import dispatch_helpers as dh
|
||||
|
||||
class _FixedDatetime(datetime):
|
||||
@classmethod
|
||||
def now(cls, tz=None):
|
||||
return datetime(2026, 5, 12, 13, 0, tzinfo=timezone.utc)
|
||||
|
||||
monkeypatch.setattr(dh, "datetime", _FixedDatetime)
|
||||
tc = _FakeTrackingConfig(
|
||||
quiet_hours_enabled=True,
|
||||
quiet_hours_start="12:00",
|
||||
quiet_hours_end="14:00",
|
||||
# Even with the event-type flag flipped off, quiet hours should be
|
||||
# the reported reason — it's the "louder" gate. The downstream defer
|
||||
# path treats this as a deferral candidate; flipping the order would
|
||||
# silently drop deferrable events when both gates are closed.
|
||||
track_assets_added=False,
|
||||
)
|
||||
outcome = dh.evaluate_event_gate(_make_event(), tc, "UTC")
|
||||
assert outcome.reason is dh.GateReason.QUIET_HOURS
|
||||
assert outcome.quiet_hours_end_at == datetime(2026, 5, 12, 14, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def test_gate_event_type_disabled_when_quiet_hours_off() -> None:
|
||||
from notify_bridge_server.services import dispatch_helpers as dh
|
||||
|
||||
tc = _FakeTrackingConfig(quiet_hours_enabled=False, track_assets_added=False)
|
||||
outcome = dh.evaluate_event_gate(_make_event(), tc, "UTC")
|
||||
assert outcome.reason is dh.GateReason.EVENT_TYPE_DISABLED
|
||||
assert outcome.quiet_hours_end_at is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event payload round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_serialize_deserialize_roundtrips_assets_and_extras() -> None:
|
||||
from notify_bridge_server.services import deferred_dispatch as dd
|
||||
|
||||
asset = _make_asset("a1")
|
||||
asset.extra = {"city": "Minsk", "is_favorite": True, "rating": 5}
|
||||
event = _make_event(added_assets=[asset])
|
||||
event.extra = {"people": ["Alice"]}
|
||||
|
||||
payload = dd.serialize_event(event)
|
||||
restored = dd.deserialize_event(payload)
|
||||
|
||||
assert restored.event_type is EventType.ASSETS_ADDED
|
||||
assert restored.provider_type is ServiceProviderType.IMMICH
|
||||
assert restored.collection_id == "col-1"
|
||||
assert len(restored.added_assets) == 1
|
||||
assert restored.added_assets[0].id == "a1"
|
||||
assert restored.added_assets[0].extra["city"] == "Minsk"
|
||||
assert restored.extra["people"] == ["Alice"]
|
||||
assert restored.timestamp == event.timestamp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Coalescing — the add-then-remove cancellation that motivated the design
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture
|
||||
async def empty_session():
|
||||
"""In-memory SQLite session for coalescing tests — no fixtures, just a clean DB."""
|
||||
# Importing models here registers them on SQLModel.metadata. We rely on
|
||||
# ``DeferredDispatch`` being declared so create_all picks it up.
|
||||
from notify_bridge_server.database import models # noqa: F401 — side effect
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_then_remove_same_assets_cancels_pending(empty_session: AsyncSession) -> None:
|
||||
"""User adds {A, B}, then removes {A, B} — both pending rows should disappear.
|
||||
|
||||
Before this feature this scenario would either spam two late notifications
|
||||
("added" then "removed") or silently drop both. The cancellation path is
|
||||
the win that justified the coalescing module.
|
||||
"""
|
||||
from notify_bridge_server.services import deferred_dispatch as dd
|
||||
from notify_bridge_server.database.models import DeferredDispatch
|
||||
|
||||
fire_at = datetime(2026, 5, 13, 6, 0, tzinfo=timezone.utc)
|
||||
add_event = _make_event(
|
||||
EventType.ASSETS_ADDED,
|
||||
added_assets=[_make_asset("A"), _make_asset("B")],
|
||||
)
|
||||
result = await dd.defer_event(
|
||||
empty_session,
|
||||
event=add_event,
|
||||
user_id=1, tracker_id=1, link_id=1,
|
||||
event_log_id=100, fire_at=fire_at,
|
||||
)
|
||||
await empty_session.commit()
|
||||
assert result == "inserted"
|
||||
|
||||
remove_event = ServiceEvent(
|
||||
event_type=EventType.ASSETS_REMOVED,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
provider_name="test-immich",
|
||||
collection_id="col-1",
|
||||
collection_name="Album A",
|
||||
timestamp=datetime(2026, 5, 12, 12, 5, tzinfo=timezone.utc),
|
||||
removed_asset_ids=["A", "B"],
|
||||
removed_count=2,
|
||||
)
|
||||
result = await dd.defer_event(
|
||||
empty_session,
|
||||
event=remove_event,
|
||||
user_id=1, tracker_id=1, link_id=1,
|
||||
event_log_id=101, fire_at=fire_at,
|
||||
)
|
||||
await empty_session.commit()
|
||||
|
||||
pending = (await empty_session.exec(
|
||||
select(DeferredDispatch).where(DeferredDispatch.status == "pending")
|
||||
)).all()
|
||||
assert pending == [], "add-then-remove of same IDs should leave the queue empty"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_then_partial_remove_keeps_remainder(empty_session: AsyncSession) -> None:
|
||||
"""User adds {A, B, C}, then removes {B} — pending row should contain {A, C}."""
|
||||
from notify_bridge_server.services import deferred_dispatch as dd
|
||||
from notify_bridge_server.database.models import DeferredDispatch
|
||||
|
||||
fire_at = datetime(2026, 5, 13, 6, 0, tzinfo=timezone.utc)
|
||||
await dd.defer_event(
|
||||
empty_session,
|
||||
event=_make_event(EventType.ASSETS_ADDED, added_assets=[
|
||||
_make_asset("A"), _make_asset("B"), _make_asset("C"),
|
||||
]),
|
||||
user_id=1, tracker_id=1, link_id=1,
|
||||
event_log_id=100, fire_at=fire_at,
|
||||
)
|
||||
await empty_session.commit()
|
||||
|
||||
remove_event = ServiceEvent(
|
||||
event_type=EventType.ASSETS_REMOVED,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
provider_name="test-immich",
|
||||
collection_id="col-1",
|
||||
collection_name="Album A",
|
||||
timestamp=datetime(2026, 5, 12, 12, 5, tzinfo=timezone.utc),
|
||||
removed_asset_ids=["B"],
|
||||
removed_count=1,
|
||||
)
|
||||
await dd.defer_event(
|
||||
empty_session,
|
||||
event=remove_event,
|
||||
user_id=1, tracker_id=1, link_id=1,
|
||||
event_log_id=101, fire_at=fire_at,
|
||||
)
|
||||
await empty_session.commit()
|
||||
|
||||
rows = (await empty_session.exec(
|
||||
select(DeferredDispatch).where(DeferredDispatch.status == "pending")
|
||||
)).all()
|
||||
# Only the assets_added row survives (B subtracted). No assets_removed
|
||||
# row because B was just added — its removal is a wash.
|
||||
assert len(rows) == 1
|
||||
assert rows[0].event_type == "assets_added"
|
||||
remaining_ids = sorted(a["id"] for a in rows[0].event_payload["added_assets"])
|
||||
assert remaining_ids == ["A", "C"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_then_add_unions_assets(empty_session: AsyncSession) -> None:
|
||||
"""Two consecutive assets_added events should merge into one pending row."""
|
||||
from notify_bridge_server.services import deferred_dispatch as dd
|
||||
from notify_bridge_server.database.models import DeferredDispatch
|
||||
|
||||
fire_at = datetime(2026, 5, 13, 6, 0, tzinfo=timezone.utc)
|
||||
await dd.defer_event(
|
||||
empty_session,
|
||||
event=_make_event(EventType.ASSETS_ADDED, added_assets=[_make_asset("A")]),
|
||||
user_id=1, tracker_id=1, link_id=1,
|
||||
event_log_id=100, fire_at=fire_at,
|
||||
)
|
||||
await empty_session.commit()
|
||||
await dd.defer_event(
|
||||
empty_session,
|
||||
event=_make_event(EventType.ASSETS_ADDED, added_assets=[
|
||||
_make_asset("B"), _make_asset("C"),
|
||||
]),
|
||||
user_id=1, tracker_id=1, link_id=1,
|
||||
event_log_id=101, fire_at=fire_at,
|
||||
)
|
||||
await empty_session.commit()
|
||||
|
||||
rows = (await empty_session.exec(
|
||||
select(DeferredDispatch).where(DeferredDispatch.status == "pending")
|
||||
)).all()
|
||||
assert len(rows) == 1
|
||||
merged_ids = sorted(a["id"] for a in rows[0].event_payload["added_assets"])
|
||||
assert merged_ids == ["A", "B", "C"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_asset_event_is_not_coalesced(empty_session: AsyncSession) -> None:
|
||||
"""Two push events for the same repo should both be queued — historical facts."""
|
||||
from notify_bridge_server.services import deferred_dispatch as dd
|
||||
from notify_bridge_server.database.models import DeferredDispatch
|
||||
|
||||
fire_at = datetime(2026, 5, 13, 6, 0, tzinfo=timezone.utc)
|
||||
for i in range(2):
|
||||
push_event = ServiceEvent(
|
||||
event_type=EventType.PUSH,
|
||||
provider_type=ServiceProviderType.GITEA,
|
||||
provider_name="test-gitea",
|
||||
collection_id="repo-1",
|
||||
collection_name="my/repo",
|
||||
timestamp=datetime(2026, 5, 12, 12, i, tzinfo=timezone.utc),
|
||||
extra={"commit_sha": f"sha{i}"},
|
||||
)
|
||||
await dd.defer_event(
|
||||
empty_session,
|
||||
event=push_event,
|
||||
user_id=1, tracker_id=1, link_id=1,
|
||||
event_log_id=100 + i, fire_at=fire_at,
|
||||
)
|
||||
await empty_session.commit()
|
||||
|
||||
rows = (await empty_session.exec(
|
||||
select(DeferredDispatch).where(DeferredDispatch.status == "pending")
|
||||
)).all()
|
||||
# Both rows survive — pushes don't cancel one another.
|
||||
assert len(rows) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scheduled_message_is_non_deferrable(empty_session: AsyncSession) -> None:
|
||||
"""``scheduled_message`` is wall-clock — defer_event should refuse to enqueue."""
|
||||
from notify_bridge_server.services import deferred_dispatch as dd
|
||||
from notify_bridge_server.database.models import DeferredDispatch
|
||||
|
||||
sched_event = ServiceEvent(
|
||||
event_type=EventType.SCHEDULED_MESSAGE,
|
||||
provider_type=ServiceProviderType.SCHEDULER,
|
||||
provider_name="sched",
|
||||
collection_id="",
|
||||
collection_name="",
|
||||
timestamp=datetime(2026, 5, 12, 12, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
result = await dd.defer_event(
|
||||
empty_session,
|
||||
event=sched_event,
|
||||
user_id=1, tracker_id=1, link_id=1,
|
||||
event_log_id=100,
|
||||
fire_at=datetime(2026, 5, 13, 6, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
assert result == "non_deferrable"
|
||||
await empty_session.commit()
|
||||
rows = (await empty_session.exec(select(DeferredDispatch))).all()
|
||||
assert rows == []
|
||||
@@ -0,0 +1,235 @@
|
||||
"""Tests for the release provider abstraction and Gitea probe."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from notify_bridge_core.release import build_release_provider, is_valid_repo
|
||||
from notify_bridge_core.release.base import (
|
||||
ReleaseErrorCode,
|
||||
ReleaseProviderKind,
|
||||
compare_versions,
|
||||
is_newer,
|
||||
normalise_version,
|
||||
)
|
||||
from notify_bridge_core.release.gitea import GiteaReleaseProvider
|
||||
|
||||
|
||||
# --- pure utilities ---------------------------------------------------------
|
||||
|
||||
|
||||
def test_normalise_version_strips_v_prefix() -> None:
|
||||
assert normalise_version("v1.2.3") == "1.2.3"
|
||||
assert normalise_version("V1.2.3") == "1.2.3"
|
||||
assert normalise_version("1.2.3") == "1.2.3"
|
||||
assert normalise_version("") == ""
|
||||
# Only strip ``v`` when followed by a digit — guard against names like
|
||||
# ``vendor-1`` being mangled into ``endor-1``.
|
||||
assert normalise_version("vendor-1") == "vendor-1"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("a", "b", "expected"),
|
||||
[
|
||||
("0.7.3", "0.7.2", 1),
|
||||
("0.7.2", "0.7.3", -1),
|
||||
("0.7.2", "0.7.2", 0),
|
||||
("v0.7.3", "0.7.2", 1),
|
||||
("1.0.0", "0.9.99", 1),
|
||||
# Stable beats prerelease at equal numerics (tie-break).
|
||||
("0.7.2-rc1", "0.7.2", -1),
|
||||
("0.7.2", "0.7.2-rc1", 1),
|
||||
# Implicit prerelease form ``1.0a2`` must NOT extract ``2`` as a
|
||||
# third numeric segment — equal to ``1.0`` stable, then stable wins.
|
||||
("1.0a2", "1.0", -1),
|
||||
("", "0.0.0", 0),
|
||||
],
|
||||
)
|
||||
def test_compare_versions(a: str, b: str, expected: int) -> None:
|
||||
assert compare_versions(a, b) == expected
|
||||
|
||||
|
||||
def test_is_newer_is_strict() -> None:
|
||||
assert is_newer("0.7.3", "0.7.2") is True
|
||||
assert is_newer("0.7.2", "0.7.2") is False
|
||||
# A pre-release of the next minor should still be flagged as newer when
|
||||
# explicitly fetched with include_prereleases=True at the provider level.
|
||||
assert is_newer("0.7.3-rc1", "0.7.2") is True
|
||||
|
||||
|
||||
def test_is_valid_repo() -> None:
|
||||
assert is_valid_repo("alexei.dolgolyov/notify-bridge") is True
|
||||
assert is_valid_repo("a/b") is True
|
||||
assert is_valid_repo("a_b/c.d-e") is True
|
||||
assert is_valid_repo("") is False
|
||||
assert is_valid_repo("no-slash") is False
|
||||
# Path-traversal attempts.
|
||||
assert is_valid_repo("foo/bar/../admin") is False
|
||||
assert is_valid_repo("foo/bar/baz") is False
|
||||
assert is_valid_repo("foo/../bar") is False
|
||||
# Embedded special chars.
|
||||
assert is_valid_repo("foo@bar/baz") is False
|
||||
assert is_valid_repo("foo/bar?x=1") is False
|
||||
|
||||
|
||||
# --- registry ---------------------------------------------------------------
|
||||
|
||||
|
||||
def test_registry_returns_none_for_disabled() -> None:
|
||||
assert build_release_provider("disabled", session=MagicMock(), url="x", repo="a/b") is None
|
||||
|
||||
|
||||
def test_registry_returns_none_for_unknown_kind() -> None:
|
||||
assert build_release_provider("svn", session=MagicMock(), url="x", repo="a/b") is None
|
||||
|
||||
|
||||
def test_registry_gitea_requires_url_and_valid_repo() -> None:
|
||||
sess = MagicMock()
|
||||
assert build_release_provider("gitea", session=sess, url="", repo="a/b") is None
|
||||
assert build_release_provider("gitea", session=sess, url="https://x", repo="") is None
|
||||
# Path traversal blocked by repo validation.
|
||||
assert build_release_provider("gitea", session=sess, url="https://x", repo="a/b/../c") is None
|
||||
provider = build_release_provider("gitea", session=sess, url="https://x", repo="a/b")
|
||||
assert isinstance(provider, GiteaReleaseProvider)
|
||||
assert provider.kind is ReleaseProviderKind.GITEA
|
||||
|
||||
|
||||
# --- Gitea provider ---------------------------------------------------------
|
||||
|
||||
|
||||
def _gitea_payload(**overrides: Any) -> list[dict[str, Any]]:
|
||||
base = {
|
||||
"tag_name": "v0.7.3",
|
||||
"name": "v0.7.3",
|
||||
"html_url": "https://git.example.com/owner/repo/releases/tag/v0.7.3",
|
||||
"body": "Notes",
|
||||
"published_at": "2026-05-01T00:00:00Z",
|
||||
"draft": False,
|
||||
"prerelease": False,
|
||||
}
|
||||
base.update(overrides)
|
||||
return [base]
|
||||
|
||||
|
||||
class _FakeContent:
|
||||
def __init__(self, raw: bytes) -> None:
|
||||
self._raw = raw
|
||||
|
||||
async def read(self, n: int = -1) -> bytes:
|
||||
return self._raw if n < 0 else self._raw[:n]
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status: int, payload: Any) -> None:
|
||||
self.status = status
|
||||
import json
|
||||
|
||||
self.content = _FakeContent(json.dumps(payload).encode("utf-8"))
|
||||
self._payload = payload
|
||||
|
||||
async def json(self) -> Any:
|
||||
return self._payload
|
||||
|
||||
async def __aenter__(self) -> "_FakeResponse":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _session_with(payload: Any, status: int = 200) -> MagicMock:
|
||||
"""Return a session whose `.get()` yields a fresh response per call.
|
||||
|
||||
Using ``side_effect`` rather than ``return_value`` ensures multiple
|
||||
awaited fetches don't share mutable response state across tests.
|
||||
"""
|
||||
sess = MagicMock()
|
||||
sess.get = MagicMock(side_effect=lambda *a, **kw: _FakeResponse(status, payload))
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _allow_private_urls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""SSRF guard rejects example.com → publicly resolvable, so tests pass.
|
||||
|
||||
But we explicitly enable the bypass to remove DNS-resolution flakiness
|
||||
from CI runs.
|
||||
"""
|
||||
monkeypatch.setenv("NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS", "1")
|
||||
# Reload the ssrf module to pick up the env var (it's read at import).
|
||||
import importlib
|
||||
|
||||
import notify_bridge_core.notifications.ssrf as ssrf_mod
|
||||
importlib.reload(ssrf_mod)
|
||||
|
||||
|
||||
async def test_gitea_fetch_latest_happy_path() -> None:
|
||||
sess = _session_with(_gitea_payload())
|
||||
provider = GiteaReleaseProvider(sess, "https://git.example.com/", "owner/repo")
|
||||
|
||||
info = await provider.fetch_latest(include_prereleases=False)
|
||||
assert info is not None
|
||||
assert info.tag == "v0.7.3"
|
||||
assert info.version == "0.7.3"
|
||||
assert info.url == "https://git.example.com/owner/repo/releases/tag/v0.7.3"
|
||||
assert info.prerelease is False
|
||||
|
||||
|
||||
async def test_gitea_skips_prereleases_by_default() -> None:
|
||||
payload = _gitea_payload(prerelease=True)
|
||||
sess = _session_with(payload)
|
||||
provider = GiteaReleaseProvider(sess, "https://x.example.com", "a/b")
|
||||
assert await provider.fetch_latest(include_prereleases=False) is None
|
||||
|
||||
|
||||
async def test_gitea_includes_prereleases_when_asked() -> None:
|
||||
payload = _gitea_payload(prerelease=True)
|
||||
sess = _session_with(payload)
|
||||
provider = GiteaReleaseProvider(sess, "https://x.example.com", "a/b")
|
||||
info = await provider.fetch_latest(include_prereleases=True)
|
||||
assert info is not None
|
||||
assert info.prerelease is True
|
||||
|
||||
|
||||
async def test_gitea_skips_drafts() -> None:
|
||||
payload = _gitea_payload(draft=True)
|
||||
sess = _session_with(payload)
|
||||
provider = GiteaReleaseProvider(sess, "https://x.example.com", "a/b")
|
||||
assert await provider.fetch_latest(include_prereleases=True) is None
|
||||
|
||||
|
||||
async def test_gitea_returns_none_on_http_error() -> None:
|
||||
sess = _session_with([], status=500)
|
||||
provider = GiteaReleaseProvider(sess, "https://x.example.com", "a/b")
|
||||
assert await provider.fetch_latest() is None
|
||||
|
||||
|
||||
async def test_gitea_test_returns_structured_status() -> None:
|
||||
sess = _session_with(_gitea_payload())
|
||||
provider = GiteaReleaseProvider(sess, "https://x.example.com", "a/b")
|
||||
result = await provider.test()
|
||||
assert result["ok"] is True
|
||||
assert result["info"] is not None
|
||||
assert result["error"] is None
|
||||
|
||||
|
||||
async def test_gitea_test_reports_http_error() -> None:
|
||||
sess = _session_with([], status=404)
|
||||
provider = GiteaReleaseProvider(sess, "https://x.example.com", "a/b")
|
||||
result = await provider.test()
|
||||
assert result["ok"] is False
|
||||
assert result["info"] is None
|
||||
# Taxonomy code, not a raw exception string.
|
||||
assert result["error"] in {code.value for code in ReleaseErrorCode}
|
||||
|
||||
|
||||
def test_gitea_constructor_validates_repo_format() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
GiteaReleaseProvider(MagicMock(), "https://x.example.com", "no-slash")
|
||||
with pytest.raises(ValueError):
|
||||
GiteaReleaseProvider(MagicMock(), "https://x.example.com", "foo/bar/../baz")
|
||||
with pytest.raises(ValueError):
|
||||
GiteaReleaseProvider(MagicMock(), "", "owner/repo")
|
||||
@@ -0,0 +1,144 @@
|
||||
"""Tests for the release_check service (interval clamping + status endpoints + persistence)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_parse_interval_hours_clamps_and_defaults() -> None:
|
||||
from notify_bridge_server.services.release_check import parse_interval_hours
|
||||
|
||||
assert parse_interval_hours("12") == 12
|
||||
assert parse_interval_hours("") == 12 # default
|
||||
assert parse_interval_hours(None) == 12
|
||||
assert parse_interval_hours("0") == 1 # clamped to min
|
||||
assert parse_interval_hours("9999") == 168 # clamped to max
|
||||
assert parse_interval_hours("not-a-number") == 12 # fallback to default
|
||||
assert parse_interval_hours("24") == 24
|
||||
|
||||
|
||||
def test_release_endpoint_anonymous_is_rejected(tmp_data_dir) -> None: # noqa: ARG001
|
||||
"""GET /api/settings/release requires auth — same as other settings."""
|
||||
from notify_bridge_server.main import app
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/settings/release")
|
||||
# Either 401 (missing token) or 403 (not authenticated) is acceptable.
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
|
||||
def test_release_force_check_requires_admin(tmp_data_dir) -> None: # noqa: ARG001
|
||||
from notify_bridge_server.main import app
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/api/settings/release/check")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
|
||||
def test_release_test_requires_admin(tmp_data_dir) -> None: # noqa: ARG001
|
||||
from notify_bridge_server.main import app
|
||||
|
||||
with TestClient(app) as client:
|
||||
resp = client.post(
|
||||
"/api/settings/release/test",
|
||||
json={"provider_kind": "gitea", "provider_url": "https://x.example.com", "provider_repo": "a/b"},
|
||||
)
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
|
||||
# --- Persistence round-trip -------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_release_state_round_trip(tmp_data_dir, monkeypatch) -> None: # noqa: ARG001
|
||||
"""Write a fake ReleaseInfo, read it back via load_status, assert flags."""
|
||||
from notify_bridge_core.release import ReleaseInfo
|
||||
from notify_bridge_server.database.engine import init_db
|
||||
from notify_bridge_server.services.release_check import (
|
||||
load_status,
|
||||
persist_release_state,
|
||||
)
|
||||
|
||||
await init_db()
|
||||
|
||||
info = ReleaseInfo(
|
||||
tag="v0.9.0",
|
||||
version="0.9.0",
|
||||
name="0.9.0 — Aurora",
|
||||
body="Release notes",
|
||||
url="https://example.com/x/y/releases/tag/v0.9.0",
|
||||
published_at="2026-06-01T00:00:00Z",
|
||||
prerelease=False,
|
||||
draft=False,
|
||||
)
|
||||
await persist_release_state(
|
||||
checked_at="2026-06-01T00:01:00+00:00",
|
||||
error=None,
|
||||
info=info,
|
||||
)
|
||||
|
||||
# Force the comparator to see an older "current" so update_available
|
||||
# comes out True regardless of the actual installed package version.
|
||||
monkeypatch.setattr(
|
||||
"notify_bridge_server.services.release_check._server_version",
|
||||
lambda: "0.7.0",
|
||||
)
|
||||
status = await load_status()
|
||||
assert status.latest == "0.9.0"
|
||||
assert status.latest_tag == "v0.9.0"
|
||||
assert status.update_available is True
|
||||
assert status.error is None
|
||||
assert status.latest_body == "Release notes"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_release_state_clears_on_none_info(tmp_data_dir, monkeypatch) -> None: # noqa: ARG001
|
||||
"""A persist call with ``info=None`` must blank all the latest-* fields."""
|
||||
from notify_bridge_core.release import ReleaseInfo
|
||||
from notify_bridge_server.database.engine import init_db
|
||||
from notify_bridge_server.services.release_check import (
|
||||
load_status,
|
||||
persist_release_state,
|
||||
)
|
||||
|
||||
await init_db()
|
||||
|
||||
# Seed a populated row.
|
||||
await persist_release_state(
|
||||
checked_at="2026-06-01T00:00:00+00:00",
|
||||
error=None,
|
||||
info=ReleaseInfo(tag="v9.9.9", version="9.9.9"),
|
||||
)
|
||||
# Now wipe by passing info=None — mimics the "provider_changed" flow.
|
||||
await persist_release_state(
|
||||
checked_at="2026-06-01T00:02:00+00:00",
|
||||
error="provider_changed",
|
||||
info=None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"notify_bridge_server.services.release_check._server_version",
|
||||
lambda: "0.7.0",
|
||||
)
|
||||
status = await load_status()
|
||||
assert status.latest is None
|
||||
assert status.latest_tag is None
|
||||
assert status.update_available is False
|
||||
assert status.error == "provider_changed"
|
||||
|
||||
|
||||
# --- Version resolver -------------------------------------------------------
|
||||
|
||||
|
||||
def test_resolve_version_prefers_source_pyproject() -> None:
|
||||
"""When pyproject.toml is alongside the source, prefer the higher of (installed, source)."""
|
||||
from notify_bridge_server.version import resolve_version
|
||||
|
||||
v = resolve_version()
|
||||
assert v != "0.0.0+unknown"
|
||||
# If the editable install is stale (e.g. 0.3.2) but pyproject says 0.7.2,
|
||||
# resolve_version must return 0.7.2 (or higher) — the resolver's
|
||||
# whole purpose. We test the "not stale" half of the contract here.
|
||||
parts = v.split(".")
|
||||
assert len(parts) >= 2
|
||||
assert parts[0].isdigit()
|
||||
Reference in New Issue
Block a user