"""Telegram bot command handler — implements all /commands.""" from __future__ import annotations import logging import random as rng import time from datetime import datetime, timezone from typing import Any import aiohttp from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from notify_bridge_core.notifications.telegram.media import TELEGRAM_API_BASE_URL from ..database.engine import get_engine from ..services import make_immich_provider from ..database.models import ( CommandConfig, CommandTemplateConfig, CommandTemplateSlot, CommandTracker, CommandTrackerListener, EventLog, NotificationTarget, NotificationTracker, NotificationTrackerTarget, ServiceProvider, TelegramBot, TrackingConfig, ) from .parser import parse_command from .registry import COMMAND_DESCRIPTIONS, get_rate_category _LOGGER = logging.getLogger(__name__) # Rate limit state: { (bot_id, chat_id, category): last_used_timestamp } _rate_limits: dict[tuple[int, str, str], float] = {} def _check_rate_limit(bot_id: int, chat_id: str, cmd: str, limits: dict[str, int]) -> int | None: """Check rate limit. Returns seconds to wait, or None if OK.""" category = get_rate_category(cmd) cooldown = limits.get(category, limits.get("default", 10)) if cooldown <= 0: return None key = (bot_id, chat_id, category) now = time.time() last = _rate_limits.get(key, 0) if now - last < cooldown: return int(cooldown - (now - last)) + 1 _rate_limits[key] = now return None def _render_cmd_template( templates: dict[str, str], slot_name: str, context: dict[str, Any] ) -> str: """Render a command template. Returns rendered string or error placeholder.""" template_str = templates.get(slot_name) if not template_str: _LOGGER.warning("No command template found for slot '%s'", slot_name) return f"[No template: {slot_name}]" try: from jinja2.sandbox import SandboxedEnvironment env = SandboxedEnvironment(autoescape=False) tmpl = env.from_string(template_str) return tmpl.render(**context) except Exception as e: _LOGGER.warning("Failed to render command template '%s': %s", slot_name, e) return f"[Template error: {slot_name}]" async def _resolve_command_context( bot: TelegramBot, ) -> tuple[list[tuple[CommandTracker, CommandConfig, ServiceProvider]], dict[str, str]]: """Resolve all enabled command trackers, configs, and providers for a bot. Returns (context_tuples, cmd_template_slots). """ engine = get_engine() async with AsyncSession(engine) as session: # Find all listeners for this bot result = await session.exec( select(CommandTrackerListener).where( CommandTrackerListener.listener_type == "telegram_bot", CommandTrackerListener.listener_id == bot.id, ) ) listeners = result.all() if not listeners: return [], {} tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]] = [] for listener in listeners: tracker = await session.get(CommandTracker, listener.command_tracker_id) if not tracker or not tracker.enabled: continue config = await session.get(CommandConfig, tracker.command_config_id) if not config: continue provider = await session.get(ServiceProvider, tracker.provider_id) if not provider: continue tuples.append((tracker, config, provider)) # Load command template slots from the first config that has one cmd_template_slots: dict[str, str] = {} for _, config, _ in tuples: if config.command_template_config_id: slot_result = await session.exec( select(CommandTemplateSlot).where( CommandTemplateSlot.config_id == config.command_template_config_id ) ) cmd_template_slots = {s.slot_name: s.template for s in slot_result.all()} if cmd_template_slots: break return tuples, cmd_template_slots def _merge_command_context( ctx: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], ) -> tuple[list[str], str, str, int, dict[str, Any]]: """Merge enabled_commands from all configs and pick defaults from first config. Returns (enabled_commands, locale, response_mode, default_count, rate_limits). """ if not ctx: return [], "en", "media", 5, {} # Union of all enabled commands across configs enabled: set[str] = set() for _, config, _ in ctx: enabled.update(config.enabled_commands or []) # Use first config's settings as defaults first_config = ctx[0][1] locale = first_config.locale or "en" response_mode = first_config.response_mode or "media" default_count = first_config.default_count or 5 rate_limits = first_config.rate_limits or {} return sorted(enabled), locale, response_mode, default_count, rate_limits async def handle_command( bot: TelegramBot, chat_id: str, text: str, ) -> str | list[dict[str, Any]] | None: """Handle a bot command. Returns text response, media list, or None.""" cmd, args, count_override = parse_command(text) if not cmd: return None ctx_tuples, cmd_templates = await _resolve_command_context(bot) enabled, locale, response_mode, default_count, rate_limits = _merge_command_context(ctx_tuples) if cmd == "start": return _render_cmd_template(cmd_templates, "start", {"locale": locale, "bot_name": bot.name}) if cmd not in enabled and cmd != "start": return None # Silently ignore disabled commands # Rate limit check wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits) if wait is not None: return _render_cmd_template(cmd_templates, "rate_limited", {"wait": wait, "locale": locale}) count = min(count_override or default_count, 20) # Build providers map from command context providers_map: dict[int, ServiceProvider] = {} for _, _, provider in ctx_tuples: providers_map[provider.id] = provider # Dispatch — each handler returns template context dict if cmd == "help": ctx = _cmd_help(enabled, locale) elif cmd == "status": ctx = await _cmd_status(bot, providers_map, locale) elif cmd == "albums": ctx = await _cmd_albums(bot, providers_map, locale) elif cmd == "events": ctx = await _cmd_events(bot, providers_map, count, locale) elif cmd == "people": ctx = await _cmd_people(providers_map, locale) elif cmd in ("search", "find", "person", "place", "latest", "random", "favorites", "summary", "memory"): return await _cmd_immich(bot, cmd, args, count, locale, response_mode, providers_map, cmd_templates) else: return None return _render_cmd_template(cmd_templates, cmd, {**ctx, "locale": locale}) def _cmd_help(enabled: list[str], locale: str) -> dict[str, Any]: commands = [] for cmd in enabled: desc = COMMAND_DESCRIPTIONS.get(cmd, {}) desc_text = desc.get(locale, desc.get("en", "")) commands.append({"name": cmd, "description": desc_text}) return {"commands": commands} async def _get_notification_trackers_for_providers( provider_ids: set[int], ) -> list[NotificationTracker]: """Get notification trackers for the given provider IDs.""" if not provider_ids: return [] engine = get_engine() async with AsyncSession(engine) as session: result = await session.exec( select(NotificationTracker).where( NotificationTracker.provider_id.in_(provider_ids) ) ) return list(result.all()) async def _check_native_memory(bot: TelegramBot) -> bool: """Check if any tracker-target linked to this bot uses native memory source.""" engine = get_engine() async with AsyncSession(engine) as session: result = await session.exec( select(NotificationTarget).where( NotificationTarget.type == "telegram", NotificationTarget.user_id == bot.user_id, ) ) targets = result.all() bot_target_ids = {t.id for t in targets if t.config.get("bot_token") == bot.token} if not bot_target_ids: return False tt_result = await session.exec( select(NotificationTrackerTarget).where(NotificationTrackerTarget.target_id.in_(bot_target_ids)) ) for tt in tt_result.all(): if tt.tracking_config_id: tc = await session.get(TrackingConfig, tt.tracking_config_id) if tc and tc.memory_source == "native": return True return False async def _cmd_status(bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str) -> dict[str, Any]: provider_ids = set(providers_map.keys()) trackers = await _get_notification_trackers_for_providers(provider_ids) active = sum(1 for t in trackers if t.enabled) total = len(trackers) total_albums = sum(len(t.collection_ids or []) for t in trackers) engine = get_engine() async with AsyncSession(engine) as session: result = await session.exec( select(EventLog).order_by(EventLog.created_at.desc()).limit(1) ) last_event = result.first() last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-" return {"trackers_active": active, "trackers_total": total, "total_albums": total_albums, "last_event": last_str} async def _cmd_albums(bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str) -> dict[str, Any]: provider_ids = set(providers_map.keys()) trackers = await _get_notification_trackers_for_providers(provider_ids) if not trackers: return {"albums": []} albums_data: list[dict] = [] async with aiohttp.ClientSession() as http: for tracker in trackers: provider = providers_map.get(tracker.provider_id) if not provider or provider.type != "immich": continue immich = make_immich_provider(http, provider) for album_id in (tracker.collection_ids or []): try: album = await immich.client.get_album(album_id) if album: albums_data.append({"name": album.name, "asset_count": album.asset_count, "id": album_id}) except Exception: albums_data.append({"name": f"{album_id[:8]}...", "asset_count": "?", "id": album_id}) return {"albums": albums_data} async def _cmd_events(bot: TelegramBot, providers_map: dict[int, ServiceProvider], count: int, locale: str) -> dict[str, Any]: provider_ids = set(providers_map.keys()) trackers = await _get_notification_trackers_for_providers(provider_ids) tracker_ids = [t.id for t in trackers] if not tracker_ids: return {"events": []} engine = get_engine() async with AsyncSession(engine) as session: result = await session.exec( select(EventLog) .where(EventLog.tracker_id.in_(tracker_ids)) .order_by(EventLog.created_at.desc()) .limit(count) ) events = result.all() events_data = [{"type": e.event_type, "album": e.collection_name, "count": e.assets_count, "date": e.created_at.strftime("%m/%d %H:%M")} for e in events] return {"events": events_data} async def _cmd_people(providers_map: dict[int, ServiceProvider], locale: str) -> dict[str, Any]: all_people: dict[str, str] = {} async with aiohttp.ClientSession() as http: for provider in providers_map.values(): if provider.type != "immich": continue immich = make_immich_provider(http, provider) people = await immich.client.get_people() all_people.update(people) names = sorted(all_people.values()) return {"people": names} async def _cmd_immich( bot: TelegramBot, cmd: str, args: str, count: int, locale: str, response_mode: str, providers_map: dict[int, ServiceProvider], cmd_templates: dict[str, str], ) -> str | list[dict[str, Any]]: """Handle commands that need Immich API access and may return media.""" if not providers_map: return _render_cmd_template(cmd_templates, "no_results", {"command": cmd, "query": args, "locale": locale}) # Get notification trackers for album data provider_ids = set(providers_map.keys()) notification_trackers = await _get_notification_trackers_for_providers(provider_ids) all_album_ids: list[str] = [] for t in notification_trackers: all_album_ids.extend(t.collection_ids or []) # Pick the first immich provider provider: ServiceProvider | None = None for p in providers_map.values(): if p.type == "immich": provider = p break if not provider: return _render_cmd_template(cmd_templates, "no_results", {"command": cmd, "query": args, "locale": locale}) async with aiohttp.ClientSession() as http: immich = make_immich_provider(http, provider) client = immich.client if cmd == "search": if not args: return _render_cmd_template(cmd_templates, "no_results", {"command": cmd, "query": "", "locale": locale}) assets = await client.search_smart(args, album_ids=all_album_ids, limit=count) return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates) if cmd == "find": if not args: return _render_cmd_template(cmd_templates, "no_results", {"command": cmd, "query": "", "locale": locale}) assets = await client.search_metadata(args, album_ids=all_album_ids, limit=count) return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates) if cmd == "person": if not args: return _render_cmd_template(cmd_templates, "no_results", {"command": "person", "query": "", "locale": locale}) people = await client.get_people() person_id = None for pid, pname in people.items(): if args.lower() in pname.lower(): person_id = pid break if not person_id: return _render_cmd_template(cmd_templates, "no_results", {"command": "person", "query": args, "locale": locale}) assets = await client.search_by_person(person_id, limit=count) return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates) if cmd == "place": if not args: return _render_cmd_template(cmd_templates, "no_results", {"command": "place", "query": "", "locale": locale}) assets = await client.search_smart( f"photos taken in {args}", album_ids=all_album_ids, limit=count ) return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates) if cmd == "favorites": fav_assets: list[dict[str, Any]] = [] for album_id in all_album_ids[:10]: try: album = await client.get_album(album_id) if album: for aid, asset in list(album.assets.items())[:50]: if asset.is_favorite and len(fav_assets) < count: fav_assets.append({ "id": asset.id, "originalFileName": asset.filename, "type": asset.type, }) except Exception: pass if len(fav_assets) >= count: break return _format_assets(fav_assets, cmd, "", locale, response_mode, client, cmd_templates) if cmd == "latest": latest_assets: list[dict[str, Any]] = [] for album_id in all_album_ids[:10]: try: album = await client.get_album(album_id) if album: for aid, asset in list(album.assets.items())[:count]: latest_assets.append({ "id": asset.id, "originalFileName": asset.filename, "type": asset.type, "createdAt": asset.created_at, }) except Exception: pass latest_assets.sort(key=lambda a: a.get("createdAt", ""), reverse=True) return _format_assets(latest_assets[:count], cmd, "", locale, response_mode, client, cmd_templates) if cmd == "random": random_assets: list[dict[str, Any]] = [] for album_id in all_album_ids[:10]: try: album = await client.get_album(album_id) if album: asset_list = list(album.assets.values()) sampled = rng.sample(asset_list, min(count, len(asset_list))) for asset in sampled: random_assets.append({ "id": asset.id, "originalFileName": asset.filename, "type": asset.type, }) except Exception: pass rng.shuffle(random_assets) return _format_assets(random_assets[:count], cmd, "", locale, response_mode, client, cmd_templates) if cmd == "summary": albums_data: list[dict] = [] for album_id in all_album_ids: try: album = await client.get_album(album_id) if album: albums_data.append({"name": album.name, "asset_count": album.asset_count, "id": album_id}) except Exception: pass return _render_cmd_template(cmd_templates, "summary", {"albums": albums_data, "locale": locale}) if cmd == "memory": # Check if any linked tracking config uses native memories use_native = await _check_native_memory(bot) today = datetime.now(timezone.utc) memory_assets: list[dict[str, Any]] = [] if use_native: # Use Immich native memories API memories = await client.get_memories() tracked_ids = set(all_album_ids) if all_album_ids else None for mem in memories: year = mem.get("data", {}).get("year") for raw_asset in mem.get("assets", []): if tracked_ids: asset_albums = raw_asset.get("albums", []) if not any(a.get("id") in tracked_ids for a in asset_albums): continue memory_assets.append({ "id": raw_asset.get("id", ""), "originalFileName": raw_asset.get("originalFileName", ""), "type": raw_asset.get("type", "IMAGE"), "createdAt": raw_asset.get("fileCreatedAt", raw_asset.get("createdAt", "")), "year": year, }) else: # Album-scanning fallback month_day = (today.month, today.day) for album_id in all_album_ids[:10]: try: album = await client.get_album(album_id) if album: for aid, asset in album.assets.items(): try: dt = datetime.fromisoformat(asset.created_at.replace("Z", "+00:00")) if (dt.month, dt.day) == month_day and dt.year != today.year: memory_assets.append({ "id": asset.id, "originalFileName": asset.filename, "type": asset.type, "createdAt": asset.created_at, "year": dt.year, }) except (ValueError, AttributeError): pass except Exception: pass memory_assets = memory_assets[:count] if not memory_assets: return _render_cmd_template(cmd_templates, "no_results", {"command": "memory", "query": "", "locale": locale}) return _format_assets(memory_assets, cmd, "", locale, response_mode, client, cmd_templates) return None def _format_assets( assets: list[dict[str, Any]], cmd: str, query: str, locale: str, response_mode: str, client: Any, cmd_templates: dict[str, str], ) -> str | list[dict[str, Any]]: """Format asset results as text or media payload.""" if not assets: return _render_cmd_template(cmd_templates, "no_results", {"command": cmd, "query": query, "locale": locale}) if response_mode == "media": media_items = [] for asset in assets: asset_id = asset.get("id", "") filename = asset.get("originalFileName", "") year = asset.get("year", "") caption = f"{filename} ({year})" if year else filename media_items.append({ "type": "photo", "asset_id": asset_id, "caption": caption, "thumbnail_url": f"{client.url}/api/assets/{asset_id}/thumbnail?size=preview", "api_key": client.api_key, }) return media_items # Text mode — render via template slot_map = {"find": "search", "person": "search", "place": "search"} slot_name = slot_map.get(cmd, cmd) return _render_cmd_template(cmd_templates, slot_name, { "assets": assets, "query": query, "command": cmd, "count": len(assets), "locale": locale, }) async def send_reply(bot_token: str, chat_id: str, text: str) -> None: """Send a text reply via Telegram Bot API, retrying without HTML on parse failure.""" async with aiohttp.ClientSession() as http: url = f"{TELEGRAM_API_BASE_URL}{bot_token}/sendMessage" payload: dict[str, Any] = {"chat_id": chat_id, "text": text, "parse_mode": "HTML"} try: async with http.post(url, json=payload) as resp: if resp.status != 200: result = await resp.json() _LOGGER.debug("Telegram reply failed: %s", result.get("description")) if "parse" in str(result.get("description", "")).lower(): payload.pop("parse_mode", None) async with http.post(url, json=payload) as retry_resp: if retry_resp.status != 200: _LOGGER.warning("Telegram reply failed on retry") except aiohttp.ClientError as err: _LOGGER.error("Failed to send Telegram reply: %s", err) async def send_media_group( bot_token: str, chat_id: str, media_items: list[dict[str, Any]], ) -> None: """Send media items as a Telegram media group (album). Falls back to individual sendPhoto calls if sendMediaGroup fails. Telegram allows max 10 items per media group. """ if not media_items: return async with aiohttp.ClientSession() as http: # Download all thumbnails first downloaded: list[tuple[bytes, str, str]] = [] # (photo_bytes, asset_id, caption) for item in media_items: asset_id = item.get("asset_id", "") caption = item.get("caption", "") thumb_url = item.get("thumbnail_url", "") api_key = item.get("api_key", "") try: async with http.get(thumb_url, headers={"x-api-key": api_key}) as resp: if resp.status != 200: _LOGGER.warning("Failed to download thumbnail for %s: HTTP %d", asset_id, resp.status) continue photo_bytes = await resp.read() downloaded.append((photo_bytes, asset_id, caption)) except aiohttp.ClientError: continue if not downloaded: return # Send in groups of 10 (Telegram limit) for i in range(0, len(downloaded), 10): chunk = downloaded[i:i + 10] if len(chunk) == 1: # Single photo — use sendPhoto photo_bytes, asset_id, caption = chunk[0] data = aiohttp.FormData() data.add_field("chat_id", chat_id) data.add_field("photo", photo_bytes, filename=f"{asset_id}.jpg", content_type="image/jpeg") if caption: data.add_field("caption", caption) try: async with http.post(f"{TELEGRAM_API_BASE_URL}{bot_token}/sendPhoto", data=data) as resp: if resp.status != 200: result = await resp.json() _LOGGER.warning("Failed to send photo: %s", result.get("description")) except aiohttp.ClientError as err: _LOGGER.warning("Failed to send photo: %s", err) else: # Multiple photos — use sendMediaGroup import json as _json data = aiohttp.FormData() data.add_field("chat_id", chat_id) media_array = [] for idx, (photo_bytes, asset_id, caption) in enumerate(chunk): attach_key = f"photo_{idx}" media_obj: dict[str, Any] = {"type": "photo", "media": f"attach://{attach_key}"} if caption: media_obj["caption"] = caption media_array.append(media_obj) data.add_field(attach_key, photo_bytes, filename=f"{asset_id}.jpg", content_type="image/jpeg") data.add_field("media", _json.dumps(media_array)) try: async with http.post(f"{TELEGRAM_API_BASE_URL}{bot_token}/sendMediaGroup", data=data) as resp: if resp.status != 200: result = await resp.json() _LOGGER.warning("Failed to send media group: %s", result.get("description")) except aiohttp.ClientError as err: _LOGGER.warning("Failed to send media group: %s", err) async def register_commands_with_telegram(bot: TelegramBot) -> bool: """Register enabled commands with Telegram BotFather API.""" ctx_tuples, _ = await _resolve_command_context(bot) enabled, locale, _, _, _ = _merge_command_context(ctx_tuples) commands = [] for cmd in enabled: desc = COMMAND_DESCRIPTIONS.get(cmd, {}) commands.append({ "command": cmd, "description": desc.get(locale, desc.get("en", cmd)), }) async with aiohttp.ClientSession() as http: url = f"{TELEGRAM_API_BASE_URL}{bot.token}/setMyCommands" payload: dict[str, Any] = {"commands": commands} try: async with http.post(url, json=payload) as resp: result = await resp.json() if result.get("ok"): _LOGGER.info("Registered %d commands for bot @%s", len(commands), bot.bot_username) # Also register for the other locale other_locale = "ru" if locale == "en" else "en" other_commands = [ {"command": c, "description": COMMAND_DESCRIPTIONS.get(c, {}).get(other_locale, c)} for c in enabled ] async with http.post(url, json={"commands": other_commands, "language_code": other_locale}) as r2: pass return True _LOGGER.warning("Failed to register commands: %s", result.get("description")) return False except aiohttp.ClientError as err: _LOGGER.error("Failed to register commands: %s", err) return False