"""Shared command handler utilities to reduce boilerplate across providers.""" from __future__ import annotations import logging from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from ..database.engine import get_engine from ..database.models import ( EventLog, NotificationTarget, NotificationTracker, NotificationTrackerTarget, ServiceProvider, TargetReceiver, ) _LOGGER = logging.getLogger(__name__) async def get_trackers_for_provider(provider_id: int) -> list[NotificationTracker]: """Get notification trackers for a single provider.""" from .handler import _get_notification_trackers_for_providers return await _get_notification_trackers_for_providers({provider_id}) async def get_last_event_str( tracker_ids: list[int], *, allowed_album_ids: set[str] | None = None, ) -> str: """Get formatted timestamp of most recent event for given trackers. Returns a 'YYYY-MM-DD HH:MM' string, or '-' if no events exist. When ``allowed_album_ids`` is provided, only events whose ``collection_id`` is in the set are considered — matches the per-chat scope applied via ``CommandTrackerListener.allowed_album_ids``. """ if not tracker_ids: return "-" engine = get_engine() async with AsyncSession(engine) as session: query = ( select(EventLog) .where(EventLog.tracker_id.in_(tracker_ids)) .order_by(EventLog.created_at.desc()) ) if allowed_album_ids is not None: query = query.where(EventLog.collection_id.in_(list(allowed_album_ids))) result = await session.exec(query.limit(1)) last_event = result.first() return last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-" async def resolve_chat_album_scope( *, provider_id: int, bot_id: int, chat_id: str, ) -> set[str]: """Compute the album scope for a (provider, bot, chat) triple. Walks the notification-routing graph: find every notification tracker for ``provider_id`` that ultimately delivers to a Telegram receiver matching this ``(bot_id, chat_id)``, then union their ``collection_ids``. The result is the set of albums this specific chat legitimately sees notifications for — which is the natural "allowed albums" for commands issued in that chat. Returns: set of album ids. Empty set = "no tracker routes to this chat" — caller should treat as "show nothing" (defense in depth); otherwise a bot's chats would leak the provider's full album catalog. Notes: - Only enabled ``TargetReceiver`` rows are considered. - Both direct Telegram targets and broadcast targets that fan out to a Telegram child target are resolved. - Explicit ``CommandTrackerListener.allowed_album_ids`` override is NOT applied here — that's the dispatcher's job. This helper is the "derived" fallback. """ engine = get_engine() async with AsyncSession(engine) as session: # 1. Telegram receivers in this chat (directly or via broadcast). direct_rows = (await session.exec( select(TargetReceiver, NotificationTarget) .join( NotificationTarget, TargetReceiver.target_id == NotificationTarget.id, ) .where( TargetReceiver.enabled == True, # noqa: E712 NotificationTarget.type == "telegram", ) )).all() target_ids: set[int] = set() for recv, target in direct_rows: rc_chat = str(recv.config.get("chat_id", "") or "") rc_bot = target.config.get("bot_id") if rc_chat == str(chat_id) and rc_bot == bot_id: target_ids.add(target.id) # Follow broadcast parents: any broadcast target whose # child_target_ids includes one of our direct Telegram target_ids # also counts as "routes to this chat". broadcast_rows = (await session.exec( select(NotificationTarget).where(NotificationTarget.type == "broadcast") )).all() for b in broadcast_rows: children = set(b.config.get("child_target_ids", []) or []) disabled = set(b.config.get("disabled_child_ids", []) or []) if (children - disabled) & target_ids: target_ids.add(b.id) if not target_ids: return set() # 2. Trackers pointing at those targets. tracker_target_rows = (await session.exec( select(NotificationTrackerTarget).where( NotificationTrackerTarget.target_id.in_(target_ids) ) )).all() tracker_ids = {tt.tracker_id for tt in tracker_target_rows} if not tracker_ids: return set() # 3. Filter trackers by provider and collect collection_ids. trackers = (await session.exec( select(NotificationTracker).where( NotificationTracker.id.in_(tracker_ids), NotificationTracker.provider_id == provider_id, ) )).all() scope: set[str] = set() for tr in trackers: for aid in (tr.collection_ids or []): if aid: scope.add(aid) return scope def get_tracked_collection_ids( provider: ServiceProvider, trackers: list[NotificationTracker], *, max_items: int = 20, ) -> list[str]: """Get deduplicated collection IDs from trackers for a provider. Iterates all trackers belonging to *provider*, collects IDs from both ``collection_ids`` and ``filters.collections``, deduplicates while preserving order, and caps at *max_items*. """ seen: set[str] = set() result: list[str] = [] for tracker in trackers: if tracker.provider_id != provider.id: continue for cid in tracker.collection_ids or []: if cid not in seen: seen.add(cid) result.append(cid) for cid in (tracker.filters or {}).get("collections", []): if cid not in seen: seen.add(cid) result.append(cid) return result[:max_items]