refactor: comprehensive codebase review — security, performance, quality, UX

Security:
- Fix NUT protocol command injection (validate names against safe regex)
- Enable Jinja2 autoescape=True to prevent HTML injection via external data
- Add WebhookProviderConfig validation model

Performance:
- Shared aiohttp.ClientSession singleton (replaces 40+ per-request sessions)
- Fix 4 N+1 queries with batch IN loads (poller, scheduler, memory, broadcast)
- asyncio.gather for Gitea commands and notification dispatcher
- Add DB indexes on NotificationTrackerState.tracker_id, CommandTrackerListener
- LRU cache for compiled Jinja2 templates
- Daily EventLog cleanup job (90-day retention)
- 30s HTTP timeout on all external calls
- GROUP BY for target type counts (replaces 7 sequential queries)

Code quality:
- Extract get_owned_entity() helper (replaces 11 duplicate functions)
- Extract slot_helpers.py (load_slots, save_slots, render_template_preview)
- Extract command_utils.py (tracker lookup, last event, collection IDs)
- Extract http_session.py (shared session lifecycle)
- Provider connection validation dedup (3x → 1 helper)
- Command dispatch tables replacing if/elif chains
- Album+links fetch helper (fetch_albums_with_links)
- Provider dispatch polymorphism (list_provider_collections)
- Immutable _enrich_assets (no longer mutates in-place)
- Fix _format_assets return type + handler unpacking

Frontend:
- Fix 18+ hardcoded English strings → t() with new i18n keys (en + ru)
- Mobile "More" nav panel with provider filter and search
- Shared Button.svelte component (4 variants, 2 sizes)
- Shared ErrorBanner.svelte component (8 pages updated)
- SvelteKit goto() replacing window.location.href
- Dashboard grid fixed for 4 cards, paginator opacity consistency

