refactor: provider-agnostic bot command system + Gitea commands
Refactored the monolithic command handler (707 lines) into a pluggable provider-handler architecture: - Abstract ProviderCommandHandler interface (base.py) - Handler dispatch registry routes commands by provider type - Extracted all Immich logic into ImmichCommandHandler - New GiteaCommandHandler with /status, /repos, /issues, /prs, /commits - Multi-provider routing: groups context by provider type, finds handler - handler.py reduced to ~280 line thin orchestrator Gitea commands: - Extended GiteaClient with get_repo_issues, get_repo_pulls, get_repo_commits - 30 Jinja2 command templates (15 EN + 15 RU) - Gitea capabilities updated with 6 commands + 15 command_slots - Default command config + command template config seeded on startup - Rate limiting: Gitea API commands share "api" category (15s cooldown) Also: - Command configs API accepts "gitea" provider type - System command configs (user_id=0) visible to all users - Webhook URL shown on Gitea provider card and edit form - Scan interval hidden for webhook-based providers
This commit is contained in:
@@ -119,6 +119,8 @@
|
||||
"webhookSecretRequired": "Webhook secret is required",
|
||||
"apiToken": "API Token",
|
||||
"apiTokenHint": "Optional. Needed for connection testing and repository listing.",
|
||||
"webhookUrl": "Webhook URL",
|
||||
"webhookUrlHint": "Set this as the Target URL in Gitea webhook settings (relative to your bridge host).",
|
||||
"testAndSave": "Test & Save",
|
||||
"saveWithoutTest": "Save without testing"
|
||||
},
|
||||
|
||||
@@ -119,6 +119,8 @@
|
||||
"webhookSecretRequired": "Секрет вебхука обязателен",
|
||||
"apiToken": "API токен",
|
||||
"apiTokenHint": "Необязательно. Нужен для проверки подключения и получения списка репозиториев.",
|
||||
"webhookUrl": "URL вебхука",
|
||||
"webhookUrlHint": "Укажите этот URL в настройках вебхука Gitea (относительно хоста bridge).",
|
||||
"testAndSave": "Проверить и сохранить",
|
||||
"saveWithoutTest": "Сохранить без проверки"
|
||||
},
|
||||
|
||||
@@ -45,6 +45,7 @@
|
||||
}: Props = $props();
|
||||
|
||||
let isScheduler = $derived(providerType === 'scheduler');
|
||||
let isWebhook = $derived(providerType === 'gitea');
|
||||
|
||||
// Custom variable management for scheduler
|
||||
function addVariable() {
|
||||
@@ -162,10 +163,12 @@
|
||||
</fieldset>
|
||||
{:else}
|
||||
<div class="grid grid-cols-2 gap-3">
|
||||
{#if !isWebhook}
|
||||
<div>
|
||||
<label for="trk-interval" class="block text-sm font-medium mb-1">{t('notificationTracker.scanInterval')}<Hint text={t('hints.scanInterval')} /></label>
|
||||
<input id="trk-interval" type="number" bind:value={form.scan_interval} min="10" max="3600" class="w-full px-3 py-2 border border-[var(--color-border)] rounded-md text-sm bg-[var(--color-background)]" />
|
||||
</div>
|
||||
{/if}
|
||||
<div>
|
||||
<label for="trk-batch" class="block text-sm font-medium mb-1">{t('notificationTracker.batchDuration')}<Hint text={t('hints.batchDuration')} /></label>
|
||||
<input id="trk-batch" type="number" bind:value={form.batch_duration} min="0" max="3600" class="w-full px-3 py-2 border border-[var(--color-border)] rounded-md text-sm bg-[var(--color-background)]" />
|
||||
|
||||
@@ -166,6 +166,13 @@
|
||||
<input id="prv-token" bind:value={form.api_token} type="password" class="w-full px-3 py-2 border border-[var(--color-border)] rounded-md text-sm bg-[var(--color-background)]" />
|
||||
<p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.apiTokenHint')}</p>
|
||||
</div>
|
||||
{#if editing}
|
||||
<div class="bg-[var(--color-muted)] rounded-md p-3">
|
||||
<label class="block text-sm font-medium mb-1">{t('providers.webhookUrl')}</label>
|
||||
<code class="text-xs select-all break-all">/api/webhooks/gitea/{editing}</code>
|
||||
<p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.webhookUrlHint')}</p>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
<button type="submit" disabled={submitting}
|
||||
class="px-4 py-2 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90 disabled:opacity-50">
|
||||
@@ -196,6 +203,9 @@
|
||||
{#if provider.config?.url}
|
||||
<a href={provider.config.url} target="_blank" rel="noopener" class="text-xs text-[var(--color-muted-foreground)] font-mono hover:text-[var(--color-primary)] hover:underline">{provider.config.url}</a>
|
||||
{/if}
|
||||
{#if provider.type === 'gitea'}
|
||||
<p class="text-xs text-[var(--color-muted-foreground)] font-mono mt-0.5">{t('providers.webhookUrl')}: <span class="select-all">/api/webhooks/gitea/{provider.id}</span></p>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-1">
|
||||
|
||||
@@ -150,7 +150,6 @@ GITEA_CAPABILITIES = ProviderCapabilities(
|
||||
{"name": "message_pr_commented", "description": "Comment on pull request"},
|
||||
{"name": "message_release_published", "description": "Release published"},
|
||||
],
|
||||
command_slots=[],
|
||||
events=[
|
||||
{"name": "push", "description": "Code pushed to repository"},
|
||||
{"name": "issue_opened", "description": "Issue opened"},
|
||||
@@ -162,7 +161,31 @@ GITEA_CAPABILITIES = ProviderCapabilities(
|
||||
{"name": "pr_commented", "description": "Comment on pull request"},
|
||||
{"name": "release_published", "description": "Release published"},
|
||||
],
|
||||
commands=[],
|
||||
command_slots=[
|
||||
{"name": "start", "description": "/start greeting message"},
|
||||
{"name": "help", "description": "/help command listing"},
|
||||
{"name": "status", "description": "/status tracker summary"},
|
||||
{"name": "repos", "description": "/repos tracked repositories"},
|
||||
{"name": "issues", "description": "/issues open issues"},
|
||||
{"name": "prs", "description": "/prs open pull requests"},
|
||||
{"name": "commits", "description": "/commits recent commits"},
|
||||
{"name": "rate_limited", "description": "Rate limit warning message"},
|
||||
{"name": "no_results", "description": "Empty results fallback"},
|
||||
{"name": "desc_help", "description": "Menu description for /help"},
|
||||
{"name": "desc_status", "description": "Menu description for /status"},
|
||||
{"name": "desc_repos", "description": "Menu description for /repos"},
|
||||
{"name": "desc_issues", "description": "Menu description for /issues"},
|
||||
{"name": "desc_prs", "description": "Menu description for /prs"},
|
||||
{"name": "desc_commits", "description": "Menu description for /commits"},
|
||||
],
|
||||
commands=[
|
||||
{"name": "status", "description": "Show tracker status"},
|
||||
{"name": "repos", "description": "List tracked repositories"},
|
||||
{"name": "issues", "description": "Recent open issues"},
|
||||
{"name": "prs", "description": "Open pull requests"},
|
||||
{"name": "commits", "description": "Recent commits"},
|
||||
{"name": "help", "description": "Show commands"},
|
||||
],
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -85,5 +85,57 @@ class GiteaClient:
|
||||
return repos
|
||||
|
||||
|
||||
async def get_repo_issues(
|
||||
self, owner: str, repo: str, state: str = "open", limit: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch issues for a repository."""
|
||||
try:
|
||||
async with self._session.get(
|
||||
f"{self._url}/api/v1/repos/{owner}/{repo}/issues",
|
||||
headers=self._headers,
|
||||
params={"type": "issues", "state": state, "limit": str(limit)},
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
_LOGGER.warning("Failed to fetch issues for %s/%s: HTTP %s", owner, repo, response.status)
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to fetch issues for %s/%s: %s", owner, repo, err)
|
||||
return []
|
||||
|
||||
async def get_repo_pulls(
|
||||
self, owner: str, repo: str, state: str = "open", limit: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch pull requests for a repository."""
|
||||
try:
|
||||
async with self._session.get(
|
||||
f"{self._url}/api/v1/repos/{owner}/{repo}/pulls",
|
||||
headers=self._headers,
|
||||
params={"state": state, "limit": str(limit)},
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
_LOGGER.warning("Failed to fetch PRs for %s/%s: HTTP %s", owner, repo, response.status)
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to fetch PRs for %s/%s: %s", owner, repo, err)
|
||||
return []
|
||||
|
||||
async def get_repo_commits(
|
||||
self, owner: str, repo: str, limit: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch recent commits for a repository."""
|
||||
try:
|
||||
async with self._session.get(
|
||||
f"{self._url}/api/v1/repos/{owner}/{repo}/commits",
|
||||
headers=self._headers,
|
||||
params={"limit": str(limit)},
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
_LOGGER.warning("Failed to fetch commits for %s/%s: HTTP %s", owner, repo, response.status)
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to fetch commits for %s/%s: %s", owner, repo, err)
|
||||
return []
|
||||
|
||||
|
||||
class GiteaApiError(Exception):
|
||||
"""Raised when a Gitea API call fails."""
|
||||
|
||||
+7
@@ -0,0 +1,7 @@
|
||||
📝 <b>Recent Commits</b>
|
||||
{%- for c in commits %}
|
||||
• <b>{{ c.repo }}</b> <code>{{ c.short_id }}</code>: {{ c.message }} ({{ c.author }})
|
||||
{%- endfor %}
|
||||
{%- if not commits %}
|
||||
No recent commits found.
|
||||
{%- endif %}
|
||||
+1
@@ -0,0 +1 @@
|
||||
Recent commits
|
||||
+1
@@ -0,0 +1 @@
|
||||
Show available commands
|
||||
+1
@@ -0,0 +1 @@
|
||||
Recent open issues
|
||||
+1
@@ -0,0 +1 @@
|
||||
Open pull requests
|
||||
+1
@@ -0,0 +1 @@
|
||||
List tracked repositories
|
||||
+1
@@ -0,0 +1 @@
|
||||
Show tracker status
|
||||
@@ -0,0 +1,4 @@
|
||||
📋 <b>Available commands:</b>
|
||||
{%- for cmd in commands %}
|
||||
/{{ cmd.name }} — {{ cmd.description }}
|
||||
{%- endfor %}
|
||||
@@ -0,0 +1,7 @@
|
||||
🐛 <b>Open Issues</b>
|
||||
{%- for issue in issues %}
|
||||
• <b>{{ issue.repo }}</b> <a href="{{ issue.url }}">#{{ issue.number }}</a>: {{ issue.title }} ({{ issue.user }})
|
||||
{%- endfor %}
|
||||
{%- if not issues %}
|
||||
No open issues found.
|
||||
{%- endif %}
|
||||
+1
@@ -0,0 +1 @@
|
||||
No results found.
|
||||
@@ -0,0 +1,7 @@
|
||||
🔀 <b>Open Pull Requests</b>
|
||||
{%- for pr in prs %}
|
||||
• <b>{{ pr.repo }}</b> <a href="{{ pr.url }}">#{{ pr.number }}</a>: {{ pr.title }} ({{ pr.user }})
|
||||
{%- endfor %}
|
||||
{%- if not prs %}
|
||||
No open pull requests found.
|
||||
{%- endif %}
|
||||
+1
@@ -0,0 +1 @@
|
||||
⏳ Please wait {{ wait }}s before using this command again.
|
||||
@@ -0,0 +1,7 @@
|
||||
📦 <b>Tracked Repositories</b>
|
||||
{%- for repo in repos %}
|
||||
• <b>{{ repo.full_name }}</b>{% if repo.description %} — {{ repo.description }}{% endif %}
|
||||
{%- endfor %}
|
||||
{%- if not repos %}
|
||||
No repositories tracked.
|
||||
{%- endif %}
|
||||
@@ -0,0 +1,2 @@
|
||||
👋 Hi! I'm your Notify Bridge bot for <b>Gitea</b>.
|
||||
Use /help to see available commands.
|
||||
@@ -0,0 +1,4 @@
|
||||
📊 <b>Gitea Status</b>
|
||||
Repositories tracked: {{ repos_count }}
|
||||
Server: Gitea v{{ server_version }}
|
||||
Last event: {{ last_event }}
|
||||
@@ -7,41 +7,71 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULTS_DIR = Path(__file__).parent
|
||||
|
||||
# Response template slot names (file stem = slot name)
|
||||
# Per-provider slot names
|
||||
PROVIDER_COMMAND_SLOTS: dict[str, list[str]] = {
|
||||
"immich": [
|
||||
# Response templates
|
||||
"start", "help", "status", "albums", "events", "people",
|
||||
"search", "latest", "favorites", "random", "summary", "memory",
|
||||
"rate_limited", "no_results",
|
||||
# Description slots
|
||||
"desc_status", "desc_albums", "desc_events", "desc_summary",
|
||||
"desc_latest", "desc_memory", "desc_random", "desc_search",
|
||||
"desc_find", "desc_person", "desc_place", "desc_favorites",
|
||||
"desc_people", "desc_help",
|
||||
],
|
||||
"gitea": [
|
||||
# Response templates
|
||||
"start", "help", "status", "repos", "issues", "prs", "commits",
|
||||
"rate_limited", "no_results",
|
||||
# Description slots
|
||||
"desc_help", "desc_status", "desc_repos", "desc_issues",
|
||||
"desc_prs", "desc_commits",
|
||||
],
|
||||
}
|
||||
|
||||
# Backward-compatible aliases
|
||||
COMMAND_SLOT_NAMES = [
|
||||
"start", "help", "status", "albums", "events", "people",
|
||||
"search", "latest", "favorites", "random", "summary", "memory",
|
||||
"rate_limited", "no_results",
|
||||
]
|
||||
|
||||
# Description slots for Telegram command menu (desc_{cmd} -> short text)
|
||||
COMMAND_DESC_SLOT_NAMES = [
|
||||
"desc_status", "desc_albums", "desc_events", "desc_summary",
|
||||
"desc_latest", "desc_memory", "desc_random", "desc_search",
|
||||
"desc_find", "desc_person", "desc_place", "desc_favorites",
|
||||
"desc_people", "desc_help",
|
||||
]
|
||||
|
||||
# All slot names (response + description)
|
||||
ALL_SLOT_NAMES = COMMAND_SLOT_NAMES + COMMAND_DESC_SLOT_NAMES
|
||||
|
||||
|
||||
def load_default_command_templates(locale: str = "en") -> dict[str, str]:
|
||||
"""Load default command template strings for a locale.
|
||||
def load_default_command_templates(
|
||||
locale: str = "en",
|
||||
provider_type: str = "immich",
|
||||
) -> dict[str, str]:
|
||||
"""Load default command template strings for a locale and provider type.
|
||||
|
||||
For "immich", templates are in {locale}/ (root, backward compat).
|
||||
For other providers, templates are in {locale}/{provider_type}/.
|
||||
|
||||
Returns dict mapping slot_name -> template string.
|
||||
"""
|
||||
locale_dir = _DEFAULTS_DIR / locale
|
||||
if provider_type == "immich":
|
||||
locale_dir = _DEFAULTS_DIR / locale
|
||||
else:
|
||||
locale_dir = _DEFAULTS_DIR / locale / provider_type
|
||||
|
||||
if not locale_dir.is_dir():
|
||||
_LOGGER.warning("No default command templates for locale '%s'", locale)
|
||||
_LOGGER.warning("No default command templates for locale '%s' provider '%s'", locale, provider_type)
|
||||
return {}
|
||||
|
||||
slot_names = PROVIDER_COMMAND_SLOTS.get(provider_type, ALL_SLOT_NAMES)
|
||||
templates: dict[str, str] = {}
|
||||
for slot_name in ALL_SLOT_NAMES:
|
||||
for slot_name in slot_names:
|
||||
filepath = locale_dir / f"{slot_name}.jinja2"
|
||||
if filepath.exists():
|
||||
templates[slot_name] = filepath.read_text(encoding="utf-8").strip()
|
||||
else:
|
||||
_LOGGER.debug("Missing default command template: %s/%s.jinja2", locale, slot_name)
|
||||
_LOGGER.debug("Missing default command template: %s/%s.jinja2", locale_dir.name, slot_name)
|
||||
|
||||
return templates
|
||||
|
||||
+7
@@ -0,0 +1,7 @@
|
||||
📝 <b>Последние коммиты</b>
|
||||
{%- for c in commits %}
|
||||
• <b>{{ c.repo }}</b> <code>{{ c.short_id }}</code>: {{ c.message }} ({{ c.author }})
|
||||
{%- endfor %}
|
||||
{%- if not commits %}
|
||||
Коммитов не найдено.
|
||||
{%- endif %}
|
||||
+1
@@ -0,0 +1 @@
|
||||
Последние коммиты
|
||||
+1
@@ -0,0 +1 @@
|
||||
Показать доступные команды
|
||||
+1
@@ -0,0 +1 @@
|
||||
Открытые задачи
|
||||
+1
@@ -0,0 +1 @@
|
||||
Открытые пулл-реквесты
|
||||
+1
@@ -0,0 +1 @@
|
||||
Список репозиториев
|
||||
+1
@@ -0,0 +1 @@
|
||||
Статус трекера
|
||||
@@ -0,0 +1,4 @@
|
||||
📋 <b>Доступные команды:</b>
|
||||
{%- for cmd in commands %}
|
||||
/{{ cmd.name }} — {{ cmd.description }}
|
||||
{%- endfor %}
|
||||
@@ -0,0 +1,7 @@
|
||||
🐛 <b>Открытые задачи</b>
|
||||
{%- for issue in issues %}
|
||||
• <b>{{ issue.repo }}</b> <a href="{{ issue.url }}">#{{ issue.number }}</a>: {{ issue.title }} ({{ issue.user }})
|
||||
{%- endfor %}
|
||||
{%- if not issues %}
|
||||
Открытых задач не найдено.
|
||||
{%- endif %}
|
||||
+1
@@ -0,0 +1 @@
|
||||
Результатов не найдено.
|
||||
@@ -0,0 +1,7 @@
|
||||
🔀 <b>Открытые пулл-реквесты</b>
|
||||
{%- for pr in prs %}
|
||||
• <b>{{ pr.repo }}</b> <a href="{{ pr.url }}">#{{ pr.number }}</a>: {{ pr.title }} ({{ pr.user }})
|
||||
{%- endfor %}
|
||||
{%- if not prs %}
|
||||
Открытых пулл-реквестов не найдено.
|
||||
{%- endif %}
|
||||
+1
@@ -0,0 +1 @@
|
||||
⏳ Подождите {{ wait }} сек. перед повторным использованием команды.
|
||||
@@ -0,0 +1,7 @@
|
||||
📦 <b>Отслеживаемые репозитории</b>
|
||||
{%- for repo in repos %}
|
||||
• <b>{{ repo.full_name }}</b>{% if repo.description %} — {{ repo.description }}{% endif %}
|
||||
{%- endfor %}
|
||||
{%- if not repos %}
|
||||
Репозитории не отслеживаются.
|
||||
{%- endif %}
|
||||
@@ -0,0 +1,2 @@
|
||||
👋 Привет! Я ваш бот Notify Bridge для <b>Gitea</b>.
|
||||
Используйте /help для списка команд.
|
||||
@@ -0,0 +1,4 @@
|
||||
📊 <b>Статус Gitea</b>
|
||||
Отслеживаемые репозитории: {{ repos_count }}
|
||||
Сервер: Gitea v{{ server_version }}
|
||||
Последнее событие: {{ last_event }}
|
||||
@@ -43,9 +43,12 @@ async def list_command_configs(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""List all command configs for the current user."""
|
||||
"""List all command configs for the current user (including system defaults)."""
|
||||
from sqlmodel import or_
|
||||
result = await session.exec(
|
||||
select(CommandConfig).where(CommandConfig.user_id == user.id)
|
||||
select(CommandConfig).where(
|
||||
or_(CommandConfig.user_id == user.id, CommandConfig.user_id == 0)
|
||||
)
|
||||
)
|
||||
return [_config_response(c) for c in result.all()]
|
||||
|
||||
@@ -58,7 +61,7 @@ async def create_command_config(
|
||||
):
|
||||
"""Create a new command config."""
|
||||
# Validate provider_type
|
||||
valid_types = ("immich",)
|
||||
valid_types = ("immich", "gitea")
|
||||
if body.provider_type not in valid_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -159,7 +162,7 @@ async def _get_user_config(
|
||||
session: AsyncSession, config_id: int, user_id: int
|
||||
) -> CommandConfig:
|
||||
config = await session.get(CommandConfig, config_id)
|
||||
if not config or config.user_id != user_id:
|
||||
if not config or (config.user_id != user_id and config.user_id != 0):
|
||||
raise HTTPException(status_code=404, detail="Command config not found")
|
||||
return config
|
||||
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
"""Abstract provider command handler interface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from ..database.models import CommandTracker, CommandConfig, ServiceProvider, TelegramBot
|
||||
|
||||
|
||||
class ProviderCommandHandler(ABC):
|
||||
"""Base class for provider-specific bot command handlers.
|
||||
|
||||
Each provider (Immich, Gitea, etc.) implements this interface to handle
|
||||
its own set of commands. The dispatch layer routes commands to the
|
||||
correct handler based on the provider type.
|
||||
"""
|
||||
|
||||
provider_type: str
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_commands(self) -> set[str]:
|
||||
"""Return the set of command names this handler owns.
|
||||
|
||||
These are provider-specific commands (e.g., 'albums' for Immich,
|
||||
'repos' for Gitea). Universal commands like 'help' and 'start'
|
||||
are handled by the main dispatcher.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle(
|
||||
self,
|
||||
cmd: str,
|
||||
args: str,
|
||||
count: int,
|
||||
locale: str,
|
||||
response_mode: str,
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
bot: TelegramBot,
|
||||
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
) -> str | list[dict[str, Any]] | None:
|
||||
"""Handle a provider-specific command.
|
||||
|
||||
Args:
|
||||
cmd: The command name (without '/').
|
||||
args: Arguments after the command.
|
||||
count: Number of results to return.
|
||||
locale: User's locale ('en', 'ru').
|
||||
response_mode: 'media' or 'text'.
|
||||
providers_map: Provider instances keyed by ID.
|
||||
cmd_templates: Template slots {slot_name: {locale: template}}.
|
||||
bot: The Telegram bot instance.
|
||||
ctx_tuples: Command context tuples for this provider type.
|
||||
|
||||
Returns:
|
||||
Text response, media list, or None if unhandled.
|
||||
"""
|
||||
|
||||
def get_rate_categories(self) -> dict[str, str]:
|
||||
"""Return rate limit category mapping for this provider's commands.
|
||||
|
||||
Keys are command names, values are category strings.
|
||||
Commands not listed default to 'default' category.
|
||||
"""
|
||||
return {}
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Provider command handler registry and routing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from .base import ProviderCommandHandler
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_REGISTRY: dict[str, ProviderCommandHandler] = {}
|
||||
|
||||
|
||||
def register_handler(handler: ProviderCommandHandler) -> None:
|
||||
"""Register a provider command handler."""
|
||||
_REGISTRY[handler.provider_type] = handler
|
||||
_LOGGER.debug("Registered command handler for provider type: %s", handler.provider_type)
|
||||
|
||||
|
||||
def get_handler(provider_type: str) -> ProviderCommandHandler | None:
|
||||
"""Get the command handler for a provider type."""
|
||||
return _REGISTRY.get(provider_type)
|
||||
|
||||
|
||||
def get_all_handlers() -> dict[str, ProviderCommandHandler]:
|
||||
"""Get all registered handlers."""
|
||||
return dict(_REGISTRY)
|
||||
|
||||
|
||||
def _auto_register() -> None:
|
||||
"""Auto-register all built-in handlers."""
|
||||
from .immich_handler import ImmichCommandHandler
|
||||
from .gitea_handler import GiteaCommandHandler
|
||||
|
||||
register_handler(ImmichCommandHandler())
|
||||
register_handler(GiteaCommandHandler())
|
||||
|
||||
|
||||
# Auto-register on import
|
||||
_auto_register()
|
||||
@@ -0,0 +1,252 @@
|
||||
"""Gitea-specific bot command handler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import (
|
||||
CommandConfig, CommandTracker, EventLog,
|
||||
NotificationTracker, ServiceProvider, TelegramBot,
|
||||
)
|
||||
from ..services import make_gitea_provider
|
||||
from .base import ProviderCommandHandler
|
||||
from .handler import _render_cmd_template, _get_notification_trackers_for_providers
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_GITEA_COMMANDS = {"status", "repos", "issues", "prs", "commits"}
|
||||
|
||||
|
||||
class GiteaCommandHandler(ProviderCommandHandler):
|
||||
"""Handles Gitea-specific bot commands."""
|
||||
|
||||
provider_type = "gitea"
|
||||
|
||||
def get_provider_commands(self) -> set[str]:
|
||||
return _GITEA_COMMANDS
|
||||
|
||||
def get_rate_categories(self) -> dict[str, str]:
|
||||
return {
|
||||
"repos": "api", "issues": "api",
|
||||
"prs": "api", "commits": "api",
|
||||
}
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
cmd: str,
|
||||
args: str,
|
||||
count: int,
|
||||
locale: str,
|
||||
response_mode: str,
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
bot: TelegramBot,
|
||||
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
) -> str | list[dict[str, Any]] | None:
|
||||
if cmd == "status":
|
||||
ctx = await _cmd_status(providers_map)
|
||||
return _render_cmd_template(cmd_templates, "status", locale, ctx)
|
||||
if cmd == "repos":
|
||||
ctx = await _cmd_repos(providers_map)
|
||||
return _render_cmd_template(cmd_templates, "repos", locale, ctx)
|
||||
if cmd == "issues":
|
||||
ctx = await _cmd_issues(providers_map, count)
|
||||
return _render_cmd_template(cmd_templates, "issues", locale, ctx)
|
||||
if cmd == "prs":
|
||||
ctx = await _cmd_prs(providers_map, count)
|
||||
return _render_cmd_template(cmd_templates, "prs", locale, ctx)
|
||||
if cmd == "commits":
|
||||
ctx = await _cmd_commits(providers_map, count)
|
||||
return _render_cmd_template(cmd_templates, "commits", locale, ctx)
|
||||
return None
|
||||
|
||||
|
||||
def _get_tracked_repos(
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
trackers: list[NotificationTracker],
|
||||
) -> list[tuple[ServiceProvider, str, str]]:
|
||||
"""Get (provider, owner, repo) tuples from tracked collection_ids."""
|
||||
repos: list[tuple[ServiceProvider, str, str]] = []
|
||||
for tracker in trackers:
|
||||
provider = providers_map.get(tracker.provider_id)
|
||||
if not provider or provider.type != "gitea":
|
||||
continue
|
||||
if not provider.config.get("api_token"):
|
||||
continue
|
||||
for full_name in (tracker.collection_ids or []):
|
||||
parts = full_name.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
repos.append((provider, parts[0], parts[1]))
|
||||
# Also check filters.collections
|
||||
for tracker in trackers:
|
||||
provider = providers_map.get(tracker.provider_id)
|
||||
if not provider or provider.type != "gitea":
|
||||
continue
|
||||
if not provider.config.get("api_token"):
|
||||
continue
|
||||
for full_name in (tracker.filters or {}).get("collections", []):
|
||||
parts = full_name.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
entry = (provider, parts[0], parts[1])
|
||||
if entry not in repos:
|
||||
repos.append(entry)
|
||||
return repos[:20] # Cap to prevent API hammering
|
||||
|
||||
|
||||
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_repos = _get_tracked_repos(providers_map, trackers)
|
||||
|
||||
# Get server version from first Gitea provider with token
|
||||
server_version = "unknown"
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider in providers_map.values():
|
||||
if provider.type == "gitea" and provider.config.get("api_token"):
|
||||
gitea = make_gitea_provider(http, provider)
|
||||
version = await gitea.client.get_server_version()
|
||||
if version:
|
||||
server_version = version
|
||||
break
|
||||
|
||||
# Last event
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
if tracker_ids:
|
||||
result = await session.exec(
|
||||
select(EventLog)
|
||||
.where(EventLog.tracker_id.in_(tracker_ids))
|
||||
.order_by(EventLog.created_at.desc()).limit(1)
|
||||
)
|
||||
last_event = result.first()
|
||||
else:
|
||||
last_event = None
|
||||
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
|
||||
|
||||
return {
|
||||
"repos_count": len(tracked_repos),
|
||||
"server_version": server_version,
|
||||
"last_event": last_str,
|
||||
}
|
||||
|
||||
|
||||
async def _cmd_repos(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_repos = _get_tracked_repos(providers_map, trackers)
|
||||
|
||||
repos_data: list[dict[str, Any]] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider, owner, repo in tracked_repos:
|
||||
gitea = make_gitea_provider(http, provider)
|
||||
try:
|
||||
all_repos = await gitea.client.get_repos(limit=50)
|
||||
for r in all_repos:
|
||||
if r.get("full_name") == f"{owner}/{repo}":
|
||||
repos_data.append({
|
||||
"full_name": r.get("full_name", ""),
|
||||
"description": r.get("description", ""),
|
||||
"stars": r.get("stars_count", 0),
|
||||
"url": r.get("html_url", ""),
|
||||
})
|
||||
break
|
||||
else:
|
||||
repos_data.append({
|
||||
"full_name": f"{owner}/{repo}",
|
||||
"description": "",
|
||||
"stars": 0,
|
||||
"url": "",
|
||||
})
|
||||
except Exception:
|
||||
repos_data.append({
|
||||
"full_name": f"{owner}/{repo}",
|
||||
"description": "?",
|
||||
"stars": 0,
|
||||
"url": "",
|
||||
})
|
||||
|
||||
return {"repos": repos_data}
|
||||
|
||||
|
||||
async def _cmd_issues(
|
||||
providers_map: dict[int, ServiceProvider], count: int,
|
||||
) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_repos = _get_tracked_repos(providers_map, trackers)
|
||||
|
||||
all_issues: list[dict[str, Any]] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider, owner, repo in tracked_repos:
|
||||
gitea = make_gitea_provider(http, provider)
|
||||
issues = await gitea.client.get_repo_issues(owner, repo, limit=count)
|
||||
for issue in issues:
|
||||
all_issues.append({
|
||||
"repo": f"{owner}/{repo}",
|
||||
"number": issue.get("number", 0),
|
||||
"title": issue.get("title", ""),
|
||||
"url": issue.get("html_url", ""),
|
||||
"user": issue.get("user", {}).get("login", ""),
|
||||
"state": issue.get("state", ""),
|
||||
})
|
||||
|
||||
all_issues.sort(key=lambda i: i.get("number", 0), reverse=True)
|
||||
return {"issues": all_issues[:count]}
|
||||
|
||||
|
||||
async def _cmd_prs(
|
||||
providers_map: dict[int, ServiceProvider], count: int,
|
||||
) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_repos = _get_tracked_repos(providers_map, trackers)
|
||||
|
||||
all_prs: list[dict[str, Any]] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider, owner, repo in tracked_repos:
|
||||
gitea = make_gitea_provider(http, provider)
|
||||
prs = await gitea.client.get_repo_pulls(owner, repo, limit=count)
|
||||
for pr in prs:
|
||||
all_prs.append({
|
||||
"repo": f"{owner}/{repo}",
|
||||
"number": pr.get("number", 0),
|
||||
"title": pr.get("title", ""),
|
||||
"url": pr.get("html_url", ""),
|
||||
"user": pr.get("user", {}).get("login", ""),
|
||||
"state": pr.get("state", ""),
|
||||
})
|
||||
|
||||
all_prs.sort(key=lambda p: p.get("number", 0), reverse=True)
|
||||
return {"prs": all_prs[:count]}
|
||||
|
||||
|
||||
async def _cmd_commits(
|
||||
providers_map: dict[int, ServiceProvider], count: int,
|
||||
) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_repos = _get_tracked_repos(providers_map, trackers)
|
||||
|
||||
all_commits: list[dict[str, Any]] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider, owner, repo in tracked_repos:
|
||||
gitea = make_gitea_provider(http, provider)
|
||||
commits = await gitea.client.get_repo_commits(owner, repo, limit=count)
|
||||
for c in commits:
|
||||
commit_data = c.get("commit", {})
|
||||
all_commits.append({
|
||||
"repo": f"{owner}/{repo}",
|
||||
"short_id": c.get("sha", "")[:7],
|
||||
"message": commit_data.get("message", "").split("\n")[0][:80],
|
||||
"author": commit_data.get("author", {}).get("name", ""),
|
||||
"url": c.get("html_url", ""),
|
||||
})
|
||||
|
||||
return {"commits": all_commits[:count]}
|
||||
@@ -1,11 +1,9 @@
|
||||
"""Telegram bot command handler — implements all /commands."""
|
||||
"""Telegram bot command handler — provider-agnostic dispatcher."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random as rng
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -14,7 +12,6 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.notifications.telegram.media import TELEGRAM_API_BASE_URL
|
||||
from ..database.engine import get_engine
|
||||
from ..services import make_immich_provider
|
||||
from ..database.models import (
|
||||
CommandConfig,
|
||||
CommandTemplateConfig,
|
||||
@@ -22,12 +19,9 @@ from ..database.models import (
|
||||
CommandTracker,
|
||||
CommandTrackerListener,
|
||||
EventLog,
|
||||
NotificationTarget,
|
||||
NotificationTracker,
|
||||
NotificationTrackerTarget,
|
||||
ServiceProvider,
|
||||
TelegramBot,
|
||||
TrackingConfig,
|
||||
)
|
||||
from .parser import parse_command
|
||||
from .registry import get_rate_category
|
||||
@@ -90,7 +84,6 @@ async def _resolve_command_context(
|
||||
"""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
# Find all listeners for this bot
|
||||
result = await session.exec(
|
||||
select(CommandTrackerListener).where(
|
||||
CommandTrackerListener.listener_type == "telegram_bot",
|
||||
@@ -115,19 +108,20 @@ async def _resolve_command_context(
|
||||
continue
|
||||
tuples.append((tracker, config, provider))
|
||||
|
||||
# Load command template slots from the first config that has one
|
||||
# Load command template slots — merge from all configs
|
||||
cmd_template_slots: dict[str, dict[str, str]] = {}
|
||||
seen_config_ids: set[int] = set()
|
||||
for _, config, _ in tuples:
|
||||
if config.command_template_config_id:
|
||||
cfg_id = config.command_template_config_id
|
||||
if cfg_id and cfg_id not in seen_config_ids:
|
||||
seen_config_ids.add(cfg_id)
|
||||
slot_result = await session.exec(
|
||||
select(CommandTemplateSlot).where(
|
||||
CommandTemplateSlot.config_id == config.command_template_config_id
|
||||
CommandTemplateSlot.config_id == cfg_id
|
||||
)
|
||||
)
|
||||
for s in slot_result.all():
|
||||
cmd_template_slots.setdefault(s.slot_name, {})[s.locale] = s.template
|
||||
if cmd_template_slots:
|
||||
break
|
||||
|
||||
return tuples, cmd_template_slots
|
||||
|
||||
@@ -135,19 +129,14 @@ async def _resolve_command_context(
|
||||
def _merge_command_context(
|
||||
ctx: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
) -> tuple[list[str], str, int, dict[str, Any]]:
|
||||
"""Merge enabled_commands from all configs and pick defaults from first config.
|
||||
|
||||
Returns (enabled_commands, response_mode, default_count, rate_limits).
|
||||
"""
|
||||
"""Merge enabled_commands from all configs and pick defaults from first config."""
|
||||
if not ctx:
|
||||
return [], "media", 5, {}
|
||||
|
||||
# Union of all enabled commands across configs
|
||||
enabled: set[str] = set()
|
||||
for _, config, _ in ctx:
|
||||
enabled.update(config.enabled_commands or [])
|
||||
|
||||
# Use first config's settings as defaults
|
||||
first_config = ctx[0][1]
|
||||
response_mode = first_config.response_mode or "media"
|
||||
default_count = first_config.default_count or 5
|
||||
@@ -162,10 +151,9 @@ async def handle_command(
|
||||
text: str,
|
||||
language_code: str = "",
|
||||
) -> str | list[dict[str, Any]] | None:
|
||||
"""Handle a bot command. Returns text response, media list, or None.
|
||||
"""Handle a bot command. Routes to provider-specific handlers.
|
||||
|
||||
language_code is the Telegram user's language (from message.from.language_code).
|
||||
Used to pick the right locale for template rendering.
|
||||
Returns text response, media list, or None.
|
||||
"""
|
||||
cmd, args, count_override = parse_command(text)
|
||||
if not cmd:
|
||||
@@ -174,10 +162,7 @@ async def handle_command(
|
||||
ctx_tuples, cmd_templates = await _resolve_command_context(bot)
|
||||
enabled, response_mode, default_count, rate_limits = _merge_command_context(ctx_tuples)
|
||||
|
||||
# Derive locale from Telegram user language, falling back to "en"
|
||||
locale = language_code[:2].lower() if language_code else "en"
|
||||
# Only use locale if we actually have templates for it, otherwise fall back
|
||||
# (_render_cmd_template handles per-slot fallback, but let's normalize)
|
||||
if locale not in ("en", "ru"):
|
||||
locale = "en"
|
||||
|
||||
@@ -185,7 +170,7 @@ async def handle_command(
|
||||
return _render_cmd_template(cmd_templates, "start", locale, {"bot_name": bot.name})
|
||||
|
||||
if cmd not in enabled and cmd != "start":
|
||||
return None # Silently ignore disabled commands
|
||||
return None
|
||||
|
||||
# Rate limit check
|
||||
wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits)
|
||||
@@ -199,24 +184,34 @@ async def handle_command(
|
||||
for _, _, provider in ctx_tuples:
|
||||
providers_map[provider.id] = provider
|
||||
|
||||
# Dispatch — each handler returns template context dict
|
||||
# Universal commands
|
||||
if cmd == "help":
|
||||
ctx = _cmd_help(enabled, locale, cmd_templates)
|
||||
elif cmd == "status":
|
||||
ctx = await _cmd_status(bot, providers_map, locale)
|
||||
elif cmd == "albums":
|
||||
ctx = await _cmd_albums(bot, providers_map, locale)
|
||||
elif cmd == "events":
|
||||
ctx = await _cmd_events(bot, providers_map, count, locale)
|
||||
elif cmd == "people":
|
||||
ctx = await _cmd_people(providers_map, locale)
|
||||
elif cmd in ("search", "find", "person", "place", "latest", "random",
|
||||
"favorites", "summary", "memory"):
|
||||
return await _cmd_immich(bot, cmd, args, count, locale, response_mode, providers_map, cmd_templates)
|
||||
else:
|
||||
return None
|
||||
return _render_cmd_template(cmd_templates, "help", locale, ctx)
|
||||
|
||||
return _render_cmd_template(cmd_templates, cmd, locale, {**ctx})
|
||||
# Provider-specific dispatch
|
||||
from .dispatch import get_handler
|
||||
|
||||
# Group ctx_tuples by provider type
|
||||
by_type: dict[str, list[tuple[CommandTracker, CommandConfig, ServiceProvider]]] = {}
|
||||
for t in ctx_tuples:
|
||||
ptype = t[2].type
|
||||
by_type.setdefault(ptype, []).append(t)
|
||||
|
||||
# Find which handler claims this command
|
||||
for ptype, ptuples in by_type.items():
|
||||
handler = get_handler(ptype)
|
||||
if handler and cmd in handler.get_provider_commands():
|
||||
# Build provider map filtered to this provider type
|
||||
pmap = {p.id: p for _, _, p in ptuples}
|
||||
result = await handler.handle(
|
||||
cmd, args, count, locale, response_mode,
|
||||
pmap, cmd_templates, bot, ptuples,
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _cmd_help(
|
||||
@@ -245,325 +240,6 @@ async def _get_notification_trackers_for_providers(
|
||||
return list(result.all())
|
||||
|
||||
|
||||
async def _check_native_memory(bot: TelegramBot) -> bool:
|
||||
"""Check if any tracker-target linked to this bot uses native memory source."""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(NotificationTarget).where(
|
||||
NotificationTarget.type == "telegram",
|
||||
NotificationTarget.user_id == bot.user_id,
|
||||
)
|
||||
)
|
||||
targets = result.all()
|
||||
bot_target_ids = {t.id for t in targets if t.config.get("bot_token") == bot.token}
|
||||
if not bot_target_ids:
|
||||
return False
|
||||
tt_result = await session.exec(
|
||||
select(NotificationTrackerTarget).where(NotificationTrackerTarget.target_id.in_(bot_target_ids))
|
||||
)
|
||||
for tt in tt_result.all():
|
||||
if tt.tracking_config_id:
|
||||
tc = await session.get(TrackingConfig, tt.tracking_config_id)
|
||||
if tc and tc.memory_source == "native":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _cmd_status(bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
active = sum(1 for t in trackers if t.enabled)
|
||||
total = len(trackers)
|
||||
total_albums = sum(len(t.collection_ids or []) for t in trackers)
|
||||
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(EventLog).order_by(EventLog.created_at.desc()).limit(1)
|
||||
)
|
||||
last_event = result.first()
|
||||
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
|
||||
|
||||
return {"trackers_active": active, "trackers_total": total, "total_albums": total_albums, "last_event": last_str}
|
||||
|
||||
|
||||
async def _cmd_albums(bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
if not trackers:
|
||||
return {"albums": []}
|
||||
|
||||
albums_data: list[dict] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for tracker in trackers:
|
||||
provider = providers_map.get(tracker.provider_id)
|
||||
if not provider or provider.type != "immich":
|
||||
continue
|
||||
immich = make_immich_provider(http, provider)
|
||||
for album_id in (tracker.collection_ids or []):
|
||||
try:
|
||||
album = await immich.client.get_album(album_id)
|
||||
if album:
|
||||
albums_data.append({"name": album.name, "asset_count": album.asset_count, "id": album_id})
|
||||
except Exception:
|
||||
albums_data.append({"name": f"{album_id[:8]}...", "asset_count": "?", "id": album_id})
|
||||
|
||||
return {"albums": albums_data}
|
||||
|
||||
|
||||
async def _cmd_events(bot: TelegramBot, providers_map: dict[int, ServiceProvider], count: int, locale: str) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
if not tracker_ids:
|
||||
return {"events": []}
|
||||
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(EventLog)
|
||||
.where(EventLog.tracker_id.in_(tracker_ids))
|
||||
.order_by(EventLog.created_at.desc())
|
||||
.limit(count)
|
||||
)
|
||||
events = result.all()
|
||||
|
||||
events_data = [{"type": e.event_type, "album": e.collection_name, "count": e.assets_count,
|
||||
"date": e.created_at.strftime("%m/%d %H:%M")} for e in events]
|
||||
|
||||
return {"events": events_data}
|
||||
|
||||
|
||||
async def _cmd_people(providers_map: dict[int, ServiceProvider], locale: str) -> dict[str, Any]:
|
||||
all_people: dict[str, str] = {}
|
||||
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider in providers_map.values():
|
||||
if provider.type != "immich":
|
||||
continue
|
||||
immich = make_immich_provider(http, provider)
|
||||
people = await immich.client.get_people()
|
||||
all_people.update(people)
|
||||
|
||||
names = sorted(all_people.values())
|
||||
return {"people": names}
|
||||
|
||||
|
||||
async def _cmd_immich(
|
||||
bot: TelegramBot, cmd: str, args: str, count: int, locale: str,
|
||||
response_mode: str, providers_map: dict[int, ServiceProvider],
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Handle commands that need Immich API access and may return media."""
|
||||
if not providers_map:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
|
||||
|
||||
# Get notification trackers for album data
|
||||
provider_ids = set(providers_map.keys())
|
||||
notification_trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
|
||||
all_album_ids: list[str] = []
|
||||
for t in notification_trackers:
|
||||
all_album_ids.extend(t.collection_ids or [])
|
||||
|
||||
# Pick the first immich provider
|
||||
provider: ServiceProvider | None = None
|
||||
for p in providers_map.values():
|
||||
if p.type == "immich":
|
||||
provider = p
|
||||
break
|
||||
if not provider:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
|
||||
|
||||
async with aiohttp.ClientSession() as http:
|
||||
immich = make_immich_provider(http, provider)
|
||||
client = immich.client
|
||||
|
||||
if cmd == "search":
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": ""})
|
||||
assets = await client.search_smart(args, album_ids=all_album_ids, limit=count)
|
||||
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "find":
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": ""})
|
||||
assets = await client.search_metadata(args, album_ids=all_album_ids, limit=count)
|
||||
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "person":
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": ""})
|
||||
people = await client.get_people()
|
||||
person_id = None
|
||||
for pid, pname in people.items():
|
||||
if args.lower() in pname.lower():
|
||||
person_id = pid
|
||||
break
|
||||
if not person_id:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": args})
|
||||
assets = await client.search_by_person(person_id, limit=count)
|
||||
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "place":
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "place", "query": ""})
|
||||
assets = await client.search_smart(
|
||||
f"photos taken in {args}", album_ids=all_album_ids, limit=count
|
||||
)
|
||||
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "favorites":
|
||||
fav_assets: list[dict[str, Any]] = []
|
||||
for album_id in all_album_ids[:10]:
|
||||
try:
|
||||
album = await client.get_album(album_id)
|
||||
if album:
|
||||
for aid, asset in list(album.assets.items())[:50]:
|
||||
if asset.is_favorite and len(fav_assets) < count:
|
||||
fav_assets.append({
|
||||
"id": asset.id, "originalFileName": asset.filename,
|
||||
"type": asset.type,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
if len(fav_assets) >= count:
|
||||
break
|
||||
return _format_assets(fav_assets, cmd, "", locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "latest":
|
||||
latest_assets: list[dict[str, Any]] = []
|
||||
for album_id in all_album_ids[:10]:
|
||||
try:
|
||||
album = await client.get_album(album_id)
|
||||
if album:
|
||||
for aid, asset in list(album.assets.items())[:count]:
|
||||
latest_assets.append({
|
||||
"id": asset.id, "originalFileName": asset.filename,
|
||||
"type": asset.type, "createdAt": asset.created_at,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
latest_assets.sort(key=lambda a: a.get("createdAt", ""), reverse=True)
|
||||
return _format_assets(latest_assets[:count], cmd, "", locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "random":
|
||||
random_assets: list[dict[str, Any]] = []
|
||||
for album_id in all_album_ids[:10]:
|
||||
try:
|
||||
album = await client.get_album(album_id)
|
||||
if album:
|
||||
asset_list = list(album.assets.values())
|
||||
sampled = rng.sample(asset_list, min(count, len(asset_list)))
|
||||
for asset in sampled:
|
||||
random_assets.append({
|
||||
"id": asset.id, "originalFileName": asset.filename,
|
||||
"type": asset.type,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
rng.shuffle(random_assets)
|
||||
return _format_assets(random_assets[:count], cmd, "", locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "summary":
|
||||
albums_data: list[dict] = []
|
||||
for album_id in all_album_ids:
|
||||
try:
|
||||
album = await client.get_album(album_id)
|
||||
if album:
|
||||
albums_data.append({"name": album.name, "asset_count": album.asset_count, "id": album_id})
|
||||
except Exception:
|
||||
pass
|
||||
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": albums_data})
|
||||
|
||||
if cmd == "memory":
|
||||
# Check if any linked tracking config uses native memories
|
||||
use_native = await _check_native_memory(bot)
|
||||
|
||||
today = datetime.now(timezone.utc)
|
||||
memory_assets: list[dict[str, Any]] = []
|
||||
|
||||
if use_native:
|
||||
# Use Immich native memories API
|
||||
memories = await client.get_memories()
|
||||
tracked_ids = set(all_album_ids) if all_album_ids else None
|
||||
for mem in memories:
|
||||
year = mem.get("data", {}).get("year")
|
||||
for raw_asset in mem.get("assets", []):
|
||||
if tracked_ids:
|
||||
asset_albums = raw_asset.get("albums", [])
|
||||
if not any(a.get("id") in tracked_ids for a in asset_albums):
|
||||
continue
|
||||
memory_assets.append({
|
||||
"id": raw_asset.get("id", ""),
|
||||
"originalFileName": raw_asset.get("originalFileName", ""),
|
||||
"type": raw_asset.get("type", "IMAGE"),
|
||||
"createdAt": raw_asset.get("fileCreatedAt", raw_asset.get("createdAt", "")),
|
||||
"year": year,
|
||||
})
|
||||
else:
|
||||
# Album-scanning fallback
|
||||
month_day = (today.month, today.day)
|
||||
for album_id in all_album_ids[:10]:
|
||||
try:
|
||||
album = await client.get_album(album_id)
|
||||
if album:
|
||||
for aid, asset in album.assets.items():
|
||||
try:
|
||||
dt = datetime.fromisoformat(asset.created_at.replace("Z", "+00:00"))
|
||||
if (dt.month, dt.day) == month_day and dt.year != today.year:
|
||||
memory_assets.append({
|
||||
"id": asset.id, "originalFileName": asset.filename,
|
||||
"type": asset.type, "createdAt": asset.created_at,
|
||||
"year": dt.year,
|
||||
})
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
memory_assets = memory_assets[:count]
|
||||
if not memory_assets:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "memory", "query": ""})
|
||||
return _format_assets(memory_assets, cmd, "", locale, response_mode, client, cmd_templates)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _format_assets(
|
||||
assets: list[dict[str, Any]], cmd: str, query: str,
|
||||
locale: str, response_mode: str, client: Any,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Format asset results as text or media payload."""
|
||||
if not assets:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": query})
|
||||
|
||||
if response_mode == "media":
|
||||
media_items = []
|
||||
for asset in assets:
|
||||
asset_id = asset.get("id", "")
|
||||
filename = asset.get("originalFileName", "")
|
||||
year = asset.get("year", "")
|
||||
caption = f"{filename} ({year})" if year else filename
|
||||
media_items.append({
|
||||
"type": "photo",
|
||||
"asset_id": asset_id,
|
||||
"caption": caption,
|
||||
"thumbnail_url": f"{client.url}/api/assets/{asset_id}/thumbnail?size=preview",
|
||||
"api_key": client.api_key,
|
||||
})
|
||||
return media_items
|
||||
|
||||
# Text mode — render via template
|
||||
slot_map = {"find": "search", "person": "search", "place": "search"}
|
||||
slot_name = slot_map.get(cmd, cmd)
|
||||
return _render_cmd_template(cmd_templates, slot_name, locale, {
|
||||
"assets": assets, "query": query, "command": cmd, "count": len(assets),
|
||||
})
|
||||
|
||||
|
||||
async def send_reply(bot_token: str, chat_id: str, text: str) -> None:
|
||||
"""Send a text reply via Telegram Bot API, retrying without HTML on parse failure."""
|
||||
async with aiohttp.ClientSession() as http:
|
||||
@@ -586,17 +262,12 @@ async def send_reply(bot_token: str, chat_id: str, text: str) -> None:
|
||||
async def send_media_group(
|
||||
bot_token: str, chat_id: str, media_items: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Send media items as a Telegram media group (album).
|
||||
|
||||
Falls back to individual sendPhoto calls if sendMediaGroup fails.
|
||||
Telegram allows max 10 items per media group.
|
||||
"""
|
||||
"""Send media items as a Telegram media group (album)."""
|
||||
if not media_items:
|
||||
return
|
||||
|
||||
async with aiohttp.ClientSession() as http:
|
||||
# Download all thumbnails first
|
||||
downloaded: list[tuple[bytes, str, str]] = [] # (photo_bytes, asset_id, caption)
|
||||
downloaded: list[tuple[bytes, str, str]] = []
|
||||
for item in media_items:
|
||||
asset_id = item.get("asset_id", "")
|
||||
caption = item.get("caption", "")
|
||||
@@ -615,12 +286,10 @@ async def send_media_group(
|
||||
if not downloaded:
|
||||
return
|
||||
|
||||
# Send in groups of 10 (Telegram limit)
|
||||
for i in range(0, len(downloaded), 10):
|
||||
chunk = downloaded[i:i + 10]
|
||||
|
||||
if len(chunk) == 1:
|
||||
# Single photo — use sendPhoto
|
||||
photo_bytes, asset_id, caption = chunk[0]
|
||||
data = aiohttp.FormData()
|
||||
data.add_field("chat_id", chat_id)
|
||||
@@ -635,7 +304,6 @@ async def send_media_group(
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to send photo: %s", err)
|
||||
else:
|
||||
# Multiple photos — use sendMediaGroup
|
||||
import json as _json
|
||||
data = aiohttp.FormData()
|
||||
data.add_field("chat_id", chat_id)
|
||||
@@ -658,12 +326,7 @@ async def send_media_group(
|
||||
|
||||
|
||||
async def register_commands_with_telegram(bot: TelegramBot) -> bool:
|
||||
"""Register enabled commands with Telegram BotFather API.
|
||||
|
||||
Registers all supported locales explicitly with language_code,
|
||||
plus English as the default fallback (no language_code).
|
||||
Descriptions are read from desc_* template slots.
|
||||
"""
|
||||
"""Register enabled commands with Telegram BotFather API."""
|
||||
ctx_tuples, templates = await _resolve_command_context(bot)
|
||||
enabled, _, _, _ = _merge_command_context(ctx_tuples)
|
||||
|
||||
@@ -677,7 +340,6 @@ async def register_commands_with_telegram(bot: TelegramBot) -> bool:
|
||||
desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd
|
||||
commands.append({"command": cmd, "description": desc})
|
||||
|
||||
# Register with explicit language_code
|
||||
payload: dict[str, Any] = {"commands": commands, "language_code": locale}
|
||||
try:
|
||||
async with http.post(url, json=payload) as resp:
|
||||
@@ -689,7 +351,6 @@ async def register_commands_with_telegram(bot: TelegramBot) -> bool:
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.error("Failed to register commands for locale '%s': %s", locale, err)
|
||||
|
||||
# Also register English as the default (no language_code) for unsupported langs
|
||||
en_commands = []
|
||||
for cmd in enabled:
|
||||
desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd
|
||||
|
||||
@@ -0,0 +1,414 @@
|
||||
"""Immich-specific bot command handler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random as rng
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import (
|
||||
CommandConfig, CommandTracker, NotificationTarget,
|
||||
NotificationTracker, NotificationTrackerTarget,
|
||||
ServiceProvider, TelegramBot, TrackingConfig,
|
||||
)
|
||||
from ..services import make_immich_provider
|
||||
from .base import ProviderCommandHandler
|
||||
from .handler import _render_cmd_template, _get_notification_trackers_for_providers
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_IMMICH_COMMANDS = {
|
||||
"status", "albums", "events", "people",
|
||||
"search", "find", "person", "place",
|
||||
"latest", "random", "favorites", "summary", "memory",
|
||||
}
|
||||
|
||||
|
||||
class ImmichCommandHandler(ProviderCommandHandler):
|
||||
"""Handles all Immich-specific bot commands."""
|
||||
|
||||
provider_type = "immich"
|
||||
|
||||
def get_provider_commands(self) -> set[str]:
|
||||
return _IMMICH_COMMANDS
|
||||
|
||||
def get_rate_categories(self) -> dict[str, str]:
|
||||
return {
|
||||
"search": "search", "find": "search", "person": "search",
|
||||
"place": "search", "favorites": "search", "people": "search",
|
||||
}
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
cmd: str,
|
||||
args: str,
|
||||
count: int,
|
||||
locale: str,
|
||||
response_mode: str,
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
bot: TelegramBot,
|
||||
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
) -> str | list[dict[str, Any]] | None:
|
||||
if cmd == "status":
|
||||
ctx = await _cmd_status(bot, providers_map, locale)
|
||||
return _render_cmd_template(cmd_templates, "status", locale, ctx)
|
||||
if cmd == "albums":
|
||||
ctx = await _cmd_albums(bot, providers_map, locale)
|
||||
return _render_cmd_template(cmd_templates, "albums", locale, ctx)
|
||||
if cmd == "events":
|
||||
ctx = await _cmd_events(bot, providers_map, count, locale)
|
||||
return _render_cmd_template(cmd_templates, "events", locale, ctx)
|
||||
if cmd == "people":
|
||||
ctx = await _cmd_people(providers_map, locale)
|
||||
return _render_cmd_template(cmd_templates, "people", locale, ctx)
|
||||
if cmd in ("search", "find", "person", "place", "latest",
|
||||
"random", "favorites", "summary", "memory"):
|
||||
return await _cmd_immich(
|
||||
bot, cmd, args, count, locale, response_mode,
|
||||
providers_map, cmd_templates,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# --- Immich command implementations (moved from handler.py) ---
|
||||
|
||||
|
||||
async def _cmd_status(
|
||||
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
|
||||
) -> dict[str, Any]:
|
||||
from ..database.models import EventLog
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
active = sum(1 for t in trackers if t.enabled)
|
||||
total = len(trackers)
|
||||
total_albums = sum(len(t.collection_ids or []) for t in trackers)
|
||||
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(EventLog).order_by(EventLog.created_at.desc()).limit(1)
|
||||
)
|
||||
last_event = result.first()
|
||||
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
|
||||
|
||||
return {
|
||||
"trackers_active": active, "trackers_total": total,
|
||||
"total_albums": total_albums, "last_event": last_str,
|
||||
}
|
||||
|
||||
|
||||
async def _cmd_albums(
|
||||
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
|
||||
) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
if not trackers:
|
||||
return {"albums": []}
|
||||
|
||||
albums_data: list[dict] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for tracker in trackers:
|
||||
provider = providers_map.get(tracker.provider_id)
|
||||
if not provider or provider.type != "immich":
|
||||
continue
|
||||
immich = make_immich_provider(http, provider)
|
||||
for album_id in (tracker.collection_ids or []):
|
||||
try:
|
||||
album = await immich.client.get_album(album_id)
|
||||
if album:
|
||||
albums_data.append({
|
||||
"name": album.name, "asset_count": album.asset_count, "id": album_id,
|
||||
})
|
||||
except Exception:
|
||||
albums_data.append({
|
||||
"name": f"{album_id[:8]}...", "asset_count": "?", "id": album_id,
|
||||
})
|
||||
|
||||
return {"albums": albums_data}
|
||||
|
||||
|
||||
async def _cmd_events(
|
||||
bot: TelegramBot, providers_map: dict[int, ServiceProvider],
|
||||
count: int, locale: str,
|
||||
) -> dict[str, Any]:
|
||||
from ..database.models import EventLog
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
if not tracker_ids:
|
||||
return {"events": []}
|
||||
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(EventLog)
|
||||
.where(EventLog.tracker_id.in_(tracker_ids))
|
||||
.order_by(EventLog.created_at.desc())
|
||||
.limit(count)
|
||||
)
|
||||
events = result.all()
|
||||
|
||||
events_data = [
|
||||
{"type": e.event_type, "album": e.collection_name,
|
||||
"count": e.assets_count, "date": e.created_at.strftime("%m/%d %H:%M")}
|
||||
for e in events
|
||||
]
|
||||
return {"events": events_data}
|
||||
|
||||
|
||||
async def _cmd_people(
|
||||
providers_map: dict[int, ServiceProvider], locale: str,
|
||||
) -> dict[str, Any]:
|
||||
all_people: dict[str, str] = {}
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider in providers_map.values():
|
||||
if provider.type != "immich":
|
||||
continue
|
||||
immich = make_immich_provider(http, provider)
|
||||
people = await immich.client.get_people()
|
||||
all_people.update(people)
|
||||
names = sorted(all_people.values())
|
||||
return {"people": names}
|
||||
|
||||
|
||||
async def _check_native_memory(bot: TelegramBot) -> bool:
|
||||
"""Check if any tracker-target linked to this bot uses native memory source."""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(NotificationTarget).where(
|
||||
NotificationTarget.type == "telegram",
|
||||
NotificationTarget.user_id == bot.user_id,
|
||||
)
|
||||
)
|
||||
targets = result.all()
|
||||
bot_target_ids = {t.id for t in targets if t.config.get("bot_token") == bot.token}
|
||||
if not bot_target_ids:
|
||||
return False
|
||||
tt_result = await session.exec(
|
||||
select(NotificationTrackerTarget).where(
|
||||
NotificationTrackerTarget.target_id.in_(bot_target_ids)
|
||||
)
|
||||
)
|
||||
for tt in tt_result.all():
|
||||
if tt.tracking_config_id:
|
||||
tc = await session.get(TrackingConfig, tt.tracking_config_id)
|
||||
if tc and tc.memory_source == "native":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _cmd_immich(
|
||||
bot: TelegramBot, cmd: str, args: str, count: int, locale: str,
|
||||
response_mode: str, providers_map: dict[int, ServiceProvider],
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Handle commands that need Immich API access and may return media."""
|
||||
if not providers_map:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
|
||||
|
||||
provider_ids = set(providers_map.keys())
|
||||
notification_trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
|
||||
all_album_ids: list[str] = []
|
||||
for t in notification_trackers:
|
||||
all_album_ids.extend(t.collection_ids or [])
|
||||
|
||||
provider: ServiceProvider | None = None
|
||||
for p in providers_map.values():
|
||||
if p.type == "immich":
|
||||
provider = p
|
||||
break
|
||||
if not provider:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
|
||||
|
||||
async with aiohttp.ClientSession() as http:
|
||||
immich = make_immich_provider(http, provider)
|
||||
client = immich.client
|
||||
|
||||
if cmd == "search":
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": ""})
|
||||
assets = await client.search_smart(args, album_ids=all_album_ids, limit=count)
|
||||
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "find":
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": ""})
|
||||
assets = await client.search_metadata(args, album_ids=all_album_ids, limit=count)
|
||||
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "person":
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": ""})
|
||||
people = await client.get_people()
|
||||
person_id = None
|
||||
for pid, pname in people.items():
|
||||
if args.lower() in pname.lower():
|
||||
person_id = pid
|
||||
break
|
||||
if not person_id:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": args})
|
||||
assets = await client.search_by_person(person_id, limit=count)
|
||||
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "place":
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "place", "query": ""})
|
||||
assets = await client.search_smart(
|
||||
f"photos taken in {args}", album_ids=all_album_ids, limit=count
|
||||
)
|
||||
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "favorites":
|
||||
fav_assets: list[dict[str, Any]] = []
|
||||
for album_id in all_album_ids[:10]:
|
||||
try:
|
||||
album = await client.get_album(album_id)
|
||||
if album:
|
||||
for aid, asset in list(album.assets.items())[:50]:
|
||||
if asset.is_favorite and len(fav_assets) < count:
|
||||
fav_assets.append({
|
||||
"id": asset.id, "originalFileName": asset.filename,
|
||||
"type": asset.type,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
if len(fav_assets) >= count:
|
||||
break
|
||||
return _format_assets(fav_assets, cmd, "", locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "latest":
|
||||
latest_assets: list[dict[str, Any]] = []
|
||||
for album_id in all_album_ids[:10]:
|
||||
try:
|
||||
album = await client.get_album(album_id)
|
||||
if album:
|
||||
for aid, asset in list(album.assets.items())[:count]:
|
||||
latest_assets.append({
|
||||
"id": asset.id, "originalFileName": asset.filename,
|
||||
"type": asset.type, "createdAt": asset.created_at,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
latest_assets.sort(key=lambda a: a.get("createdAt", ""), reverse=True)
|
||||
return _format_assets(latest_assets[:count], cmd, "", locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "random":
|
||||
random_assets: list[dict[str, Any]] = []
|
||||
for album_id in all_album_ids[:10]:
|
||||
try:
|
||||
album = await client.get_album(album_id)
|
||||
if album:
|
||||
asset_list = list(album.assets.values())
|
||||
sampled = rng.sample(asset_list, min(count, len(asset_list)))
|
||||
for asset in sampled:
|
||||
random_assets.append({
|
||||
"id": asset.id, "originalFileName": asset.filename,
|
||||
"type": asset.type,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
rng.shuffle(random_assets)
|
||||
return _format_assets(random_assets[:count], cmd, "", locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "summary":
|
||||
albums_data: list[dict] = []
|
||||
for album_id in all_album_ids:
|
||||
try:
|
||||
album = await client.get_album(album_id)
|
||||
if album:
|
||||
albums_data.append({
|
||||
"name": album.name, "asset_count": album.asset_count, "id": album_id,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": albums_data})
|
||||
|
||||
if cmd == "memory":
|
||||
use_native = await _check_native_memory(bot)
|
||||
today = datetime.now(timezone.utc)
|
||||
memory_assets: list[dict[str, Any]] = []
|
||||
|
||||
if use_native:
|
||||
memories = await client.get_memories()
|
||||
tracked_ids = set(all_album_ids) if all_album_ids else None
|
||||
for mem in memories:
|
||||
year = mem.get("data", {}).get("year")
|
||||
for raw_asset in mem.get("assets", []):
|
||||
if tracked_ids:
|
||||
asset_albums = raw_asset.get("albums", [])
|
||||
if not any(a.get("id") in tracked_ids for a in asset_albums):
|
||||
continue
|
||||
memory_assets.append({
|
||||
"id": raw_asset.get("id", ""),
|
||||
"originalFileName": raw_asset.get("originalFileName", ""),
|
||||
"type": raw_asset.get("type", "IMAGE"),
|
||||
"createdAt": raw_asset.get("fileCreatedAt", raw_asset.get("createdAt", "")),
|
||||
"year": year,
|
||||
})
|
||||
else:
|
||||
month_day = (today.month, today.day)
|
||||
for album_id in all_album_ids[:10]:
|
||||
try:
|
||||
album = await client.get_album(album_id)
|
||||
if album:
|
||||
for aid, asset in album.assets.items():
|
||||
try:
|
||||
dt = datetime.fromisoformat(asset.created_at.replace("Z", "+00:00"))
|
||||
if (dt.month, dt.day) == month_day and dt.year != today.year:
|
||||
memory_assets.append({
|
||||
"id": asset.id, "originalFileName": asset.filename,
|
||||
"type": asset.type, "createdAt": asset.created_at,
|
||||
"year": dt.year,
|
||||
})
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
memory_assets = memory_assets[:count]
|
||||
if not memory_assets:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "memory", "query": ""})
|
||||
return _format_assets(memory_assets, cmd, "", locale, response_mode, client, cmd_templates)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _format_assets(
|
||||
assets: list[dict[str, Any]], cmd: str, query: str,
|
||||
locale: str, response_mode: str, client: Any,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Format asset results as text or media payload."""
|
||||
if not assets:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": query})
|
||||
|
||||
if response_mode == "media":
|
||||
media_items = []
|
||||
for asset in assets:
|
||||
asset_id = asset.get("id", "")
|
||||
filename = asset.get("originalFileName", "")
|
||||
year = asset.get("year", "")
|
||||
caption = f"{filename} ({year})" if year else filename
|
||||
media_items.append({
|
||||
"type": "photo",
|
||||
"asset_id": asset_id,
|
||||
"caption": caption,
|
||||
"thumbnail_url": f"{client.url}/api/assets/{asset_id}/thumbnail?size=preview",
|
||||
"api_key": client.api_key,
|
||||
})
|
||||
return media_items
|
||||
|
||||
slot_map = {"find": "search", "person": "search", "place": "search"}
|
||||
slot_name = slot_map.get(cmd, cmd)
|
||||
return _render_cmd_template(cmd_templates, slot_name, locale, {
|
||||
"assets": assets, "query": query, "command": cmd, "count": len(assets),
|
||||
})
|
||||
@@ -4,8 +4,11 @@ from __future__ import annotations
|
||||
|
||||
# Map commands to rate limit categories
|
||||
_RATE_CATEGORY: dict[str, str] = {
|
||||
# Immich
|
||||
"search": "search", "find": "search", "person": "search",
|
||||
"place": "search", "favorites": "search", "people": "search",
|
||||
# Gitea (API calls share a category)
|
||||
"repos": "api", "issues": "api", "prs": "api", "commits": "api",
|
||||
}
|
||||
|
||||
|
||||
@@ -17,5 +20,5 @@ DEFAULT_COMMANDS_CONFIG = {
|
||||
"enabled": ["help", "status", "albums", "events", "latest", "random", "favorites", "summary", "memory"],
|
||||
"response_mode": "media",
|
||||
"default_count": 5,
|
||||
"rate_limits": {"search": 30, "default": 10},
|
||||
"rate_limits": {"search": 30, "api": 15, "default": 10},
|
||||
}
|
||||
|
||||
@@ -53,6 +53,7 @@ async def lifespan(app: FastAPI):
|
||||
await _seed_default_templates()
|
||||
await _seed_default_command_templates()
|
||||
await _seed_default_tracking_configs()
|
||||
await _seed_default_command_configs()
|
||||
# Configure webhook secret from DB setting (falls back to env var)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession as _AS
|
||||
from .api.app_settings import get_setting as _get_setting
|
||||
@@ -393,7 +394,7 @@ async def _seed_default_command_templates():
|
||||
|
||||
# Upsert slots for each locale
|
||||
for locale in ("en", "ru"):
|
||||
slots = load_default_command_templates(locale)
|
||||
slots = load_default_command_templates(locale, provider_type="immich")
|
||||
if not slots:
|
||||
continue
|
||||
for slot_name, template_text in slots.items():
|
||||
@@ -416,6 +417,51 @@ async def _seed_default_command_templates():
|
||||
template=template_text,
|
||||
))
|
||||
|
||||
# --- Seed Gitea default command templates ---
|
||||
gitea_cmd_result = await session.exec(
|
||||
select(CommandTemplateConfig).where(
|
||||
CommandTemplateConfig.user_id == 0,
|
||||
CommandTemplateConfig.provider_type == "gitea",
|
||||
)
|
||||
)
|
||||
gitea_cmd_configs = gitea_cmd_result.all()
|
||||
|
||||
if not gitea_cmd_configs:
|
||||
gitea_cmd_config = CommandTemplateConfig(
|
||||
user_id=0,
|
||||
provider_type="gitea",
|
||||
name="Default Gitea Commands",
|
||||
description="Default Gitea command templates",
|
||||
)
|
||||
session.add(gitea_cmd_config)
|
||||
await session.flush()
|
||||
else:
|
||||
gitea_cmd_config = gitea_cmd_configs[0]
|
||||
|
||||
for locale in ("en", "ru"):
|
||||
gitea_cmd_slots = load_default_command_templates(locale, provider_type="gitea")
|
||||
if not gitea_cmd_slots:
|
||||
continue
|
||||
for slot_name, template_text in gitea_cmd_slots.items():
|
||||
slot_result = await session.exec(
|
||||
select(CommandTemplateSlot).where(
|
||||
CommandTemplateSlot.config_id == gitea_cmd_config.id,
|
||||
CommandTemplateSlot.slot_name == slot_name,
|
||||
CommandTemplateSlot.locale == locale,
|
||||
)
|
||||
)
|
||||
existing = slot_result.first()
|
||||
if existing:
|
||||
existing.template = template_text
|
||||
session.add(existing)
|
||||
else:
|
||||
session.add(CommandTemplateSlot(
|
||||
config_id=gitea_cmd_config.id,
|
||||
slot_name=slot_name,
|
||||
locale=locale,
|
||||
template=template_text,
|
||||
))
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
@@ -464,6 +510,85 @@ async def _seed_default_tracking_configs():
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def _seed_default_command_configs():
|
||||
"""Seed system-owned default command configs for each provider type."""
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from .database.engine import get_engine
|
||||
from .database.models import CommandConfig, CommandTemplateConfig
|
||||
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
# Find existing system-owned command configs
|
||||
result = await session.exec(
|
||||
select(CommandConfig).where(CommandConfig.user_id == 0)
|
||||
)
|
||||
existing = {c.provider_type: c for c in result.all()}
|
||||
|
||||
# Find system command template configs to link
|
||||
tmpl_result = await session.exec(
|
||||
select(CommandTemplateConfig).where(CommandTemplateConfig.user_id == 0)
|
||||
)
|
||||
tmpl_by_type = {t.provider_type: t.id for t in tmpl_result.all()}
|
||||
|
||||
defaults = [
|
||||
{
|
||||
"provider_type": "immich",
|
||||
"name": "Default Immich",
|
||||
"enabled_commands": [
|
||||
"help", "status", "albums", "events", "latest",
|
||||
"random", "favorites", "summary", "memory",
|
||||
],
|
||||
"response_mode": "media",
|
||||
"default_count": 5,
|
||||
"rate_limits": {"search": 30, "default": 10},
|
||||
},
|
||||
{
|
||||
"provider_type": "gitea",
|
||||
"name": "Default Gitea",
|
||||
"enabled_commands": [
|
||||
"help", "status", "repos", "issues", "prs", "commits",
|
||||
],
|
||||
"response_mode": "text",
|
||||
"default_count": 10,
|
||||
"rate_limits": {"api": 15, "default": 10},
|
||||
},
|
||||
]
|
||||
|
||||
for cfg in defaults:
|
||||
ptype = cfg["provider_type"]
|
||||
if ptype in existing:
|
||||
continue
|
||||
cmd_tmpl_id = tmpl_by_type.get(ptype)
|
||||
# Use raw SQL to handle legacy NOT NULL columns
|
||||
import json as _json2
|
||||
from sqlalchemy import text as _text2
|
||||
from datetime import datetime as _dt3, timezone as _tz3
|
||||
await session.execute(
|
||||
_text2(
|
||||
"INSERT INTO command_config "
|
||||
"(user_id, provider_type, name, icon, enabled_commands, locale, "
|
||||
"response_mode, default_count, rate_limits, command_template_config_id, created_at) "
|
||||
"VALUES (:uid, :pt, :name, :icon, :cmds, :locale, :rm, :dc, :rl, :ctid, :ca)"
|
||||
),
|
||||
{
|
||||
"uid": 0,
|
||||
"pt": ptype,
|
||||
"name": cfg["name"],
|
||||
"icon": "",
|
||||
"cmds": _json2.dumps(cfg["enabled_commands"]),
|
||||
"locale": "en",
|
||||
"rm": cfg["response_mode"],
|
||||
"dc": cfg["default_count"],
|
||||
"rl": _json2.dumps(cfg["rate_limits"]),
|
||||
"ctid": cmd_tmpl_id,
|
||||
"ca": _dt3.now(_tz3.utc).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
def run():
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8420)
|
||||
|
||||
Reference in New Issue
Block a user