"""Telegram bot command handler — provider-agnostic dispatcher.""" from __future__ import annotations import logging import time from functools import lru_cache from typing import Any import aiohttp from cachetools import TTLCache from jinja2.sandbox import SandboxedEnvironment from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from notify_bridge_core.notifications.telegram.client import TelegramClient from ..database.engine import get_engine from ..database.models import ( CommandConfig, CommandTemplateConfig, CommandTemplateSlot, CommandTracker, CommandTrackerListener, EventLog, NotificationTracker, ServiceProvider, TelegramBot, TelegramChat, ) from .base import CommandResponse from .parser import parse_command from .registry import get_rate_category _LOGGER = logging.getLogger(__name__) # Singleton Jinja2 environment for template rendering (Phase 4d) _JINJA_ENV = SandboxedEnvironment(autoescape=True) # Rate limit state with automatic TTL expiry (Phase 4e) _rate_limits: TTLCache = TTLCache(maxsize=10000, ttl=3600) # Maximum responses per command to avoid Telegram rate limits _MAX_RESPONSES_PER_COMMAND = 5 # Commands that fetch assets from the service provider and usually reply # with media — "uploading photo" is the accurate UX hint while we wait on # the provider API + Telegram upload. _UPLOAD_PHOTO_COMMANDS = frozenset({ "latest", "random", "favorites", "memory", "search", "find", "person", "place", }) # Commands that fetch from the provider but reply with text only. # "typing" is accurate; we still want an indicator because the fetch is slow. _TYPING_COMMANDS = frozenset({"summary"}) def classify_command_chat_action(text: str) -> str | None: """Return the Telegram chat-action hint to show for this command, or None. The classification is by command name alone — good enough for the cases where a chat action is worthwhile (slow provider fetches). Fast DB-only commands (``/status``, ``/albums``, ``/events``, ``/people``) return ``None`` and skip the indicator entirely. """ cmd, _, _ = parse_command(text) if not cmd: return None if cmd in _UPLOAD_PHOTO_COMMANDS: return "upload_photo" if cmd in _TYPING_COMMANDS: return "typing" return None def _check_rate_limit(bot_id: int, chat_id: str, cmd: str, limits: dict[str, int]) -> int | None: """Check rate limit. Returns seconds to wait, or None if OK.""" category = get_rate_category(cmd) cooldown = limits.get(category, limits.get("default", 10)) if cooldown <= 0: return None key = (bot_id, chat_id, category) now = time.time() last = _rate_limits.get(key, 0) if now - last < cooldown: return int(cooldown - (now - last)) + 1 _rate_limits[key] = now return None def _resolve_template( templates: dict[str, dict[str, str]], slot_name: str, locale: str, ) -> str | None: """Pick a template string for slot+locale, falling back to 'en'.""" locale_map = templates.get(slot_name, {}) return locale_map.get(locale) or locale_map.get("en") @lru_cache(maxsize=256) def _compile_template(template_str: str): """Cache compiled Jinja2 templates to avoid re-parsing identical strings.""" return _JINJA_ENV.from_string(template_str) def _render_cmd_template( templates: dict[str, dict[str, str]], slot_name: str, locale: str, context: dict[str, Any], ) -> str: """Render a locale-aware command template. Falls back to 'en'.""" template_str = _resolve_template(templates, slot_name, locale) if not template_str: # Missing template = user sees "[No template: X]" — this is an ERROR, # not a warning. Broken replies must stand out in production logs. _LOGGER.error("No command template found for slot '%s' locale '%s'", slot_name, locale) return f"[No template: {slot_name}]" try: tmpl = _compile_template(template_str) return tmpl.render(**context) except Exception: _LOGGER.error( "Failed to render command template '%s' locale=%s — user will see a broken reply", slot_name, locale, exc_info=True, ) return f"[Template error: {slot_name}]" # --------------------------------------------------------------------------- # Context resolution # --------------------------------------------------------------------------- async def _resolve_command_context( bot: TelegramBot, ) -> tuple[ list[tuple[CommandTracker, CommandConfig, ServiceProvider, CommandTrackerListener]], dict[int, dict[str, dict[str, str]]], ]: """Resolve all enabled command trackers, configs, and providers for a bot. Returns: (context_tuples, templates_by_config_id) templates_by_config_id is {command_template_config_id: {slot_name: {locale: template}}}. """ engine = get_engine() async with AsyncSession(engine) as session: result = await session.exec( select(CommandTrackerListener).where( CommandTrackerListener.listener_type == "telegram_bot", CommandTrackerListener.listener_id == bot.id, ) ) listeners = result.all() if not listeners: return [], {} # Batch-fetch all referenced entities in 3 queries instead of N*3 tracker_ids = list({l.command_tracker_id for l in listeners}) tracker_result = await session.exec( select(CommandTracker).where(CommandTracker.id.in_(tracker_ids)) ) trackers_by_id = {t.id: t for t in tracker_result.all()} config_ids = list({ t.command_config_id for t in trackers_by_id.values() if t.enabled and t.command_config_id }) if config_ids: config_result = await session.exec( select(CommandConfig).where(CommandConfig.id.in_(config_ids)) ) configs_by_id = {c.id: c for c in config_result.all()} else: configs_by_id = {} provider_ids = list({ t.provider_id for t in trackers_by_id.values() if t.enabled and t.provider_id }) if provider_ids: provider_result = await session.exec( select(ServiceProvider).where(ServiceProvider.id.in_(provider_ids)) ) providers_by_id = {p.id: p for p in provider_result.all()} else: providers_by_id = {} tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider, CommandTrackerListener]] = [] for listener in listeners: tracker = trackers_by_id.get(listener.command_tracker_id) if not tracker or not tracker.enabled: continue config = configs_by_id.get(tracker.command_config_id) if not config: continue provider = providers_by_id.get(tracker.provider_id) if not provider: continue tuples.append((tracker, config, provider, listener)) # Load command template slots per config (not merged) templates_by_config_id: dict[int, dict[str, dict[str, str]]] = {} seen_config_ids: set[int] = set() for _, config, _, _ in tuples: cfg_id = config.command_template_config_id if cfg_id and cfg_id not in seen_config_ids: seen_config_ids.add(cfg_id) slot_result = await session.exec( select(CommandTemplateSlot).where( CommandTemplateSlot.config_id == cfg_id ) ) slots: dict[str, dict[str, str]] = {} for s in slot_result.all(): slots.setdefault(s.slot_name, {})[s.locale] = s.template templates_by_config_id[cfg_id] = slots return tuples, templates_by_config_id def _templates_for_config( templates_by_config_id: dict[int, dict[str, dict[str, str]]], config: CommandConfig, ) -> dict[str, dict[str, str]]: """Get template slots for a specific command config.""" cfg_id = config.command_template_config_id if cfg_id and cfg_id in templates_by_config_id: return templates_by_config_id[cfg_id] return {} def _merge_all_templates( templates_by_config_id: dict[int, dict[str, dict[str, str]]], ) -> dict[str, dict[str, str]]: """Merge all template config slots into one dict (for universal commands).""" merged: dict[str, dict[str, str]] = {} for slots in templates_by_config_id.values(): for slot_name, locale_map in slots.items(): merged.setdefault(slot_name, {}).update(locale_map) return merged def _merge_enabled_commands( ctx: list[tuple[CommandTracker, CommandConfig, ServiceProvider, CommandTrackerListener]], ) -> tuple[list[str], dict[str, Any]]: """Merge enabled_commands (union) and rate_limits from all configs. Rate limits use the most restrictive (minimum) cooldown per category. """ if not ctx: return [], {} enabled: set[str] = set() merged_limits: dict[str, int] = {} for _, config, _, _ in ctx: enabled.update(config.enabled_commands or []) for category, cooldown in (config.rate_limits or {}).items(): if category not in merged_limits: merged_limits[category] = cooldown else: merged_limits[category] = min(merged_limits[category], cooldown) return sorted(enabled), merged_limits # --------------------------------------------------------------------------- # Event logging # --------------------------------------------------------------------------- def _format_command_subject(cmd: str, args: str) -> str: """Render the dashboard ``collection_name`` for a command event.""" args = (args or "").strip() return f"/{cmd} {args}".rstrip() if args else f"/{cmd}" def _normalize_issuer(issuer: dict[str, Any] | None) -> dict[str, Any] | None: """Strip a Telegram ``from`` payload to the fields the dashboard needs. Telegram's ``from`` carries plenty we don't want to persist (premium badge, language code already captured elsewhere, etc.). Keep just the identity bits and drop anything else so future Telegram changes can't accidentally start logging extra PII. """ if not issuer: return None keep = ("id", "username", "first_name", "last_name", "is_bot") out = {k: issuer[k] for k in keep if k in issuer and issuer[k] not in (None, "")} return out or None async def _log_command_event( *, bot: TelegramBot, chat_id: str, cmd: str, args: str, locale: str, event_type: str, responses: list[CommandResponse], ctx_tuples: list[ tuple[CommandTracker, CommandConfig, ServiceProvider, CommandTrackerListener] ], extra_details: dict[str, Any] | None = None, issuer: dict[str, Any] | None = None, ) -> None: """Persist a single ``EventLog`` row for a bot-command invocation. One row per user invocation. Per-tracker breakdown lives in ``details`` (``tracker_count`` / ``responses_count``). Best-effort: a logging failure must never block the user-visible reply, so we swallow. """ try: first_tracker: CommandTracker | None = None first_provider: ServiceProvider | None = None if ctx_tuples: first_tracker, _, first_provider, _ = ctx_tuples[0] media_total = sum(len(r.media or []) for r in responses) details: dict[str, Any] = { "command": cmd, "args": args or "", "chat_id": chat_id, "locale": locale, "tracker_count": len(ctx_tuples), "responses_count": len(responses), } normalized_issuer = _normalize_issuer(issuer) if normalized_issuer: details["issuer"] = normalized_issuer if extra_details: details.update(extra_details) engine = get_engine() async with AsyncSession(engine) as session: session.add(EventLog( user_id=bot.user_id, tracker_id=None, tracker_name="", action_id=None, action_name="", command_tracker_id=first_tracker.id if first_tracker else None, command_tracker_name=first_tracker.name if first_tracker else "", telegram_bot_id=bot.id, bot_name=bot.name or "", provider_id=first_provider.id if first_provider else None, provider_name=(first_provider.name if first_provider else "") or "", event_type=event_type, collection_id=str(chat_id), collection_name=_format_command_subject(cmd, args), assets_count=media_total, details=details, )) await session.commit() except Exception: # noqa: BLE001 — diagnostic only, never block reply _LOGGER.exception( "Failed to log command event bot=%d chat=%s cmd=/%s", bot.id, chat_id, cmd, ) # --------------------------------------------------------------------------- # Main dispatcher # --------------------------------------------------------------------------- async def handle_command( bot: TelegramBot, chat_id: str, text: str, language_code: str = "", *, issuer: dict[str, Any] | None = None, ) -> list[CommandResponse] | None: """Handle a bot command. Routes to provider-specific handlers. Returns a list of CommandResponse objects (one per tracker), or None. Universal commands (/start, /help) return a single-element list. Provider-specific commands dispatch per-tracker with per-tracker config. ``issuer`` is the Telegram ``from`` object (``{id, username, first_name, last_name, language_code}``) when known. Stored on the EventLog row so the dashboard can show *who* invoked the command. """ cmd, args, count_override = parse_command(text) if not cmd: return None ctx_tuples, templates_by_config_id = await _resolve_command_context(bot) enabled, rate_limits = _merge_enabled_commands(ctx_tuples) locale = language_code[:2].lower() if language_code else "en" if locale not in ("en", "ru"): locale = "en" # Merged templates for universal commands merged_templates = _merge_all_templates(templates_by_config_id) # Universal commands have no tracker/provider context. if cmd == "start": text_resp = _render_cmd_template(merged_templates, "start", locale, {"bot_name": bot.name}) responses = [CommandResponse(text=text_resp)] await _log_command_event( bot=bot, chat_id=chat_id, cmd=cmd, args=args, locale=locale, event_type="command_handled", responses=responses, ctx_tuples=[], issuer=issuer, ) return responses # Unknown / disabled command — caller treats this the same as "no # match" and we deliberately do NOT log it (avoids dashboard spam # from random ``/foo`` traffic). if cmd not in enabled and cmd != "start": return None # Rate limit check (once per command, shared across all trackers) wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits) if wait is not None: _LOGGER.info( "Rate-limited /%s for bot=%d chat=%s — %ds cooldown remaining", cmd, bot.id, chat_id, wait, ) text_resp = _render_cmd_template(merged_templates, "rate_limited", locale, {"wait": wait}) responses = [CommandResponse(text=text_resp)] await _log_command_event( bot=bot, chat_id=chat_id, cmd=cmd, args=args, locale=locale, event_type="command_rate_limited", responses=responses, ctx_tuples=ctx_tuples, extra_details={"wait_seconds": wait}, issuer=issuer, ) return responses # Universal commands — single merged response if cmd == "help": ctx = _cmd_help(enabled, locale, merged_templates) text_resp = _render_cmd_template(merged_templates, "help", locale, ctx) responses = [CommandResponse(text=text_resp)] await _log_command_event( bot=bot, chat_id=chat_id, cmd=cmd, args=args, locale=locale, event_type="command_handled", responses=responses, ctx_tuples=ctx_tuples, issuer=issuer, ) return responses # Provider-specific dispatch — per-tracker from .dispatch import get_handler # For paginated commands (/search, /find) a trailing integer means page, # not count. Preserve count_override meaning for all other commands. paginated_cmds = {"search", "find"} page = 1 if cmd in paginated_cmds and count_override: page = max(1, count_override) count_override = None from .command_utils import resolve_chat_album_scope responses: list[CommandResponse] = [] dispatched_ctx: list[ tuple[CommandTracker, CommandConfig, ServiceProvider, CommandTrackerListener] ] = [] try: for tracker, config, provider, listener in ctx_tuples: if len(responses) >= _MAX_RESPONSES_PER_COMMAND: _LOGGER.warning( "Truncated command responses at %d for bot=%d chat=%s cmd=/%s (listener context size=%d)", _MAX_RESPONSES_PER_COMMAND, bot.id, chat_id, cmd, len(ctx_tuples), ) break handler = get_handler(provider.type) if not handler or cmd not in handler.get_provider_commands(): continue tracker_templates = _templates_for_config(templates_by_config_id, config) count = min(count_override or config.default_count or 5, 20) response_mode = config.response_mode or "media" # Resolve the album scope for this (provider, bot, chat) triple. # - Explicit ``listener.allowed_album_ids`` override wins as-is. # - Otherwise derive from notification routing: only albums that # already deliver notifications to this chat are queryable from # it. Prevents commands leaking the full album catalog into # chats that were never set up to receive from those trackers. if listener is not None and listener.allowed_album_ids is not None: allowed_album_ids: set[str] = set(listener.allowed_album_ids) else: allowed_album_ids = await resolve_chat_album_scope( provider_id=provider.id, bot_id=bot.id, chat_id=chat_id, ) result = await handler.handle( cmd, args, count, locale, response_mode, provider, tracker_templates, bot, tracker, config, listener=listener, allowed_album_ids=allowed_album_ids, page=page, ) if result is not None: responses.append(result) dispatched_ctx.append((tracker, config, provider, listener)) except Exception as exc: # noqa: BLE001 — log then re-raise await _log_command_event( bot=bot, chat_id=chat_id, cmd=cmd, args=args, locale=locale, event_type="command_failed", responses=responses, ctx_tuples=ctx_tuples, extra_details={"error": f"{type(exc).__name__}: {exc}"}, issuer=issuer, ) raise if responses: await _log_command_event( bot=bot, chat_id=chat_id, cmd=cmd, args=args, locale=locale, event_type="command_handled", responses=responses, ctx_tuples=dispatched_ctx, issuer=issuer, ) return responses return None def _cmd_help( enabled: list[str], locale: str, templates: dict[str, dict[str, str]], ) -> dict[str, Any]: commands = [] for cmd in enabled: desc_text = _resolve_template(templates, f"desc_{cmd}", locale) or cmd entry: dict[str, str] = {"name": cmd, "description": desc_text} usage_text = _resolve_template(templates, f"usage_{cmd}", locale) if usage_text: entry["usage"] = usage_text commands.append(entry) return {"commands": commands} async def _get_notification_trackers_for_providers( provider_ids: set[int], ) -> list[NotificationTracker]: """Get notification trackers for the given provider IDs.""" if not provider_ids: return [] engine = get_engine() async with AsyncSession(engine) as session: result = await session.exec( select(NotificationTracker).where( NotificationTracker.provider_id.in_(provider_ids) ) ) return list(result.all()) async def send_reply( bot_token: str, chat_id: str, text: str, reply_to_message_id: int | None = None, session: aiohttp.ClientSession | None = None, ) -> None: """Send a text reply to a chat. Thin wrapper that goes through the single ``services.telegram_send`` entry point so commands and notifications share one routine — same HTTP session pool, same file_id caches. Command responses are listings (albums, people, events, ...) that embed multiple links; Telegram's default behavior of rendering a preview of the first URL is almost never what the user wants and clashes with the "Disable link previews" toggle operators set on their Telegram target. We always pass ``disable_web_page_preview=True`` here. """ from ..services.telegram_send import send_telegram_message result = await send_telegram_message( bot_token, chat_id, text, reply_to_message_id=reply_to_message_id, disable_web_page_preview=True, ) if not result.get("success"): # User-visible failure: the bot's reply never reached the chat. _LOGGER.error( "Telegram reply failed (chat=%s reply_to=%s len=%d): code=%s error=%r", chat_id, reply_to_message_id, len(text or ""), result.get("error_code"), result.get("error"), ) async def send_media_group( bot_token: str, chat_id: str, media_items: list[dict[str, Any]], reply_to_message_id: int | None = None, session: aiohttp.ClientSession | None = None, ) -> None: """Send media items via the shared Telegram routine. ``media_items`` must already be in TelegramClient asset format — each entry contains ``type`` (``"photo"``/``"video"``/``"document"``), ``url``, optional ``cache_key``, and optional ``headers``. Provider command handlers build this format via ``build_telegram_asset_entry`` — the same helper the notification dispatcher uses — so videos keep their ``"video"`` type and point at a real video URL instead of a still thumbnail. Uses ``services.telegram_send.send_telegram_media`` so the URL cache and asset cache are wired in exactly like the notification path. Repeated ``/latest`` / ``/random`` commands that match previously-sent assets hit the cache and skip the re-upload. """ if not media_items: # This is what happened in the /random blind spot: the text reply # was sent, but the media follow-up was silently skipped because # the caller passed an empty media list. Surface it so we can see # it in the log and correlate with the text message. _LOGGER.warning( "send_media_group called with 0 items (chat=%s reply_to=%s) — no media will be delivered", chat_id, reply_to_message_id, ) return from ..services.telegram_send import send_telegram_media result = await send_telegram_media( bot_token, chat_id, media_items, reply_to_message_id=reply_to_message_id, chat_action=None, ) if not result.get("success"): # User-visible failure: media promised by the text reply never arrived. _LOGGER.error( "Telegram media group failed (chat=%s items=%d reply_to=%s): code=%s error=%r failed_at_chunk=%s", chat_id, len(media_items), reply_to_message_id, result.get("error_code"), result.get("error"), result.get("failed_at_chunk"), ) def _normalize_locale(raw: str | None) -> str: """Mirror the locale normalization used by the message handler.""" locale = (raw or "")[:2].lower() if locale not in ("en", "ru"): locale = "en" return locale def _build_command_list( enabled: list[str], templates: dict[str, dict[str, str]], locale: str, ) -> list[dict[str, str]]: commands: list[dict[str, str]] = [] for cmd in enabled: desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd commands.append({"command": cmd, "description": desc}) return commands async def sync_chat_command_binding(bot: TelegramBot, chat: TelegramChat) -> bool: """Push Telegram's per-chat command binding for a single chat. Used for immediate refresh when the user toggles a chat's ``language_override`` or ``commands_enabled`` flag — avoids the 30 s debounce of the bot-wide sync. Only touches the chat-scoped binding (one Telegram API call); global per-language registrations stay untouched. The bot-wide sync (``register_commands_with_telegram``) remains the source of truth for everything else. Returns ``True`` when Telegram acknowledged the change. """ from ..services.http_session import get_http_session http = await get_http_session() client = TelegramClient(http, bot.token) scope = {"type": "chat", "chat_id": chat.chat_id} # Chat is opted out of commands → ensure no chat-scoped override # lingers. Telegram returns ok=true even if there was nothing to # delete, so this is safe to call unconditionally. if not chat.commands_enabled or not chat.language_override: result = await client.delete_my_commands(scope=scope) if not result.get("success"): _LOGGER.warning( "delete_my_commands(immediate) failed bot=%d chat=%s: %s", bot.id, chat.chat_id, result.get("error"), ) return bool(result.get("success")) # Override active → resolve the command list for this bot in the # override locale and push it scoped to this chat. ctx_tuples, templates_by_config_id = await _resolve_command_context(bot) enabled, _ = _merge_enabled_commands(ctx_tuples) templates = _merge_all_templates(templates_by_config_id) override_locale = _normalize_locale(chat.language_override) commands = _build_command_list(enabled, templates, override_locale) result = await client.set_my_commands(commands, scope=scope) if not result.get("success"): _LOGGER.warning( "set_my_commands(immediate) failed bot=%d chat=%s locale=%s: %s", bot.id, chat.chat_id, override_locale, result.get("error"), ) return bool(result.get("success")) async def register_commands_with_telegram(bot: TelegramBot) -> bool: """Register enabled commands with Telegram BotFather API via TelegramClient. Registration happens at three levels: 1. Default (no scope, no language) — fallback for any user. 2. Per-language (no scope, ``language_code=en|ru``) — Telegram picks based on the *user's* Telegram client language. 3. Per-chat (``scope=BotCommandScopeChat``) — when a chat has ``language_override`` set, register chat-scoped commands so the override takes effect regardless of each user's Telegram client language. This is the only level Telegram honors for "this chat should use RU even though the user's Telegram is in EN" — the per-language registration alone is keyed on the client locale, not on any per-chat preference we store. """ ctx_tuples, templates_by_config_id = await _resolve_command_context(bot) enabled, _ = _merge_enabled_commands(ctx_tuples) templates = _merge_all_templates(templates_by_config_id) from ..services.http_session import get_http_session http = await get_http_session() client = TelegramClient(http, bot.token) success = False # Register per-locale commands (keyed on user's Telegram client language) for locale in ("en", "ru"): commands = _build_command_list(enabled, templates, locale) result = await client.set_my_commands(commands, language_code=locale) if result.get("success"): success = True else: _LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error")) # Register default (no language_code) with EN descriptions en_commands = _build_command_list(enabled, templates, "en") result = await client.set_my_commands(en_commands) if result.get("success"): _LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username) success = True # Per-chat overrides: apply chat-scoped commands so language_override # wins over the user's Telegram client language. For chats with # commands enabled but no override, clear any prior chat-scoped # binding so they fall back to the per-language registration above. engine = get_engine() async with AsyncSession(engine) as session: chat_result = await session.exec( select(TelegramChat).where( TelegramChat.bot_id == bot.id, TelegramChat.commands_enabled == True, # noqa: E712 — SQLModel needs == for column comparison ) ) chats = list(chat_result.all()) override_count = 0 for chat in chats: scope = {"type": "chat", "chat_id": chat.chat_id} if chat.language_override: override_locale = _normalize_locale(chat.language_override) commands = _build_command_list(enabled, templates, override_locale) result = await client.set_my_commands(commands, scope=scope) if result.get("success"): override_count += 1 else: _LOGGER.warning( "Failed to register chat-scoped commands for bot=%d chat=%s locale=%s: %s", bot.id, chat.chat_id, override_locale, result.get("error"), ) else: # Clear any stale chat-scoped binding from a previous override # so this chat falls back to the per-language registration. # Telegram returns ok=true even when nothing was set; safe to # call unconditionally. result = await client.delete_my_commands(scope=scope) if not result.get("success"): _LOGGER.debug( "delete_my_commands for bot=%d chat=%s returned: %s", bot.id, chat.chat_id, result.get("error"), ) if override_count: _LOGGER.info( "Applied %d per-chat command override(s) for bot @%s", override_count, bot.bot_username, ) return success