Functionality:
- max_instances=1 on scheduler jobs (prevents duplicate events)
- Webhook provider in watcher (prevents error spam)
- Fix stale SQLModel reference in poller
- Gitea get_repo() direct API call
This commit is contained in:
2026-03-28 13:22:26 +03:00
parent 616b221c92
commit b803d004e1
65 changed files with 1934 additions and 1498 deletions
@@ -3,9 +3,18 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from ..database.models import CommandTracker, CommandConfig, ServiceProvider, TelegramBot
from ..database.models import CommandConfig, CommandTracker, ServiceProvider, TelegramBot
@dataclass(frozen=True)
class CommandResponse:
"""A single response from one tracker's command execution."""
text: str | None = None
media: list[dict[str, Any]] = field(default_factory=list)
class ProviderCommandHandler(ABC):
@@ -14,6 +23,8 @@ class ProviderCommandHandler(ABC):
Each provider (Immich, Gitea, etc.) implements this interface to handle
its own set of commands. The dispatch layer routes commands to the
correct handler based on the provider type.
Each handler call receives a single (tracker, config, provider) context.
"""
provider_type: str
@@ -35,26 +46,28 @@ class ProviderCommandHandler(ABC):
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
"""Handle a provider-specific command.
tracker: CommandTracker,
config: CommandConfig,
) -> CommandResponse | None:
"""Handle a provider-specific command for a single tracker.
Args:
cmd: The command name (without '/').
args: Arguments after the command.
count: Number of results to return.
locale: User's locale ('en', 'ru').
response_mode: 'media' or 'text'.
providers_map: Provider instances keyed by ID.
cmd_templates: Template slots {slot_name: {locale: template}}.
response_mode: 'media' or 'text' (from this tracker's config).
provider: The service provider instance for this tracker.
cmd_templates: Template slots for this tracker's command template config.
bot: The Telegram bot instance.
ctx_tuples: Command context tuples for this provider type.
tracker: The command tracker being dispatched.
config: The command config for this tracker.
Returns:
Text response, media list, or None if unhandled.
A CommandResponse, or None if unhandled.
"""
def get_rate_categories(self) -> dict[str, str]:
@@ -0,0 +1,67 @@
"""Shared command handler utilities to reduce boilerplate across providers."""
from __future__ import annotations
import logging
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import EventLog, NotificationTracker, ServiceProvider
_LOGGER = logging.getLogger(__name__)
async def get_trackers_for_provider(provider_id: int) -> list[NotificationTracker]:
"""Get notification trackers for a single provider."""
from .handler import _get_notification_trackers_for_providers
return await _get_notification_trackers_for_providers({provider_id})
async def get_last_event_str(tracker_ids: list[int]) -> str:
"""Get formatted timestamp of most recent event for given trackers.
Returns a 'YYYY-MM-DD HH:MM' string, or '-' if no events exist.
"""
if not tracker_ids:
return "-"
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(EventLog)
.where(EventLog.tracker_id.in_(tracker_ids))
.order_by(EventLog.created_at.desc())
.limit(1)
)
last_event = result.first()
return last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
def get_tracked_collection_ids(
provider: ServiceProvider,
trackers: list[NotificationTracker],
*,
max_items: int = 20,
) -> list[str]:
"""Get deduplicated collection IDs from trackers for a provider.
Iterates all trackers belonging to *provider*, collects IDs from both
``collection_ids`` and ``filters.collections``, deduplicates while
preserving order, and caps at *max_items*.
"""
seen: set[str] = set()
result: list[str] = []
for tracker in trackers:
if tracker.provider_id != provider.id:
continue
for cid in tracker.collection_ids or []:
if cid not in seen:
seen.add(cid)
result.append(cid)
for cid in (tracker.filters or {}).get("collections", []):
if cid not in seen:
seen.add(cid)
result.append(cid)
return result[:max_items]
@@ -2,27 +2,55 @@
from __future__ import annotations
import asyncio
import logging
from collections.abc import Callable, Coroutine
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import (
CommandConfig, CommandTracker, EventLog,
NotificationTracker, ServiceProvider, TelegramBot,
CommandConfig, CommandTracker, ServiceProvider, TelegramBot,
)
from ..services import make_gitea_provider
from .base import ProviderCommandHandler
from .handler import _render_cmd_template, _get_notification_trackers_for_providers
from ..services.http_session import get_http_session
from .base import CommandResponse, ProviderCommandHandler
from .command_utils import get_last_event_str, get_tracked_collection_ids, get_trackers_for_provider
from .handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__)
_GITEA_COMMANDS = {"status", "repos", "issues", "prs", "commits"}
def _get_tracked_repos(
provider: ServiceProvider,
trackers: list,
) -> list[tuple[ServiceProvider, str, str]]:
"""Get (provider, owner, repo) tuples from tracked collection_ids."""
if not provider.config.get("api_token"):
return []
collection_ids = get_tracked_collection_ids(provider, trackers)
repos: list[tuple[ServiceProvider, str, str]] = []
for full_name in collection_ids:
parts = full_name.split("/", 1)
if len(parts) == 2:
repos.append((provider, parts[0], parts[1]))
return repos
# ---------------------------------------------------------------------------
# Command dispatch table
# ---------------------------------------------------------------------------
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
"""Register a function in the text command dispatch table."""
name = fn.__name__.removeprefix("_cmd_")
_TEXT_COMMANDS[name] = fn
return fn
class GiteaCommandHandler(ProviderCommandHandler):
"""Handles Gitea-specific bot commands."""
@@ -44,91 +72,35 @@ class GiteaCommandHandler(ProviderCommandHandler):
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
if cmd == "status":
ctx = await _cmd_status(providers_map)
return _render_cmd_template(cmd_templates, "status", locale, ctx)
if cmd == "repos":
ctx = await _cmd_repos(providers_map)
return _render_cmd_template(cmd_templates, "repos", locale, ctx)
if cmd == "issues":
ctx = await _cmd_issues(providers_map, count)
return _render_cmd_template(cmd_templates, "issues", locale, ctx)
if cmd == "prs":
ctx = await _cmd_prs(providers_map, count)
return _render_cmd_template(cmd_templates, "prs", locale, ctx)
if cmd == "commits":
ctx = await _cmd_commits(providers_map, count)
return _render_cmd_template(cmd_templates, "commits", locale, ctx)
return None
tracker: CommandTracker,
config: CommandConfig,
) -> CommandResponse | None:
fn = _TEXT_COMMANDS.get(cmd)
if fn is None:
return None
ctx = await fn(provider, count)
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
def _get_tracked_repos(
providers_map: dict[int, ServiceProvider],
trackers: list[NotificationTracker],
) -> list[tuple[ServiceProvider, str, str]]:
"""Get (provider, owner, repo) tuples from tracked collection_ids."""
repos: list[tuple[ServiceProvider, str, str]] = []
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "gitea":
continue
if not provider.config.get("api_token"):
continue
for full_name in (tracker.collection_ids or []):
parts = full_name.split("/", 1)
if len(parts) == 2:
repos.append((provider, parts[0], parts[1]))
# Also check filters.collections
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "gitea":
continue
if not provider.config.get("api_token"):
continue
for full_name in (tracker.filters or {}).get("collections", []):
parts = full_name.split("/", 1)
if len(parts) == 2:
entry = (provider, parts[0], parts[1])
if entry not in repos:
repos.append(entry)
return repos[:20] # Cap to prevent API hammering
@_text_cmd
async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_repos = _get_tracked_repos(provider, trackers)
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
# Get server version from first Gitea provider with token
# Get server version
server_version = "unknown"
async with aiohttp.ClientSession() as http:
for provider in providers_map.values():
if provider.type == "gitea" and provider.config.get("api_token"):
gitea = make_gitea_provider(http, provider)
version = await gitea.client.get_server_version()
if version:
server_version = version
break
if provider.config.get("api_token"):
http = await get_http_session()
gitea = make_gitea_provider(http, provider)
version = await gitea.client.get_server_version()
if version:
server_version = version
# Last event
engine = get_engine()
async with AsyncSession(engine) as session:
tracker_ids = [t.id for t in trackers]
if tracker_ids:
result = await session.exec(
select(EventLog)
.where(EventLog.tracker_id.in_(tracker_ids))
.order_by(EventLog.created_at.desc()).limit(1)
)
last_event = result.first()
else:
last_event = None
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
tracker_ids = [t.id for t in trackers]
last_str = await get_last_event_str(tracker_ids)
return {
"repos_count": len(tracked_repos),
@@ -137,116 +109,139 @@ async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, An
}
async def _cmd_repos(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
@_text_cmd
async def _cmd_repos(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_repos = _get_tracked_repos(provider, trackers)
repos_data: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider)
try:
all_repos = await gitea.client.get_repos(limit=50)
for r in all_repos:
if r.get("full_name") == f"{owner}/{repo}":
repos_data.append({
"full_name": r.get("full_name", ""),
"description": r.get("description", ""),
"stars": r.get("stars_count", 0),
"url": r.get("html_url", ""),
})
break
else:
repos_data.append({
"full_name": f"{owner}/{repo}",
"description": "",
"stars": 0,
"url": "",
})
except Exception:
repos_data.append({
"full_name": f"{owner}/{repo}",
"description": "?",
"stars": 0,
"url": "",
})
http = await get_http_session()
async def _fetch_repo(prov: ServiceProvider, owner: str, repo: str) -> dict[str, Any]:
gitea = make_gitea_provider(http, prov)
# Use direct get_repo endpoint instead of listing all repos
r = await gitea.client.get_repo(owner, repo)
if r:
return {
"full_name": r.get("full_name", ""),
"description": r.get("description", ""),
"stars": r.get("stars_count", 0),
"url": r.get("html_url", ""),
}
return {
"full_name": f"{owner}/{repo}",
"description": "",
"stars": 0,
"url": "",
}
tasks = [_fetch_repo(prov, owner, repo) for prov, owner, repo in tracked_repos]
results = await asyncio.gather(*tasks, return_exceptions=True)
for (prov, owner, repo), result in zip(tracked_repos, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch repo %s/%s: %s", owner, repo, result)
repos_data.append({
"full_name": f"{owner}/{repo}",
"description": "?",
"stars": 0,
"url": "",
})
else:
repos_data.append(result)
return {"repos": repos_data}
async def _cmd_issues(
providers_map: dict[int, ServiceProvider], count: int,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
@_text_cmd
async def _cmd_issues(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_repos = _get_tracked_repos(provider, trackers)
all_issues: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider)
issues = await gitea.client.get_repo_issues(owner, repo, limit=count)
for issue in issues:
all_issues.append({
"repo": f"{owner}/{repo}",
"number": issue.get("number", 0),
"title": issue.get("title", ""),
"url": issue.get("html_url", ""),
"user": issue.get("user", {}).get("login", ""),
"state": issue.get("state", ""),
})
http = await get_http_session()
async def _fetch_issues(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
gitea = make_gitea_provider(http, prov)
return await gitea.client.get_repo_issues(owner, repo, limit=count)
tasks = [_fetch_issues(prov, owner, repo) for prov, owner, repo in tracked_repos]
results = await asyncio.gather(*tasks, return_exceptions=True)
for (prov, owner, repo), result in zip(tracked_repos, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch issues for %s/%s: %s", owner, repo, result)
continue
for issue in result:
all_issues.append({
"repo": f"{owner}/{repo}",
"number": issue.get("number", 0),
"title": issue.get("title", ""),
"url": issue.get("html_url", ""),
"user": issue.get("user", {}).get("login", ""),
"state": issue.get("state", ""),
})
all_issues.sort(key=lambda i: i.get("number", 0), reverse=True)
return {"issues": all_issues[:count]}
async def _cmd_prs(
providers_map: dict[int, ServiceProvider], count: int,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
@_text_cmd
async def _cmd_prs(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_repos = _get_tracked_repos(provider, trackers)
all_prs: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider)
prs = await gitea.client.get_repo_pulls(owner, repo, limit=count)
for pr in prs:
all_prs.append({
"repo": f"{owner}/{repo}",
"number": pr.get("number", 0),
"title": pr.get("title", ""),
"url": pr.get("html_url", ""),
"user": pr.get("user", {}).get("login", ""),
"state": pr.get("state", ""),
})
http = await get_http_session()
async def _fetch_prs(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
gitea = make_gitea_provider(http, prov)
return await gitea.client.get_repo_pulls(owner, repo, limit=count)
tasks = [_fetch_prs(prov, owner, repo) for prov, owner, repo in tracked_repos]
results = await asyncio.gather(*tasks, return_exceptions=True)
for (prov, owner, repo), result in zip(tracked_repos, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch PRs for %s/%s: %s", owner, repo, result)
continue
for pr in result:
all_prs.append({
"repo": f"{owner}/{repo}",
"number": pr.get("number", 0),
"title": pr.get("title", ""),
"url": pr.get("html_url", ""),
"user": pr.get("user", {}).get("login", ""),
"state": pr.get("state", ""),
})
all_prs.sort(key=lambda p: p.get("number", 0), reverse=True)
return {"prs": all_prs[:count]}
async def _cmd_commits(
providers_map: dict[int, ServiceProvider], count: int,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
@_text_cmd
async def _cmd_commits(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_repos = _get_tracked_repos(provider, trackers)
all_commits: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider)
commits = await gitea.client.get_repo_commits(owner, repo, limit=count)
for c in commits:
commit_data = c.get("commit", {})
all_commits.append({
"repo": f"{owner}/{repo}",
"short_id": c.get("sha", "")[:7],
"message": commit_data.get("message", "").split("\n")[0][:80],
"author": commit_data.get("author", {}).get("name", ""),
"url": c.get("html_url", ""),
})
http = await get_http_session()
async def _fetch_commits(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
gitea = make_gitea_provider(http, prov)
return await gitea.client.get_repo_commits(owner, repo, limit=count)
tasks = [_fetch_commits(prov, owner, repo) for prov, owner, repo in tracked_repos]
results = await asyncio.gather(*tasks, return_exceptions=True)
for (prov, owner, repo), result in zip(tracked_repos, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch commits for %s/%s: %s", owner, repo, result)
continue
for c in result:
commit_data = c.get("commit", {})
all_commits.append({
"repo": f"{owner}/{repo}",
"short_id": c.get("sha", "")[:7],
"message": commit_data.get("message", "").split("\n")[0][:80],
"author": commit_data.get("author", {}).get("name", ""),
"url": c.get("html_url", ""),
})
return {"commits": all_commits[:count]}
@@ -4,6 +4,7 @@ from __future__ import annotations
import logging
import time
from functools import lru_cache
from typing import Any
import aiohttp
@@ -25,17 +26,21 @@ from ..database.models import (
ServiceProvider,
TelegramBot,
)
from .base import CommandResponse
from .parser import parse_command
from .registry import get_rate_category
_LOGGER = logging.getLogger(__name__)
# Singleton Jinja2 environment for template rendering (Phase 4d)
_JINJA_ENV = SandboxedEnvironment(autoescape=False)
_JINJA_ENV = SandboxedEnvironment(autoescape=True)
# Rate limit state with automatic TTL expiry (Phase 4e)
_rate_limits: TTLCache = TTLCache(maxsize=10000, ttl=3600)
# Maximum responses per command to avoid Telegram rate limits
_MAX_RESPONSES_PER_COMMAND = 5
def _check_rate_limit(bot_id: int, chat_id: str, cmd: str, limits: dict[str, int]) -> int | None:
"""Check rate limit. Returns seconds to wait, or None if OK."""
@@ -60,6 +65,12 @@ def _resolve_template(
return locale_map.get(locale) or locale_map.get("en")
@lru_cache(maxsize=256)
def _compile_template(template_str: str):
"""Cache compiled Jinja2 templates to avoid re-parsing identical strings."""
return _JINJA_ENV.from_string(template_str)
def _render_cmd_template(
templates: dict[str, dict[str, str]], slot_name: str, locale: str,
context: dict[str, Any],
@@ -70,20 +81,28 @@ def _render_cmd_template(
_LOGGER.warning("No command template found for slot '%s' locale '%s'", slot_name, locale)
return f"[No template: {slot_name}]"
try:
tmpl = _JINJA_ENV.from_string(template_str)
tmpl = _compile_template(template_str)
return tmpl.render(**context)
except Exception as e:
_LOGGER.warning("Failed to render command template '%s': %s", slot_name, e)
return f"[Template error: {slot_name}]"
# ---------------------------------------------------------------------------
# Context resolution
# ---------------------------------------------------------------------------
async def _resolve_command_context(
bot: TelegramBot,
) -> tuple[list[tuple[CommandTracker, CommandConfig, ServiceProvider]], dict[str, dict[str, str]]]:
) -> tuple[
list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
dict[int, dict[str, dict[str, str]]],
]:
"""Resolve all enabled command trackers, configs, and providers for a bot.
Returns (context_tuples, cmd_template_slots).
cmd_template_slots is {slot_name: {locale: template}}.
Returns:
(context_tuples, templates_by_config_id)
templates_by_config_id is {command_template_config_id: {slot_name: {locale: template}}}.
"""
engine = get_engine()
async with AsyncSession(engine) as session:
@@ -142,8 +161,8 @@ async def _resolve_command_context(
continue
tuples.append((tracker, config, provider))
# Load command template slots — merge from all configs
cmd_template_slots: dict[str, dict[str, str]] = {}
# Load command template slots per config (not merged)
templates_by_config_id: dict[int, dict[str, dict[str, str]]] = {}
seen_config_ids: set[int] = set()
for _, config, _ in tuples:
cfg_id = config.command_template_config_id
@@ -154,98 +173,136 @@ async def _resolve_command_context(
CommandTemplateSlot.config_id == cfg_id
)
)
slots: dict[str, dict[str, str]] = {}
for s in slot_result.all():
cmd_template_slots.setdefault(s.slot_name, {})[s.locale] = s.template
slots.setdefault(s.slot_name, {})[s.locale] = s.template
templates_by_config_id[cfg_id] = slots
return tuples, cmd_template_slots
return tuples, templates_by_config_id
def _merge_command_context(
def _templates_for_config(
templates_by_config_id: dict[int, dict[str, dict[str, str]]],
config: CommandConfig,
) -> dict[str, dict[str, str]]:
"""Get template slots for a specific command config."""
cfg_id = config.command_template_config_id
if cfg_id and cfg_id in templates_by_config_id:
return templates_by_config_id[cfg_id]
return {}
def _merge_all_templates(
templates_by_config_id: dict[int, dict[str, dict[str, str]]],
) -> dict[str, dict[str, str]]:
"""Merge all template config slots into one dict (for universal commands)."""
merged: dict[str, dict[str, str]] = {}
for slots in templates_by_config_id.values():
for slot_name, locale_map in slots.items():
merged.setdefault(slot_name, {}).update(locale_map)
return merged
def _merge_enabled_commands(
ctx: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> tuple[list[str], str, int, dict[str, Any]]:
"""Merge enabled_commands from all configs and pick defaults from first config."""
) -> tuple[list[str], dict[str, Any]]:
"""Merge enabled_commands (union) and rate_limits from all configs.
Rate limits use the most restrictive (minimum) cooldown per category.
"""
if not ctx:
return [], "media", 5, {}
return [], {}
enabled: set[str] = set()
merged_limits: dict[str, int] = {}
for _, config, _ in ctx:
enabled.update(config.enabled_commands or [])
for category, cooldown in (config.rate_limits or {}).items():
if category not in merged_limits:
merged_limits[category] = cooldown
else:
merged_limits[category] = min(merged_limits[category], cooldown)
first_config = ctx[0][1]
response_mode = first_config.response_mode or "media"
default_count = first_config.default_count or 5
rate_limits = first_config.rate_limits or {}
return sorted(enabled), merged_limits
return sorted(enabled), response_mode, default_count, rate_limits
# ---------------------------------------------------------------------------
# Main dispatcher
# ---------------------------------------------------------------------------
async def handle_command(
bot: TelegramBot,
chat_id: str,
text: str,
language_code: str = "",
) -> str | list[dict[str, Any]] | None:
) -> list[CommandResponse] | None:
"""Handle a bot command. Routes to provider-specific handlers.
Returns text response, media list, or None.
Returns a list of CommandResponse objects (one per tracker), or None.
Universal commands (/start, /help) return a single-element list.
Provider-specific commands dispatch per-tracker with per-tracker config.
"""
cmd, args, count_override = parse_command(text)
if not cmd:
return None
ctx_tuples, cmd_templates = await _resolve_command_context(bot)
enabled, response_mode, default_count, rate_limits = _merge_command_context(ctx_tuples)
ctx_tuples, templates_by_config_id = await _resolve_command_context(bot)
enabled, rate_limits = _merge_enabled_commands(ctx_tuples)
locale = language_code[:2].lower() if language_code else "en"
if locale not in ("en", "ru"):
locale = "en"
# Merged templates for universal commands
merged_templates = _merge_all_templates(templates_by_config_id)
if cmd == "start":
return _render_cmd_template(cmd_templates, "start", locale, {"bot_name": bot.name})
text_resp = _render_cmd_template(merged_templates, "start", locale, {"bot_name": bot.name})
return [CommandResponse(text=text_resp)]
if cmd not in enabled and cmd != "start":
return None
# Rate limit check
# Rate limit check (once per command, shared across all trackers)
wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits)
if wait is not None:
return _render_cmd_template(cmd_templates, "rate_limited", locale, {"wait": wait})
text_resp = _render_cmd_template(merged_templates, "rate_limited", locale, {"wait": wait})
return [CommandResponse(text=text_resp)]
count = min(count_override or default_count, 20)
# Build providers map from command context
providers_map: dict[int, ServiceProvider] = {}
for _, _, provider in ctx_tuples:
providers_map[provider.id] = provider
# Universal commands
# Universal commands — single merged response
if cmd == "help":
ctx = _cmd_help(enabled, locale, cmd_templates)
return _render_cmd_template(cmd_templates, "help", locale, ctx)
ctx = _cmd_help(enabled, locale, merged_templates)
text_resp = _render_cmd_template(merged_templates, "help", locale, ctx)
return [CommandResponse(text=text_resp)]
# Provider-specific dispatch
# Provider-specific dispatch — per-tracker
from .dispatch import get_handler
# Group ctx_tuples by provider type
by_type: dict[str, list[tuple[CommandTracker, CommandConfig, ServiceProvider]]] = {}
for t in ctx_tuples:
ptype = t[2].type
by_type.setdefault(ptype, []).append(t)
# Find which handler claims this command
for ptype, ptuples in by_type.items():
handler = get_handler(ptype)
if handler and cmd in handler.get_provider_commands():
# Build provider map filtered to this provider type
pmap = {p.id: p for _, _, p in ptuples}
result = await handler.handle(
cmd, args, count, locale, response_mode,
pmap, cmd_templates, bot, ptuples,
responses: list[CommandResponse] = []
for tracker, config, provider in ctx_tuples:
if len(responses) >= _MAX_RESPONSES_PER_COMMAND:
_LOGGER.warning(
"Truncated command responses at %d for bot %d cmd /%s",
_MAX_RESPONSES_PER_COMMAND, bot.id, cmd,
)
if result is not None:
return result
break
return None
handler = get_handler(provider.type)
if not handler or cmd not in handler.get_provider_commands():
continue
tracker_templates = _templates_for_config(templates_by_config_id, config)
count = min(count_override or config.default_count or 5, 20)
response_mode = config.response_mode or "media"
result = await handler.handle(
cmd, args, count, locale, response_mode,
provider, tracker_templates, bot, tracker, config,
)
if result is not None:
responses.append(result)
return responses if responses else None
def _cmd_help(
@@ -283,17 +340,13 @@ async def send_reply(
session: aiohttp.ClientSession | None = None,
) -> None:
"""Send a text reply via TelegramClient."""
async def _send(http: aiohttp.ClientSession) -> None:
client = TelegramClient(http, bot_token)
result = await client.send_message(chat_id, text, reply_to_message_id=reply_to_message_id)
if not result.get("success"):
_LOGGER.warning("Telegram reply failed: %s", result.get("error"))
if session is not None:
await _send(session)
else:
async with aiohttp.ClientSession() as http:
await _send(http)
if session is None:
from ..services.http_session import get_http_session
session = await get_http_session()
client = TelegramClient(session, bot_token)
result = await client.send_message(chat_id, text, reply_to_message_id=reply_to_message_id)
if not result.get("success"):
_LOGGER.warning("Telegram reply failed: %s", result.get("error"))
async def send_media_group(
@@ -319,52 +372,50 @@ async def send_media_group(
captions = [item.get("caption", "") for item in media_items if item.get("caption")]
caption = "\n".join(captions) if captions else None
async def _send(http: aiohttp.ClientSession) -> None:
client = TelegramClient(http, bot_token)
result = await client.send_notification(
chat_id, assets=assets, caption=caption,
reply_to_message_id=reply_to_message_id,
chat_action=None,
)
if not result.get("success"):
_LOGGER.warning("Telegram media group failed: %s", result.get("error"))
if session is not None:
await _send(session)
else:
async with aiohttp.ClientSession() as http:
await _send(http)
if session is None:
from ..services.http_session import get_http_session
session = await get_http_session()
client = TelegramClient(session, bot_token)
result = await client.send_notification(
chat_id, assets=assets, caption=caption,
reply_to_message_id=reply_to_message_id,
chat_action=None,
)
if not result.get("success"):
_LOGGER.warning("Telegram media group failed: %s", result.get("error"))
async def register_commands_with_telegram(bot: TelegramBot) -> bool:
"""Register enabled commands with Telegram BotFather API via TelegramClient."""
ctx_tuples, templates = await _resolve_command_context(bot)
enabled, _, _, _ = _merge_command_context(ctx_tuples)
ctx_tuples, templates_by_config_id = await _resolve_command_context(bot)
enabled, _ = _merge_enabled_commands(ctx_tuples)
templates = _merge_all_templates(templates_by_config_id)
async with aiohttp.ClientSession() as http:
client = TelegramClient(http, bot.token)
success = False
from ..services.http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, bot.token)
success = False
# Register per-locale commands
for locale in ("en", "ru"):
commands = []
for cmd in enabled:
desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd
commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(commands, language_code=locale)
if result.get("success"):
success = True
else:
_LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error"))
# Register default (no language_code) with EN descriptions
en_commands = []
# Register per-locale commands
for locale in ("en", "ru"):
commands = []
for cmd in enabled:
desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd
en_commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(en_commands)
desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd
commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(commands, language_code=locale)
if result.get("success"):
_LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username)
success = True
else:
_LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error"))
return success
# Register default (no language_code) with EN descriptions
en_commands = []
for cmd in enabled:
desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd
en_commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(en_commands)
if result.get("success"):
_LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username)
success = True
return success
@@ -6,70 +6,48 @@ import asyncio
import logging
from typing import Any
import aiohttp
from notify_bridge_core.providers.immich.asset_utils import get_public_url
from ...database.models import ServiceProvider, TelegramBot
from ...database.models import ServiceProvider
from ...services import make_immich_provider
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
from .common import _format_assets, build_asset_dict
from ...services.http_session import get_http_session
from ..command_utils import get_trackers_for_provider
from ..handler import _render_cmd_template
from .common import _format_assets, build_asset_dict, fetch_albums_with_links
_LOGGER = logging.getLogger(__name__)
async def _cmd_albums(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
provider: ServiceProvider, locale: str,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
trackers = await get_trackers_for_provider(provider.id)
if not trackers:
return {"albums": []}
albums_data: list[dict] = []
async with aiohttp.ClientSession() as http:
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "immich":
continue
immich = make_immich_provider(http, provider)
album_ids = tracker.collection_ids or []
if not album_ids:
continue
# Deduplicate album IDs while preserving order
seen: set[str] = set()
album_ids: list[str] = []
for tracker in trackers:
for aid in tracker.collection_ids or []:
if aid not in seen:
seen.add(aid)
album_ids.append(aid)
if not album_ids:
return {"albums": []}
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
album_results = await asyncio.gather(
*[immich.client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[immich.client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
for album_id, result, links in zip(album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
albums_data.append({
"name": f"{album_id[:8]}...", "asset_count": "?", "id": album_id,
})
elif result:
pub_url = ""
if not isinstance(links, Exception) and ext_domain:
pub_url = get_public_url(ext_domain, links) or ""
albums_data.append({
"name": result.name, "asset_count": result.asset_count,
"id": album_id, "public_url": pub_url,
})
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
http = await get_http_session()
immich = make_immich_provider(http, provider)
albums_data = await fetch_albums_with_links(immich.client, album_ids, ext_domain)
return {"albums": albums_data}
async def cmd_favorites(
bot: TelegramBot, providers_map: dict[int, ServiceProvider],
providers_map: dict[int, ServiceProvider],
all_album_ids: list[str], count: int, locale: str,
response_mode: str, client: Any,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /favorites command with concurrent album fetching."""
album_ids = all_album_ids[:10]
if not album_ids:
@@ -104,28 +82,6 @@ async def cmd_summary(
if not all_album_ids:
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": []})
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in all_album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in all_album_ids],
return_exceptions=True,
)
ext = external_domain.rstrip("/")
albums_data: list[dict] = []
for album_id, result, links in zip(all_album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
pub_url = ""
if not isinstance(links, Exception) and ext:
pub_url = get_public_url(ext, links) or ""
albums_data.append({
"name": result.name, "asset_count": result.asset_count,
"id": album_id, "public_url": pub_url,
})
albums_data = await fetch_albums_with_links(client, all_album_ids, ext, include_failed=False)
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": albums_data})
@@ -2,10 +2,12 @@
from __future__ import annotations
import asyncio
import logging
from typing import Any
from ...services import make_immich_provider
from notify_bridge_core.providers.immich.asset_utils import get_public_url
from ..handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__)
@@ -17,6 +19,53 @@ _IMMICH_COMMANDS = {
}
async def fetch_albums_with_links(
client: Any,
album_ids: list[str],
ext_domain: str,
*,
include_failed: bool = True,
) -> list[dict[str, Any]]:
"""Fetch albums and their shared links concurrently.
Returns a list of album data dicts with keys: name, asset_count, id,
public_url, and ``_album`` (the raw album object for callers that need
asset-level access).
When *include_failed* is True, albums that fail to fetch are included
with placeholder data (``"?"`` for counts). When False, they are
silently skipped.
"""
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
albums_data: list[dict[str, Any]] = []
for album_id, result, links in zip(album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
if include_failed:
albums_data.append({
"name": f"{album_id[:8]}...", "asset_count": "?",
"id": album_id, "public_url": "", "_album": None,
})
continue
if result:
pub_url = ""
if not isinstance(links, Exception) and ext_domain:
pub_url = get_public_url(ext_domain, links) or ""
albums_data.append({
"name": result.name, "asset_count": result.asset_count,
"id": album_id, "public_url": pub_url, "_album": result,
})
return albums_data
def build_asset_dict(
asset: Any,
*,
@@ -56,8 +105,14 @@ def _format_assets(
assets: list[dict[str, Any]], cmd: str, query: str,
locale: str, response_mode: str, client: Any,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Format asset results as text or media payload."""
) -> str | dict[str, Any]:
"""Format asset results as text or a text-plus-media payload.
Returns:
str: rendered text when *response_mode* is ``"text"`` (or no assets).
dict: ``{"text": ..., "media": [...]}`` when *response_mode* is
``"media"`` and assets are present.
"""
if not assets:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": query})
@@ -68,7 +123,7 @@ def _format_assets(
})
if response_mode == "media":
media_items = []
media_items: list[dict[str, Any]] = []
for asset in assets:
asset_id = asset.get("id", "")
media_items.append({
@@ -13,23 +13,22 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from ...database.engine import get_engine
from ...database.models import (
EventLog, NotificationTarget, NotificationTrackerTarget,
ServiceProvider, TelegramBot, TrackingConfig,
EventLog, NotificationTracker, NotificationTrackerTarget,
ServiceProvider, TrackingConfig,
)
from notify_bridge_core.providers.immich.asset_utils import get_public_url
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
from .common import _format_assets, build_asset_dict
from ..command_utils import get_trackers_for_provider
from ..handler import _render_cmd_template
from .common import _format_assets, build_asset_dict, fetch_albums_with_links
_LOGGER = logging.getLogger(__name__)
async def _cmd_events(
bot: TelegramBot, providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
count: int, locale: str,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
trackers = await get_trackers_for_provider(provider.id)
tracker_ids = [t.id for t in trackers]
if not tracker_ids:
return {"events": []}
@@ -57,32 +56,21 @@ async def cmd_latest(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
external_domain: str = "",
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /latest command with concurrent album fetching."""
album_ids = all_album_ids[:10]
if not album_ids:
return _format_assets([], "latest", "", locale, response_mode, client, cmd_templates)
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
ext = external_domain.rstrip("/")
fetched = await fetch_albums_with_links(client, album_ids, ext, include_failed=False)
latest_assets: list[dict[str, Any]] = []
for album_id, result, links in zip(album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
pub_url = ""
if not isinstance(links, Exception) and ext:
pub_url = get_public_url(ext, links) or ""
for aid, asset in list(result.assets.items())[:count]:
for album_data in fetched:
pub_url = album_data.get("public_url", "")
album_obj = album_data.get("_album")
if album_obj:
for aid, asset in list(album_obj.assets.items())[:count]:
asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else ""
latest_assets.append(build_asset_dict(asset, public_url=asset_pub))
@@ -95,32 +83,21 @@ async def cmd_random(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
external_domain: str = "",
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /random command with concurrent album fetching."""
album_ids = all_album_ids[:10]
if not album_ids:
return _format_assets([], "random", "", locale, response_mode, client, cmd_templates)
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
ext = external_domain.rstrip("/")
fetched = await fetch_albums_with_links(client, album_ids, ext, include_failed=False)
random_assets: list[dict[str, Any]] = []
for album_id, result, links in zip(album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
pub_url = ""
if not isinstance(links, Exception) and ext:
pub_url = get_public_url(ext, links) or ""
asset_list = list(result.assets.values())
for album_data in fetched:
pub_url = album_data.get("public_url", "")
album_obj = album_data.get("_album")
if album_obj:
asset_list = list(album_obj.assets.values())
sampled = rng.sample(asset_list, min(count, len(asset_list)))
for asset in sampled:
asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else ""
@@ -130,40 +107,40 @@ async def cmd_random(
return _format_assets(random_assets[:count], "random", "", locale, response_mode, client, cmd_templates)
async def _check_native_memory(bot: TelegramBot) -> bool:
"""Check if any tracker-target linked to this bot uses native memory source."""
async def _check_native_memory(provider_id: int) -> bool:
"""Check if any notification tracker for this provider uses native memory source."""
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(NotificationTarget).where(
NotificationTarget.type == "telegram",
NotificationTarget.user_id == bot.user_id,
tracker_result = await session.exec(
select(NotificationTracker).where(
NotificationTracker.provider_id == provider_id,
)
)
targets = result.all()
bot_target_ids = {t.id for t in targets if t.config.get("bot_token") == bot.token}
if not bot_target_ids:
trackers = tracker_result.all()
tracker_ids = [t.id for t in trackers]
if not tracker_ids:
return False
tt_result = await session.exec(
select(NotificationTrackerTarget).where(
NotificationTrackerTarget.target_id.in_(bot_target_ids)
NotificationTrackerTarget.tracker_id.in_(tracker_ids)
)
)
for tt in tt_result.all():
if tt.tracking_config_id:
tc = await session.get(TrackingConfig, tt.tracking_config_id)
if tc and tc.memory_source == "native":
return True
return False
tc_ids = list({tt.tracking_config_id for tt in tt_result.all() if tt.tracking_config_id})
if not tc_ids:
return False
tc_result = await session.exec(
select(TrackingConfig).where(TrackingConfig.id.in_(tc_ids))
)
return any(tc.memory_source == "native" for tc in tc_result.all())
async def cmd_memory(
bot: TelegramBot, client: Any, all_album_ids: list[str], count: int,
provider_id: int, client: Any, all_album_ids: list[str], count: int,
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /memory command with concurrent album fetching."""
use_native = await _check_native_memory(bot)
use_native = await _check_native_memory(provider_id)
today = datetime.now(timezone.utc)
memory_assets: list[dict[str, Any]] = []
@@ -2,26 +2,21 @@
from __future__ import annotations
import asyncio
import logging
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ...database.engine import get_engine
from ...database.models import (
CommandConfig, CommandTracker, EventLog,
CommandConfig, CommandTracker,
ServiceProvider, TelegramBot,
)
from ...services import make_immich_provider
from ..base import ProviderCommandHandler
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
from notify_bridge_core.providers.immich.asset_utils import get_public_url
from ...services.http_session import get_http_session
from ..base import CommandResponse, ProviderCommandHandler
from ..command_utils import get_last_event_str, get_trackers_for_provider
from ..handler import _render_cmd_template
from .albums import _cmd_albums, cmd_favorites, cmd_summary
from .common import _IMMICH_COMMANDS
from .common import _IMMICH_COMMANDS, fetch_albums_with_links
from .events import _cmd_events, cmd_latest, cmd_memory, cmd_random
from .search import cmd_find, cmd_person, cmd_place, cmd_search
@@ -29,21 +24,15 @@ _LOGGER = logging.getLogger(__name__)
async def _cmd_status(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
provider: ServiceProvider, locale: str,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
trackers = await get_trackers_for_provider(provider.id)
active = sum(1 for t in trackers if t.enabled)
total = len(trackers)
total_albums = sum(len(t.collection_ids or []) for t in trackers)
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(EventLog).order_by(EventLog.created_at.desc()).limit(1)
)
last_event = result.first()
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
tracker_ids = [t.id for t in trackers]
last_str = await get_last_event_str(tracker_ids)
return {
"trackers_active": active, "trackers_total": total,
@@ -52,16 +41,13 @@ async def _cmd_status(
async def _cmd_people(
providers_map: dict[int, ServiceProvider], locale: str,
provider: ServiceProvider, locale: str,
) -> dict[str, Any]:
all_people: dict[str, str] = {}
async with aiohttp.ClientSession() as http:
for provider in providers_map.values():
if provider.type != "immich":
continue
immich = make_immich_provider(http, provider)
people = await immich.client.get_people()
all_people.update(people)
http = await get_http_session()
immich = make_immich_provider(http, provider)
people = await immich.client.get_people()
all_people.update(people)
names = sorted(all_people.values())
return {"people": names}
@@ -87,106 +73,92 @@ class ImmichCommandHandler(ProviderCommandHandler):
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
tracker: CommandTracker,
config: CommandConfig,
) -> CommandResponse | None:
if cmd == "status":
ctx = await _cmd_status(bot, providers_map, locale)
return _render_cmd_template(cmd_templates, "status", locale, ctx)
ctx = await _cmd_status(provider, locale)
return CommandResponse(text=_render_cmd_template(cmd_templates, "status", locale, ctx))
if cmd == "albums":
ctx = await _cmd_albums(bot, providers_map, locale)
return _render_cmd_template(cmd_templates, "albums", locale, ctx)
ctx = await _cmd_albums(provider, locale)
return CommandResponse(text=_render_cmd_template(cmd_templates, "albums", locale, ctx))
if cmd == "events":
ctx = await _cmd_events(bot, providers_map, count, locale)
return _render_cmd_template(cmd_templates, "events", locale, ctx)
ctx = await _cmd_events(provider, count, locale)
return CommandResponse(text=_render_cmd_template(cmd_templates, "events", locale, ctx))
if cmd == "people":
ctx = await _cmd_people(providers_map, locale)
return _render_cmd_template(cmd_templates, "people", locale, ctx)
ctx = await _cmd_people(provider, locale)
return CommandResponse(text=_render_cmd_template(cmd_templates, "people", locale, ctx))
if cmd in ("search", "find", "person", "place", "latest",
"random", "favorites", "summary", "memory"):
return await _cmd_immich(
bot, cmd, args, count, locale, response_mode,
providers_map, cmd_templates,
cmd, args, count, locale, response_mode,
provider, cmd_templates,
)
return None
async def _cmd_immich(
bot: TelegramBot, cmd: str, args: str, count: int, locale: str,
response_mode: str, providers_map: dict[int, ServiceProvider],
cmd: str, args: str, count: int, locale: str,
response_mode: str, provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
) -> CommandResponse | None:
"""Handle commands that need Immich API access and may return media."""
if not providers_map:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
provider_ids = set(providers_map.keys())
notification_trackers = await _get_notification_trackers_for_providers(provider_ids)
notification_trackers = await get_trackers_for_provider(provider.id)
all_album_ids: list[str] = []
for t in notification_trackers:
all_album_ids.extend(t.collection_ids or [])
provider: ServiceProvider | None = None
for p in providers_map.values():
if p.type == "immich":
provider = p
break
if not provider:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
async with aiohttp.ClientSession() as http:
immich = make_immich_provider(http, provider)
client = immich.client
http = await get_http_session()
immich = make_immich_provider(http, provider)
client = immich.client
# Build asset_id → public_url map from tracked albums' shared links
asset_public_urls: dict[str, str] = {}
if ext_domain and all_album_ids and cmd in ("search", "find", "person", "place", "favorites"):
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in all_album_ids],
return_exceptions=True,
)
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in all_album_ids],
return_exceptions=True,
)
for album_id, links, album in zip(all_album_ids, link_results, album_results):
if isinstance(links, Exception) or isinstance(album, Exception):
continue
pub_url = get_public_url(ext_domain, links)
if pub_url and album:
for asset_id in album.assets:
asset_public_urls[asset_id] = f"{pub_url}/photos/{asset_id}"
# Build asset_id → public_url map from tracked albums' shared links
asset_public_urls: dict[str, str] = {}
if ext_domain and all_album_ids and cmd in ("search", "find", "person", "place", "favorites"):
fetched = await fetch_albums_with_links(client, all_album_ids, ext_domain, include_failed=False)
for album_data in fetched:
pub_url = album_data.get("public_url", "")
album_obj = album_data.get("_album")
if pub_url and album_obj:
for asset_id in album_obj.assets:
asset_public_urls[asset_id] = f"{pub_url}/photos/{asset_id}"
if cmd == "search":
return await cmd_search(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
# Wrap single-provider in a map for functions that still expect it
providers_map = {provider.id: provider}
if cmd == "find":
return await cmd_find(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
result: str | dict[str, Any] | None = None
if cmd == "person":
return await cmd_person(client, args, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
if cmd == "search":
result = await cmd_search(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "find":
result = await cmd_find(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "person":
result = await cmd_person(client, args, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "place":
result = await cmd_place(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "favorites":
result = await cmd_favorites(providers_map, all_album_ids, count, locale, response_mode, client, cmd_templates)
elif cmd == "latest":
result = await cmd_latest(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
elif cmd == "random":
result = await cmd_random(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
elif cmd == "summary":
result = await cmd_summary(client, all_album_ids, locale, cmd_templates, external_domain=ext_domain)
elif cmd == "memory":
result = await cmd_memory(provider.id, client, all_album_ids, count, locale, response_mode, cmd_templates)
if cmd == "place":
return await cmd_place(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
if cmd == "favorites":
return await cmd_favorites(bot, providers_map, all_album_ids, count, locale, response_mode, client, cmd_templates)
if cmd == "latest":
return await cmd_latest(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
if cmd == "random":
return await cmd_random(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
if cmd == "summary":
return await cmd_summary(client, all_album_ids, locale, cmd_templates, external_domain=ext_domain)
if cmd == "memory":
return await cmd_memory(bot, client, all_album_ids, count, locale, response_mode, cmd_templates)
return None
if result is None:
return None
# _format_assets returns {"text": ..., "media": [...]} for media mode
if isinstance(result, dict):
return CommandResponse(
text=result.get("text"),
media=result.get("media", []),
)
return CommandResponse(text=result)
@@ -9,14 +9,15 @@ from .common import _format_assets
def _enrich_assets(assets: list[dict[str, Any]], asset_public_urls: dict[str, str]) -> list[dict[str, Any]]:
"""Add public_url to assets from the pre-built map."""
"""Add public_url to assets from the pre-built map. Returns new list without mutating inputs."""
if not asset_public_urls:
return assets
for asset in assets:
aid = asset.get("id", "")
if aid and aid in asset_public_urls and not asset.get("public_url"):
asset["public_url"] = asset_public_urls[aid]
return assets
return [
{**asset, "public_url": asset_public_urls.get(asset.get("id", ""), "")}
if asset.get("id", "") in asset_public_urls and not asset.get("public_url")
else asset
for asset in assets
]
async def cmd_search(
@@ -24,7 +25,7 @@ async def cmd_search(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /search command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "search", "query": ""})
@@ -38,7 +39,7 @@ async def cmd_find(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /find command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "find", "query": ""})
@@ -52,7 +53,7 @@ async def cmd_person(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /person command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": ""})
@@ -74,7 +75,7 @@ async def cmd_place(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /place command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "place", "query": ""})
@@ -3,17 +3,31 @@
from __future__ import annotations
import logging
from collections.abc import Callable, Coroutine
from typing import Any
from ..database.models import CommandConfig, CommandTracker, ServiceProvider, TelegramBot
from ..services import make_nut_provider
from .base import ProviderCommandHandler
from .base import CommandResponse, ProviderCommandHandler
from .handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__)
_NUT_COMMANDS = {"status", "devices", "battery"}
# ---------------------------------------------------------------------------
# Command dispatch table
# ---------------------------------------------------------------------------
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
"""Register a function in the text command dispatch table."""
name = fn.__name__.removeprefix("_cmd_")
_TEXT_COMMANDS[name] = fn
return fn
class NutCommandHandler(ProviderCommandHandler):
"""Handles NUT-specific bot commands."""
@@ -33,80 +47,73 @@ class NutCommandHandler(ProviderCommandHandler):
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
if cmd == "status":
ctx = await _cmd_status(providers_map)
return _render_cmd_template(cmd_templates, "status", locale, ctx)
if cmd == "devices":
ctx = await _cmd_devices(providers_map)
return _render_cmd_template(cmd_templates, "devices", locale, ctx)
if cmd == "battery":
ctx = await _cmd_battery(providers_map)
return _render_cmd_template(cmd_templates, "battery", locale, ctx)
return None
tracker: CommandTracker,
config: CommandConfig,
) -> CommandResponse | None:
fn = _TEXT_COMMANDS.get(cmd)
if fn is None:
return None
ctx = await fn(provider, count)
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
async def _query_all_ups(
providers_map: dict[int, ServiceProvider],
async def _query_ups(
provider: ServiceProvider,
) -> list[dict[str, Any]]:
"""Connect to all NUT providers and query UPS data."""
"""Connect to a NUT provider and query UPS data."""
from notify_bridge_core.providers.nut.models import NutUpsData
results: list[dict[str, Any]] = []
for provider in providers_map.values():
if provider.type != "nut":
continue
nut = make_nut_provider(provider)
nut = make_nut_provider(provider)
try:
client = nut._make_client()
await client.connect()
try:
client = nut._make_client()
await client.connect()
try:
devices = await client.list_ups()
for dev in devices:
variables = await client.list_var(dev.name)
data = NutUpsData.from_variables(dev.name, variables)
results.append({
"name": data.name,
"description": data.description,
"model": data.model,
"manufacturer": data.manufacturer,
"status": data.status,
"battery_charge": int(data.battery_charge) if data.battery_charge is not None else None,
"battery_runtime": data.battery_runtime_formatted,
"ups_load": int(data.ups_load) if data.ups_load is not None else None,
"input_voltage": str(data.input_voltage) if data.input_voltage is not None else None,
"output_voltage": str(data.output_voltage) if data.output_voltage is not None else None,
})
finally:
await client.disconnect()
except Exception as exc:
_LOGGER.warning("Failed to query NUT provider %s: %s", provider.name, exc)
devices = await client.list_ups()
for dev in devices:
variables = await client.list_var(dev.name)
data = NutUpsData.from_variables(dev.name, variables)
results.append({
"name": data.name,
"description": data.description,
"model": data.model,
"manufacturer": data.manufacturer,
"status": data.status,
"battery_charge": int(data.battery_charge) if data.battery_charge is not None else None,
"battery_runtime": data.battery_runtime_formatted,
"ups_load": int(data.ups_load) if data.ups_load is not None else None,
"input_voltage": str(data.input_voltage) if data.input_voltage is not None else None,
"output_voltage": str(data.output_voltage) if data.output_voltage is not None else None,
})
finally:
await client.disconnect()
except Exception as exc:
_LOGGER.warning("Failed to query NUT provider %s: %s", provider.name, exc)
return results
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
devices = await _query_all_ups(providers_map)
@_text_cmd
async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
devices = await _query_ups(provider)
return {"devices": devices}
async def _cmd_devices(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
@_text_cmd
async def _cmd_devices(provider: ServiceProvider, count: int) -> dict[str, Any]:
devices: list[dict[str, Any]] = []
for provider in providers_map.values():
if provider.type != "nut":
continue
nut = make_nut_provider(provider)
try:
device_list = await nut.list_collections()
devices.extend(device_list)
except Exception as exc:
_LOGGER.warning("Failed to list devices from %s: %s", provider.name, exc)
nut = make_nut_provider(provider)
try:
device_list = await nut.list_collections()
devices.extend(device_list)
except Exception as exc:
_LOGGER.warning("Failed to list devices from %s: %s", provider.name, exc)
return {"devices": devices}
async def _cmd_battery(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
devices = await _query_all_ups(providers_map)
@_text_cmd
async def _cmd_battery(provider: ServiceProvider, count: int) -> dict[str, Any]:
devices = await _query_ups(provider)
return {"devices": devices}
@@ -3,26 +3,47 @@
from __future__ import annotations
import logging
from collections.abc import Callable, Coroutine
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import (
CommandConfig, CommandTracker, EventLog,
NotificationTracker, ServiceProvider, TelegramBot,
CommandConfig, CommandTracker, ServiceProvider, TelegramBot,
)
from ..services import make_planka_provider
from .base import ProviderCommandHandler
from .handler import _render_cmd_template, _get_notification_trackers_for_providers
from ..services.http_session import get_http_session
from .base import CommandResponse, ProviderCommandHandler
from .command_utils import get_last_event_str, get_tracked_collection_ids, get_trackers_for_provider
from .handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__)
_PLANKA_COMMANDS = {"status", "boards", "cards", "lists"}
def _get_tracked_board_ids(
provider: ServiceProvider,
trackers: list,
) -> list[str]:
"""Get board IDs from tracked collection_ids for this provider."""
if not provider.config.get("api_key"):
return []
return get_tracked_collection_ids(provider, trackers)
# ---------------------------------------------------------------------------
# Command dispatch table
# ---------------------------------------------------------------------------
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
"""Register a function in the text command dispatch table."""
name = fn.__name__.removeprefix("_cmd_")
_TEXT_COMMANDS[name] = fn
return fn
class PlankaCommandHandler(ProviderCommandHandler):
"""Handles Planka-specific bot commands."""
@@ -43,69 +64,26 @@ class PlankaCommandHandler(ProviderCommandHandler):
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
if cmd == "status":
ctx = await _cmd_status(providers_map)
return _render_cmd_template(cmd_templates, "status", locale, ctx)
if cmd == "boards":
ctx = await _cmd_boards(providers_map)
return _render_cmd_template(cmd_templates, "boards", locale, ctx)
if cmd == "cards":
ctx = await _cmd_cards(providers_map, count)
return _render_cmd_template(cmd_templates, "cards", locale, ctx)
if cmd == "lists":
ctx = await _cmd_lists(providers_map)
return _render_cmd_template(cmd_templates, "lists", locale, ctx)
return None
tracker: CommandTracker,
config: CommandConfig,
) -> CommandResponse | None:
fn = _TEXT_COMMANDS.get(cmd)
if fn is None:
return None
ctx = await fn(provider, count)
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
def _get_tracked_board_ids(
providers_map: dict[int, ServiceProvider],
trackers: list[NotificationTracker],
) -> list[tuple[ServiceProvider, str]]:
"""Get (provider, board_id) tuples from tracked collection_ids."""
boards: list[tuple[ServiceProvider, str]] = []
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "planka":
continue
if not provider.config.get("api_key"):
continue
for board_id in (tracker.collection_ids or []):
entry = (provider, board_id)
if entry not in boards:
boards.append(entry)
# Also check filters.collections
for board_id in (tracker.filters or {}).get("collections", []):
entry = (provider, board_id)
if entry not in boards:
boards.append(entry)
return boards[:20]
@_text_cmd
async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_boards = _get_tracked_board_ids(provider, trackers)
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
# Last event
engine = get_engine()
async with AsyncSession(engine) as session:
tracker_ids = [t.id for t in trackers]
if tracker_ids:
result = await session.exec(
select(EventLog)
.where(EventLog.tracker_id.in_(tracker_ids))
.order_by(EventLog.created_at.desc()).limit(1)
)
last_event = result.first()
else:
last_event = None
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
tracker_ids = [t.id for t in trackers]
last_str = await get_last_event_str(tracker_ids)
return {
"boards_count": len(tracked_boards),
@@ -113,81 +91,69 @@ async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, An
}
async def _cmd_boards(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
@_text_cmd
async def _cmd_boards(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_boards = _get_tracked_board_ids(provider, trackers)
boards_data: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, board_id in tracked_boards:
planka = make_planka_provider(http, provider)
all_boards = await planka.client.get_boards()
for b in all_boards:
if str(b.get("id", "")) == board_id:
boards_data.append({"name": b.get("name", board_id)})
break
else:
boards_data.append({"name": board_id})
http = await get_http_session()
planka = make_planka_provider(http, provider)
all_boards = await planka.client.get_boards()
board_names = {str(b.get("id", "")): b.get("name", "") for b in all_boards}
for board_id in tracked_boards:
boards_data.append({"name": board_names.get(board_id, board_id)})
return {"boards": boards_data}
async def _cmd_cards(
providers_map: dict[int, ServiceProvider], count: int,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
@_text_cmd
async def _cmd_cards(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_boards = _get_tracked_board_ids(provider, trackers)
all_cards: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, board_id in tracked_boards:
planka = make_planka_provider(http, provider)
cards = await planka.client.get_board_cards(board_id, limit=count)
lists = await planka.client.get_board_lists(board_id)
lists_by_id = {str(lst.get("id", "")): lst.get("name", "") for lst in lists}
http = await get_http_session()
planka = make_planka_provider(http, provider)
boards = await planka.client.get_boards()
board_names = {str(b.get("id", "")): b.get("name", "") for b in boards}
boards = await planka.client.get_boards()
board_name = board_id
for b in boards:
if str(b.get("id", "")) == board_id:
board_name = b.get("name", board_id)
break
for board_id in tracked_boards:
cards = await planka.client.get_board_cards(board_id, limit=count)
lists = await planka.client.get_board_lists(board_id)
lists_by_id = {str(lst.get("id", "")): lst.get("name", "") for lst in lists}
board_name = board_names.get(board_id, board_id)
for card in cards:
list_id = str(card.get("listId", ""))
all_cards.append({
"name": card.get("name", ""),
"list_name": lists_by_id.get(list_id, ""),
"board_name": board_name,
})
for card in cards:
list_id = str(card.get("listId", ""))
all_cards.append({
"name": card.get("name", ""),
"list_name": lists_by_id.get(list_id, ""),
"board_name": board_name,
})
return {"cards": all_cards[:count]}
async def _cmd_lists(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
@_text_cmd
async def _cmd_lists(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_boards = _get_tracked_board_ids(provider, trackers)
all_lists: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, board_id in tracked_boards:
planka = make_planka_provider(http, provider)
lists = await planka.client.get_board_lists(board_id)
http = await get_http_session()
planka = make_planka_provider(http, provider)
boards = await planka.client.get_boards()
board_names = {str(b.get("id", "")): b.get("name", "") for b in boards}
boards = await planka.client.get_boards()
board_name = board_id
for b in boards:
if str(b.get("id", "")) == board_id:
board_name = b.get("name", board_id)
break
for board_id in tracked_boards:
lists = await planka.client.get_board_lists(board_id)
board_name = board_names.get(board_id, board_id)
for lst in lists:
all_lists.append({
"name": lst.get("name", ""),
"board_name": board_name,
})
for lst in lists:
all_lists.append({
"name": lst.get("name", ""),
"board_name": board_name,
})
return {"lists": all_lists}
@@ -6,7 +6,6 @@ import hmac
import logging
from typing import Any
import aiohttp
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -16,6 +15,7 @@ from notify_bridge_core.notifications.telegram.client import TelegramClient
from ..database.engine import get_session
from ..database.models import TelegramBot, TelegramChat
from ..services.telegram import save_chat_from_webhook
from .base import CommandResponse
from .handler import handle_command, send_media_group, send_reply
_LOGGER = logging.getLogger(__name__)
@@ -89,15 +89,13 @@ async def telegram_webhook(
return {"ok": True, "skipped": "commands_disabled"}
effective_lang = chat_row.language_override or msg_language
message_id = message.get("message_id")
cmd_response = await handle_command(bot, chat_id, text, language_code=effective_lang)
if cmd_response is not None:
if isinstance(cmd_response, dict) and "media" in cmd_response:
await send_reply(bot.token, chat_id, cmd_response["text"], reply_to_message_id=message_id)
await send_media_group(bot.token, chat_id, cmd_response["media"], reply_to_message_id=message_id)
elif isinstance(cmd_response, list):
await send_media_group(bot.token, chat_id, cmd_response, reply_to_message_id=message_id)
else:
await send_reply(bot.token, chat_id, cmd_response, reply_to_message_id=message_id)
responses = await handle_command(bot, chat_id, text, language_code=effective_lang)
if responses:
for resp in responses:
if resp.text:
await send_reply(bot.token, chat_id, resp.text, reply_to_message_id=message_id)
if resp.media:
await send_media_group(bot.token, chat_id, resp.media, reply_to_message_id=message_id)
return {"ok": True}
return {"ok": True, "skipped": "not_a_command"}
@@ -105,13 +103,15 @@ async def telegram_webhook(
async def register_webhook(bot_token: str, webhook_url: str, secret: str | None = None) -> dict:
"""Register webhook URL with Telegram Bot API via TelegramClient."""
async with aiohttp.ClientSession() as http:
client = TelegramClient(http, bot_token)
return await client.set_webhook(webhook_url, secret=secret)
from ..services.http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, bot_token)
return await client.set_webhook(webhook_url, secret=secret)
async def unregister_webhook(bot_token: str) -> dict:
"""Remove webhook from Telegram Bot API via TelegramClient."""
async with aiohttp.ClientSession() as http:
client = TelegramClient(http, bot_token)
return await client.delete_webhook()
from ..services.http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, bot_token)
return await client.delete_webhook()