"""Shared dispatch helpers used by both watcher and webhook handlers.""" from __future__ import annotations import dataclasses import logging import random from datetime import datetime, time, timezone from typing import Any, Callable from zoneinfo import ZoneInfo, ZoneInfoNotFoundError from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from notify_bridge_core.models.events import ServiceEvent from notify_bridge_core.models.media import MediaAsset from notify_bridge_core.notifications.receiver import Receiver, build_receiver from ..database.models import ( EmailBot, MatrixBot, NotificationTarget, NotificationTracker, NotificationTrackerTarget, TargetReceiver, TelegramBot, TelegramChat, TemplateConfig, TemplateSlot, TrackingConfig, ) _LOGGER = logging.getLogger(__name__) def _resolve_zoneinfo(tz_name: str | None) -> ZoneInfo: """Resolve an IANA tz string to a ZoneInfo, falling back to UTC on any error.""" if not tz_name: return ZoneInfo("UTC") try: return ZoneInfo(tz_name) except (ZoneInfoNotFoundError, ValueError): _LOGGER.warning("Unknown timezone %r; falling back to UTC", tz_name) return ZoneInfo("UTC") def in_quiet_hours( start: str | None, end: str | None, tz_name: str | None = "UTC", ) -> bool: """Check if the current time (in the given timezone) is within the quiet window. HH:MM strings are interpreted in the supplied timezone. If either bound is missing, quiet hours are disabled. """ if not start or not end: return False try: tz = _resolve_zoneinfo(tz_name) now = datetime.now(timezone.utc).astimezone(tz).time() t_start = time.fromisoformat(start) t_end = time.fromisoformat(end) if t_start <= t_end: return t_start <= now <= t_end else: # Overnight window (e.g., 22:00 - 06:00) return now >= t_start or now <= t_end except (ValueError, TypeError): return False async def get_app_timezone(session: AsyncSession) -> str: """Load the app-level timezone from AppSetting (falls back to UTC).""" from ..api.app_settings import get_setting value = await get_setting(session, "timezone") return value or "UTC" def event_allowed_by_config( event: ServiceEvent, tc: TrackingConfig, tz_name: str | None = "UTC", ) -> bool: """Check if an event is allowed by the tracking config's flags + quiet hours.""" # Quiet hours gate every event type when enabled. if tc.quiet_hours_enabled and in_quiet_hours( tc.quiet_hours_start, tc.quiet_hours_end, tz_name ): return False event_type = event.event_type.value flag_map = { # Immich events "assets_added": tc.track_assets_added, "assets_removed": tc.track_assets_removed, "collection_renamed": tc.track_collection_renamed, "collection_deleted": tc.track_collection_deleted, "sharing_changed": tc.track_sharing_changed, # Gitea events "push": tc.track_push, "issue_opened": tc.track_issue_opened, "issue_closed": tc.track_issue_closed, "issue_commented": tc.track_issue_commented, "pr_opened": tc.track_pr_opened, "pr_closed": tc.track_pr_closed, "pr_merged": tc.track_pr_merged, "pr_commented": tc.track_pr_commented, "release_published": tc.track_release_published, # Planka events "card_created": tc.track_card_created, "card_updated": tc.track_card_updated, "card_moved": tc.track_card_moved, "card_deleted": tc.track_card_deleted, "card_commented": tc.track_card_commented, "comment_updated": tc.track_comment_updated, "board_created": tc.track_board_created, "board_updated": tc.track_board_updated, "board_deleted": tc.track_board_deleted, "list_created": tc.track_list_created, "list_updated": tc.track_list_updated, "list_deleted": tc.track_list_deleted, "attachment_created": tc.track_attachment_created, "card_label_added": tc.track_card_label_added, "task_completed": tc.track_task_completed, # Scheduler events "scheduled_message": tc.track_scheduled_message, # Generic Webhook events "webhook_received": tc.track_webhook_received, # NUT (UPS) events "ups_online": tc.track_ups_online, "ups_on_battery": tc.track_ups_on_battery, "ups_low_battery": tc.track_ups_low_battery, "ups_battery_restored": tc.track_ups_battery_restored, "ups_comms_lost": tc.track_ups_comms_lost, "ups_comms_restored": tc.track_ups_comms_restored, "ups_replace_battery": tc.track_ups_replace_battery, "ups_overload": tc.track_ups_overload, } return flag_map.get(event_type, True) # --- Display-time filters driven by TrackingConfig ------------------------- # # These transform a ServiceEvent so the dispatched notification reflects the # user's per-tracker "asset display" preferences. Event-tracking flags (which # events fire at all) live in ``event_allowed_by_config`` above; the filters # here only reshape an already-allowed event. # Asset.extra keys stripped when ``include_asset_details=False``. These are # the enrichment fields the default templates render as prose (city/country, # ⭐ rating, ❤️ favorite). ``thumbhash``/``file_size``/``playback_size``/ # ``owner_id``/``cache_key`` stay — they are load-bearing for media send and # caching, not user-facing prose. _ASSET_DETAIL_KEYS: tuple[str, ...] = ( "city", "country", "state", "latitude", "longitude", "is_favorite", "rating", ) def _sort_key_for(order_by: str) -> Callable[[MediaAsset], Any] | None: if order_by == "date": return lambda a: a.created_at if order_by == "name": return lambda a: a.filename.lower() if order_by == "rating": # None ratings sort last regardless of direction. return lambda a: ( a.extra.get("rating") is None, a.extra.get("rating") or 0, ) return None def _sort_assets( assets: list[MediaAsset], order_by: str, order: str, ) -> list[MediaAsset]: """Sort MediaAssets by the configured key/direction. ``order_by="none"`` preserves the input order (the provider's own ordering, usually detection order). ``"random"`` shuffles in place on a copy so repeated renders of the same event aren't identical. """ if order_by in ("none", "") or len(assets) < 2: return list(assets) if order_by == "random": shuffled = list(assets) random.shuffle(shuffled) return shuffled key_fn = _sort_key_for(order_by) if key_fn is None: return list(assets) return sorted(assets, key=key_fn, reverse=(order == "descending")) def _transform_asset( asset: MediaAsset, *, strip_details: bool, strip_tags: bool, ) -> MediaAsset: """Return a copy of ``asset`` with details and/or tags removed.""" new_extra = asset.extra new_description = asset.description new_tags = asset.tags if strip_details: new_extra = {k: v for k, v in asset.extra.items() if k not in _ASSET_DETAIL_KEYS} new_description = None if strip_tags: new_tags = [] return dataclasses.replace( asset, description=new_description, tags=list(new_tags) if new_tags is not asset.tags else asset.tags, extra=new_extra, ) def apply_tracking_display_filters( event: ServiceEvent, tc: TrackingConfig | None, ) -> ServiceEvent | None: """Apply per-tracker display preferences to an already-allowed event. Semantics: * ``notify_favorites_only`` + ``assets_order_by`` + ``max_assets_to_show`` only apply to ``ASSETS_ADDED`` events — the album-change path. Scheduled / periodic / memory events have their own limits and ordering (``scheduled_limit``, ``scheduled_order_by``, etc.), so reapplying the album-change cap would wrongly truncate them. * ``include_tags`` and ``include_asset_details`` apply to every event that carries assets, since they control rendering irrespective of how the assets were selected. Returns: A new ``ServiceEvent`` with filters applied, or ``None`` if the event should be dropped entirely (``notify_favorites_only=True`` and none of the added assets are favorites). """ if tc is None: return event assets = list(event.added_assets) new_added_count = event.added_count is_change_event = event.event_type.value == "assets_added" if is_change_event: if tc.notify_favorites_only: assets = [a for a in assets if a.extra.get("is_favorite")] new_added_count = len(assets) if not assets: return None assets = _sort_assets(assets, tc.assets_order_by, tc.assets_order) if tc.max_assets_to_show >= 0: assets = assets[: tc.max_assets_to_show] strip_details = not tc.include_asset_details strip_tags = not tc.include_tags if (strip_details or strip_tags) and assets: assets = [ _transform_asset(a, strip_details=strip_details, strip_tags=strip_tags) for a in assets ] new_extra = event.extra if strip_tags and "people" in event.extra: new_extra = {k: v for k, v in event.extra.items() if k != "people"} return dataclasses.replace( event, added_assets=assets, added_count=new_added_count, extra=new_extra, ) async def _resolve_target( session: AsyncSession, target: NotificationTarget, ) -> dict[str, Any]: """Resolve a single target into dispatch-ready data (config + receivers + credentials). Returns a dict with target_type, target_config, and receivers. Does NOT include tracking_config or template_slots — those come from the tracker link. """ # Load receivers as typed Receiver objects recv_result = await session.exec( select(TargetReceiver).where( TargetReceiver.target_id == target.id, TargetReceiver.enabled == True, ) ) recv_rows = recv_result.all() # For Telegram targets, resolve locale from TelegramChat chat_locale_map: dict[str, str] = {} if target.type == "telegram": bot_id = target.config.get("bot_id") if bot_id: chat_ids = [str(r.config.get("chat_id", "")) for r in recv_rows if r.config.get("chat_id")] if chat_ids: chat_result = await session.exec( select(TelegramChat).where( TelegramChat.bot_id == bot_id, TelegramChat.chat_id.in_(chat_ids), ) ) for chat in chat_result.all(): resolved = ( getattr(chat, 'language_override', '') or getattr(chat, 'language_code', '') or '' ) if resolved: chat_locale_map[chat.chat_id] = resolved[:2].lower() receivers: list[Receiver] = [] for r in recv_rows: explicit_locale = getattr(r, 'locale', '') or '' locale = explicit_locale if not locale and target.type == "telegram": chat_id = str(r.config.get("chat_id", "")) locale = chat_locale_map.get(chat_id, "") receivers.append(build_receiver(target.type, dict(r.config), locale)) target_config = dict(target.config) # chat_action lives on the model column — single source of truth. # Strip any legacy/stale value from config so an old config-stored value # can't shadow the user's UI choice. When the column is unset, leave the # key absent so the dispatcher's "typing" fallback applies. target_config.pop("chat_action", None) if hasattr(target, 'chat_action') and target.chat_action: target_config["chat_action"] = target.chat_action # Inject bot credentials for bot-backed target types if target.type == "email": email_bot_id = target.config.get("email_bot_id") if email_bot_id: email_bot = await session.get(EmailBot, email_bot_id) if email_bot: target_config["smtp"] = { "host": email_bot.smtp_host, "port": email_bot.smtp_port, "username": email_bot.smtp_username, "password": email_bot.smtp_password, "from_address": email_bot.email, "from_name": email_bot.name, "use_tls": email_bot.smtp_use_tls, } elif target.type == "matrix": matrix_bot_id = target.config.get("matrix_bot_id") if matrix_bot_id: matrix_bot = await session.get(MatrixBot, matrix_bot_id) if matrix_bot: target_config["homeserver_url"] = matrix_bot.homeserver_url target_config["access_token"] = matrix_bot.access_token return { "target_type": target.type, "target_config": target_config, "receivers": receivers, } async def load_link_data( session: AsyncSession, tracker_id: int, *, check_quiet_hours: bool = False, ) -> list[dict[str, Any]]: """Load tracker-target link data for dispatch. Args: session: Active database session. tracker_id: ID of the tracker whose links to load. check_quiet_hours: If True, skip links currently in quiet hours. """ # Load the tracker itself for default config IDs tracker = await session.get(NotificationTracker, tracker_id) default_tc_id = getattr(tracker, "default_tracking_config_id", None) if tracker else None default_tmpl_id = getattr(tracker, "default_template_config_id", None) if tracker else None tt_result = await session.exec( select(NotificationTrackerTarget).where( NotificationTrackerTarget.tracker_id == tracker_id ) ) tracker_targets = tt_result.all() # Filter enabled links and quiet hours upfront active_links = [ tt for tt in tracker_targets if tt.enabled and not (check_quiet_hours and in_quiet_hours(tt.quiet_hours_start, tt.quiet_hours_end)) ] if not active_links: return [] # Batch-load targets target_ids = list({tt.target_id for tt in active_links}) target_result = await session.exec( select(NotificationTarget).where(NotificationTarget.id.in_(target_ids)) ) target_map = {t.id: t for t in target_result.all()} # Batch-load tracking configs (per-link + tracker default) tc_ids = list({tid for tid in [tt.tracking_config_id for tt in active_links] + [default_tc_id] if tid}) tc_map: dict[int, TrackingConfig] = {} if tc_ids: tc_result = await session.exec(select(TrackingConfig).where(TrackingConfig.id.in_(tc_ids))) tc_map = {tc.id: tc for tc in tc_result.all()} # Batch-load template configs (per-link + tracker default) tmpl_ids = list({tid for tid in [tt.template_config_id for tt in active_links] + [default_tmpl_id] if tid}) tmpl_map: dict[int, TemplateConfig] = {} if tmpl_ids: tmpl_result = await session.exec(select(TemplateConfig).where(TemplateConfig.id.in_(tmpl_ids))) tmpl_map = {tc.id: tc for tc in tmpl_result.all()} # Batch-load template slots for all template configs slots_by_config: dict[int, dict[str, dict[str, str]]] = {} if tmpl_ids: slot_result = await session.exec( select(TemplateSlot).where(TemplateSlot.config_id.in_(tmpl_ids)) ) for s in slot_result.all(): event_key = s.slot_name.removeprefix("message_") if s.slot_name.startswith("message_") else s.slot_name slots_by_config.setdefault(s.config_id, {}).setdefault(event_key, {})[s.locale] = s.template # Pre-resolve broadcast children in one query to avoid N+1 per-child fetches broadcast_child_ids: set[int] = set() for tt in active_links: target = target_map.get(tt.target_id) if target and target.type == "broadcast": disabled_ids = set(target.config.get("disabled_child_ids", [])) for cid in target.config.get("child_target_ids", []): if cid not in disabled_ids: broadcast_child_ids.add(cid) child_target_map: dict[int, NotificationTarget] = {} if broadcast_child_ids: child_rows = await session.exec( select(NotificationTarget).where(NotificationTarget.id.in_(broadcast_child_ids)) ) child_target_map = {t.id: t for t in child_rows.all()} link_data: list[dict[str, Any]] = [] for tt in active_links: target = target_map.get(tt.target_id) if not target: continue # Per-link config overrides tracker defaults tc_id = tt.tracking_config_id or default_tc_id tmpl_id = tt.template_config_id or default_tmpl_id tracking_config = tc_map.get(tc_id) if tc_id else None template_config = tmpl_map.get(tmpl_id) if tmpl_id else None template_slots = slots_by_config.get(template_config.id) if template_config else None # Broadcast target: expand into child targets (pre-loaded above) if target.type == "broadcast": disabled_ids = set(target.config.get("disabled_child_ids", [])) for child_id in target.config.get("child_target_ids", []): if child_id in disabled_ids: continue child_target = child_target_map.get(child_id) if not child_target or child_target.type == "broadcast": continue resolved = await _resolve_target(session, child_target) link_data.append({ **resolved, "tracking_config": tracking_config, "template_config": template_config, "template_slots": template_slots, }) continue # Regular target resolved = await _resolve_target(session, target) link_data.append({ **resolved, "tracking_config": tracking_config, "template_config": template_config, "template_slots": template_slots, }) return link_data