refactor: comprehensive codebase review — security, performance, quality, UX

Security:
- Fix NUT protocol command injection (validate names against safe regex)
- Enable Jinja2 autoescape=True to prevent HTML injection via external data
- Add WebhookProviderConfig validation model

Performance:
- Shared aiohttp.ClientSession singleton (replaces 40+ per-request sessions)
- Fix 4 N+1 queries with batch IN loads (poller, scheduler, memory, broadcast)
- asyncio.gather for Gitea commands and notification dispatcher
- Add DB indexes on NotificationTrackerState.tracker_id, CommandTrackerListener
- LRU cache for compiled Jinja2 templates
- Daily EventLog cleanup job (90-day retention)
- 30s HTTP timeout on all external calls
- GROUP BY for target type counts (replaces 7 sequential queries)

Code quality:
- Extract get_owned_entity() helper (replaces 11 duplicate functions)
- Extract slot_helpers.py (load_slots, save_slots, render_template_preview)
- Extract command_utils.py (tracker lookup, last event, collection IDs)
- Extract http_session.py (shared session lifecycle)
- Provider connection validation dedup (3x → 1 helper)
- Command dispatch tables replacing if/elif chains
- Album+links fetch helper (fetch_albums_with_links)
- Provider dispatch polymorphism (list_provider_collections)
- Immutable _enrich_assets (no longer mutates in-place)
- Fix _format_assets return type + handler unpacking

Frontend:
- Fix 18+ hardcoded English strings → t() with new i18n keys (en + ru)
- Mobile "More" nav panel with provider filter and search
- Shared Button.svelte component (4 variants, 2 sizes)
- Shared ErrorBanner.svelte component (8 pages updated)
- SvelteKit goto() replacing window.location.href
- Dashboard grid fixed for 4 cards, paginator opacity consistency

Functionality:
- max_instances=1 on scheduler jobs (prevents duplicate events)
- Webhook provider in watcher (prevents error spam)
- Fix stale SQLModel reference in poller
- Gitea get_repo() direct API call
This commit is contained in:
2026-03-28 13:22:26 +03:00
parent 616b221c92
commit b803d004e1
65 changed files with 1934 additions and 1498 deletions
+34 -1
View File
@@ -43,6 +43,17 @@ jobs:
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache
cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max
- name: Trigger Portainer redeploy
continue-on-error: true
run: |
if [ -n "${{ secrets.DOCKER_REDEPLOY_WEBHOOK_URL }}" ]; then
echo "Triggering Portainer redeploy..."
curl -sf -X POST "${{ secrets.DOCKER_REDEPLOY_WEBHOOK_URL }}" \
--max-time 30 || echo "::warning::Portainer webhook failed"
else
echo "DOCKER_REDEPLOY_WEBHOOK_URL not set — skipping auto-deploy"
fi
- name: Generate changelog - name: Generate changelog
id: changelog id: changelog
run: | run: |
@@ -56,7 +67,29 @@ jobs:
- name: Create Gitea Release - name: Create Gitea Release
run: | run: |
BODY=$(cat /tmp/changelog.txt | python3 -c "import sys,json; print(json.dumps(sys.stdin.read()))") if [ -f RELEASE_NOTES.md ]; then
export RELEASE_NOTES=$(cat RELEASE_NOTES.md)
echo "Found RELEASE_NOTES.md"
else
export RELEASE_NOTES=""
echo "No RELEASE_NOTES.md found"
fi
BODY=$(python3 -c "
import json, os, sys
release_notes = os.environ.get('RELEASE_NOTES', '')
changelog = open('/tmp/changelog.txt').read().strip()
sections = []
if release_notes.strip():
sections.append(release_notes.strip())
if changelog:
sections.append('## Changelog\n\n' + changelog)
print(json.dumps('\n\n'.join(sections)))
")
curl -s -X POST \ curl -s -X POST \
"https://${{ env.REGISTRY }}/api/v1/repos/${{ env.IMAGE_NAME }}/releases" \ "https://${{ env.REGISTRY }}/api/v1/repos/${{ env.IMAGE_NAME }}/releases" \
-H "Authorization: token ${{ secrets.RELEASE_TOKEN }}" \ -H "Authorization: token ${{ secrets.RELEASE_TOKEN }}" \
+85
View File
@@ -0,0 +1,85 @@
<script lang="ts">
import type { Snippet } from 'svelte';
let {
variant = 'primary',
size = 'md',
disabled = false,
type = 'button',
href,
onclick,
children,
class: extraClass = '',
}: {
variant?: 'primary' | 'secondary' | 'danger' | 'ghost';
size?: 'sm' | 'md';
disabled?: boolean;
type?: 'button' | 'submit';
href?: string;
onclick?: (e: MouseEvent) => void;
children: Snippet;
class?: string;
} = $props();
const baseClasses = 'inline-flex items-center justify-center gap-1.5 rounded-md text-sm font-medium transition-colors disabled:opacity-50';
const sizeClasses: Record<string, string> = {
sm: 'px-2.5 py-1 text-xs',
md: 'px-4 py-2',
};
const variantClasses: Record<string, string> = {
primary: 'btn-primary',
secondary: 'btn-secondary',
danger: 'btn-danger',
ghost: 'btn-ghost',
};
const classes = $derived(
`${baseClasses} ${sizeClasses[size]} ${variantClasses[variant]} ${extraClass}`.trim()
);
</script>
{#if href && !disabled}
<a {href} class={classes} onclick={onclick}>
{@render children()}
</a>
{:else}
<button {type} {disabled} class={classes} onclick={onclick}>
{@render children()}
</button>
{/if}
<style>
.btn-primary {
background: var(--color-primary);
color: var(--color-primary-foreground);
}
.btn-primary:hover:not(:disabled) {
opacity: 0.9;
}
.btn-secondary {
background: var(--color-muted);
color: var(--color-foreground);
border: 1px solid var(--color-border);
}
.btn-secondary:hover:not(:disabled) {
opacity: 0.8;
}
.btn-danger {
background: var(--color-error-fg);
color: white;
}
.btn-danger:hover:not(:disabled) {
opacity: 0.9;
}
.btn-ghost {
background: transparent;
color: var(--color-muted-foreground);
}
.btn-ghost:hover:not(:disabled) {
background: var(--color-muted);
color: var(--color-foreground);
}
</style>
@@ -1,5 +1,6 @@
<script lang="ts"> <script lang="ts">
import MdiIcon from './MdiIcon.svelte'; import MdiIcon from './MdiIcon.svelte';
import { t } from '$lib/i18n';
export interface EntityItem { export interface EntityItem {
value: string | number; value: string | number;
@@ -142,7 +143,7 @@
<div class="ep-list" bind:this={listEl} role="listbox"> <div class="ep-list" bind:this={listEl} role="listbox">
{#if filtered.length === 0} {#if filtered.length === 0}
<div class="ep-empty">No matches</div> <div class="ep-empty">{t('common.noMatches')}</div>
{:else} {:else}
{#each filtered as item, i} {#each filtered as item, i}
<button <button
@@ -0,0 +1,13 @@
<script lang="ts">
interface Props {
message: string;
class?: string;
}
let { message, class: className = '' }: Props = $props();
</script>
{#if message}
<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4 {className}">
{message}
</div>
{/if}
@@ -1,5 +1,6 @@
<script lang="ts"> <script lang="ts">
import MdiIcon from './MdiIcon.svelte'; import MdiIcon from './MdiIcon.svelte';
import { t } from '$lib/i18n';
export interface GridItem { export interface GridItem {
value: string | number; value: string | number;
@@ -117,7 +118,7 @@
</button> </button>
{/each} {/each}
{#if filtered.length === 0} {#if filtered.length === 0}
<div class="icon-grid-empty" style="grid-column: 1 / -1; text-align: center; padding: 0.75rem; color: var(--color-muted-foreground); font-size: 0.75rem;">No matches</div> <div class="icon-grid-empty" style="grid-column: 1 / -1; text-align: center; padding: 0.75rem; color: var(--color-muted-foreground); font-size: 0.75rem;">{t('common.noMatches')}</div>
{/if} {/if}
</div> </div>
</div> </div>
+2 -1
View File
@@ -1,6 +1,7 @@
<script lang="ts"> <script lang="ts">
import { onMount } from 'svelte'; import { onMount } from 'svelte';
import MdiIcon from './MdiIcon.svelte'; import MdiIcon from './MdiIcon.svelte';
import { t } from '$lib/i18n';
let { open = false, title = '', onclose, children } = $props<{ let { open = false, title = '', onclose, children } = $props<{
open: boolean; open: boolean;
@@ -93,7 +94,7 @@
> >
<div style="display: flex; align-items: center; justify-content: space-between; padding: 1.5rem 1.5rem 1rem;"> <div style="display: flex; align-items: center; justify-content: space-between; padding: 1.5rem 1.5rem 1rem;">
<h3 id="modal-title-{uniqueId}" style="font-size: 1.125rem; font-weight: 600;">{title}</h3> <h3 id="modal-title-{uniqueId}" style="font-size: 1.125rem; font-weight: 600;">{title}</h3>
<button class="modal-close" onclick={onclose} aria-label="Close"> <button class="modal-close" onclick={onclose} aria-label={t('common.close')}>
<MdiIcon name="mdiClose" size={18} /> <MdiIcon name="mdiClose" size={18} />
</button> </button>
</div> </div>
@@ -1,5 +1,6 @@
<script lang="ts"> <script lang="ts">
import MdiIcon from './MdiIcon.svelte'; import MdiIcon from './MdiIcon.svelte';
import { t } from '$lib/i18n';
export interface MultiEntityItem { export interface MultiEntityItem {
value: string; value: string;
@@ -132,7 +133,7 @@
<div class="mes-list" bind:this={listEl} role="listbox"> <div class="mes-list" bind:this={listEl} role="listbox">
{#if filtered.length === 0} {#if filtered.length === 0}
<div class="mes-empty">No matches</div> <div class="mes-empty">{t('common.noMatches')}</div>
{:else} {:else}
{#each filtered as item, i} {#each filtered as item, i}
{@const checked = (values || []).includes(item.value)} {@const checked = (values || []).includes(item.value)}
+1 -1
View File
@@ -56,7 +56,7 @@
{/if} {/if}
{/if} {/if}
</div> </div>
<button class="snack-close" onclick={() => removeSnack(snack.id)} aria-label="Dismiss"> <button class="snack-close" onclick={() => removeSnack(snack.id)} aria-label={t('common.dismiss')}>
<MdiIcon name="mdiClose" size={14} /> <MdiIcon name="mdiClose" size={14} />
</button> </button>
</div> </div>
+2
View File
@@ -110,6 +110,8 @@ export const previewTargetTypeItems = (): GridItem[] => [
{ value: 'email', icon: 'mdiEmailOutline', label: 'Email', desc: t('gridDesc.previewEmail') }, { value: 'email', icon: 'mdiEmailOutline', label: 'Email', desc: t('gridDesc.previewEmail') },
{ value: 'discord', icon: 'mdiChat', label: 'Discord', desc: t('gridDesc.previewDiscord') }, { value: 'discord', icon: 'mdiChat', label: 'Discord', desc: t('gridDesc.previewDiscord') },
{ value: 'slack', icon: 'mdiSlack', label: 'Slack', desc: t('gridDesc.previewSlack') }, { value: 'slack', icon: 'mdiSlack', label: 'Slack', desc: t('gridDesc.previewSlack') },
{ value: 'ntfy', icon: 'mdiBellOutline', label: 'ntfy', desc: t('gridDesc.previewNtfy') },
{ value: 'matrix', icon: 'mdiMatrix', label: 'Matrix', desc: t('gridDesc.previewMatrix') },
]; ];
// --- Provider type items (derived from descriptor registry) --- // --- Provider type items (derived from descriptor registry) ---
+28 -10
View File
@@ -36,7 +36,8 @@
"targetMatrix": "Matrix", "targetMatrix": "Matrix",
"targetBroadcast": "Broadcast", "targetBroadcast": "Broadcast",
"automation": "Automation", "automation": "Automation",
"actions": "Actions" "actions": "Actions",
"more": "More"
}, },
"auth": { "auth": {
"signIn": "Sign in", "signIn": "Sign in",
@@ -51,7 +52,9 @@
"creatingAccount": "Creating account...", "creatingAccount": "Creating account...",
"passwordMismatch": "Passwords do not match", "passwordMismatch": "Passwords do not match",
"passwordTooShort": "Password must be at least 8 characters", "passwordTooShort": "Password must be at least 8 characters",
"or": "or" "or": "or",
"loginFailed": "Login failed",
"setupFailed": "Setup failed"
}, },
"dashboard": { "dashboard": {
"title": "Dashboard", "title": "Dashboard",
@@ -150,7 +153,9 @@
"gpRefreshTokenHint": "Obtain from Google OAuth Playground (developers.google.com/oauthplayground) with the Photos Library API scope.", "gpRefreshTokenHint": "Obtain from Google OAuth Playground (developers.google.com/oauthplayground) with the Photos Library API scope.",
"gpAllFieldsRequired": "Client ID, Client Secret, and Refresh Token are all required", "gpAllFieldsRequired": "Client ID, Client Secret, and Refresh Token are all required",
"testAndSave": "Test & Save", "testAndSave": "Test & Save",
"saveWithoutTest": "Save without testing" "saveWithoutTest": "Save without testing",
"selectType": "Select a provider type",
"testFailed": "Connection test failed"
}, },
"notificationTracker": { "notificationTracker": {
"title": "Notification Trackers", "title": "Notification Trackers",
@@ -231,7 +236,8 @@
"noLink": "No Link", "noLink": "No Link",
"saveWithoutLinks": "Save without links", "saveWithoutLinks": "Save without links",
"createLinks": "Create {count} link(s)", "createLinks": "Create {count} link(s)",
"linksNote": "You can also create links manually in Immich." "linksNote": "You can also create links manually in Immich.",
"createdLinks": "Created {count} public link(s)"
}, },
"templates": { "templates": {
"title": "Templates", "title": "Templates",
@@ -409,7 +415,9 @@
"cacheTtl": "Media cache TTL (hours)", "cacheTtl": "Media cache TTL (hours)",
"cacheTtlHint": "How long to cache uploaded Telegram file_ids before re-uploading (default: 48h)", "cacheTtlHint": "How long to cache uploaded Telegram file_ids before re-uploading (default: 48h)",
"settingsSaved": "Settings saved", "settingsSaved": "Settings saved",
"noExternalDomain": "External domain URL not configured" "noExternalDomain": "External domain URL not configured",
"saveFailed": "Failed to save bot",
"webhookFailed": "Failed to register webhook"
}, },
"trackingConfig": { "trackingConfig": {
"title": "Tracking Configs", "title": "Tracking Configs",
@@ -584,7 +592,7 @@
"added_assets": "List of asset dicts (use {% for asset in added_assets %})", "added_assets": "List of asset dicts (use {% for asset in added_assets %})",
"removed_assets": "List of removed asset IDs (strings)", "removed_assets": "List of removed asset IDs (strings)",
"shared": "Whether album is shared (boolean)", "shared": "Whether album is shared (boolean)",
"target_type": "Target type: 'telegram' or 'webhook'", "target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
"has_videos": "Whether added assets contain videos (boolean)", "has_videos": "Whether added assets contain videos (boolean)",
"has_photos": "Whether added assets contain photos (boolean)", "has_photos": "Whether added assets contain photos (boolean)",
"old_name": "Previous album name (rename events)", "old_name": "Previous album name (rename events)",
@@ -675,7 +683,8 @@
"displayName": "Display Name", "displayName": "Display Name",
"testConnection": "Test connection", "testConnection": "Test connection",
"noBots": "No Matrix bots yet.", "noBots": "No Matrix bots yet.",
"confirmDelete": "Delete this Matrix bot?" "confirmDelete": "Delete this Matrix bot?",
"operationFailed": "Operation failed"
}, },
"emailBot": { "emailBot": {
"title": "Email Bots", "title": "Email Bots",
@@ -693,7 +702,8 @@
"useTls": "Use TLS/SSL", "useTls": "Use TLS/SSL",
"testConnection": "Send test email", "testConnection": "Send test email",
"noBots": "No email bots yet.", "noBots": "No email bots yet.",
"confirmDelete": "Delete this email bot?" "confirmDelete": "Delete this email bot?",
"operationFailed": "Operation failed"
}, },
"cmdTemplateConfig": { "cmdTemplateConfig": {
"title": "Command Templates", "title": "Command Templates",
@@ -841,7 +851,12 @@
"allTypes": "All types", "allTypes": "All types",
"allProviders": "All providers", "allProviders": "All providers",
"noFilterResults": "No items match the current filter.", "noFilterResults": "No items match the current filter.",
"redirecting": "Redirecting..." "redirecting": "Redirecting...",
"noMatches": "No matches",
"saveFailed": "Save failed",
"loadFailed": "Failed to load data",
"dismiss": "Dismiss",
"systemSuffix": " (System)"
}, },
"templateSlot": { "templateSlot": {
"message_assets_added": "New assets added to album", "message_assets_added": "New assets added to album",
@@ -926,12 +941,15 @@
"previewEmail": "Preview with email HTML format", "previewEmail": "Preview with email HTML format",
"previewDiscord": "Preview with Discord markdown", "previewDiscord": "Preview with Discord markdown",
"previewSlack": "Preview with Slack markdown", "previewSlack": "Preview with Slack markdown",
"previewNtfy": "Preview as ntfy notification",
"previewMatrix": "Preview with Matrix HTML format",
"providerImmich": "Self-hosted photo server", "providerImmich": "Self-hosted photo server",
"providerGitea": "Self-hosted Git service", "providerGitea": "Self-hosted Git service",
"providerPlanka": "Self-hosted Kanban board", "providerPlanka": "Self-hosted Kanban board",
"providerScheduler": "Time-based scheduled messages", "providerScheduler": "Time-based scheduled messages",
"providerNut": "Network UPS monitoring", "providerNut": "Network UPS monitoring",
"providerGooglePhotos": "Google Photos albums & shared libraries" "providerGooglePhotos": "Google Photos albums & shared libraries",
"providerWebhook": "Receive events via HTTP POST"
}, },
"error": { "error": {
"notFound": "Page not found", "notFound": "Page not found",
+28 -10
View File
@@ -36,7 +36,8 @@
"targetMatrix": "Matrix", "targetMatrix": "Matrix",
"targetBroadcast": "Рассылка", "targetBroadcast": "Рассылка",
"automation": "Автоматизация", "automation": "Автоматизация",
"actions": "Действия" "actions": "Действия",
"more": "Ещё"
}, },
"auth": { "auth": {
"signIn": "Войти", "signIn": "Войти",
@@ -51,7 +52,9 @@
"creatingAccount": "Создание...", "creatingAccount": "Создание...",
"passwordMismatch": "Пароли не совпадают", "passwordMismatch": "Пароли не совпадают",
"passwordTooShort": "Пароль должен быть не менее 8 символов", "passwordTooShort": "Пароль должен быть не менее 8 символов",
"or": "или" "or": "или",
"loginFailed": "Ошибка входа",
"setupFailed": "Ошибка настройки"
}, },
"dashboard": { "dashboard": {
"title": "Главная", "title": "Главная",
@@ -150,7 +153,9 @@
"gpRefreshTokenHint": "Получите через Google OAuth Playground (developers.google.com/oauthplayground) с областью Photos Library API.", "gpRefreshTokenHint": "Получите через Google OAuth Playground (developers.google.com/oauthplayground) с областью Photos Library API.",
"gpAllFieldsRequired": "Client ID, Client Secret и Refresh Token обязательны", "gpAllFieldsRequired": "Client ID, Client Secret и Refresh Token обязательны",
"testAndSave": "Проверить и сохранить", "testAndSave": "Проверить и сохранить",
"saveWithoutTest": "Сохранить без проверки" "saveWithoutTest": "Сохранить без проверки",
"selectType": "Выберите тип провайдера",
"testFailed": "Ошибка проверки подключения"
}, },
"notificationTracker": { "notificationTracker": {
"title": "Трекеры уведомлений", "title": "Трекеры уведомлений",
@@ -231,7 +236,8 @@
"noLink": "Нет ссылки", "noLink": "Нет ссылки",
"saveWithoutLinks": "Сохранить без ссылок", "saveWithoutLinks": "Сохранить без ссылок",
"createLinks": "Создать {count} ссылку(и)", "createLinks": "Создать {count} ссылку(и)",
"linksNote": "Вы также можете создать ссылки вручную в Immich." "linksNote": "Вы также можете создать ссылки вручную в Immich.",
"createdLinks": "Создано публичных ссылок: {count}"
}, },
"templates": { "templates": {
"title": "Шаблоны", "title": "Шаблоны",
@@ -409,7 +415,9 @@
"cacheTtl": "TTL кэша медиа (часы)", "cacheTtl": "TTL кэша медиа (часы)",
"cacheTtlHint": "Сколько хранить кэш Telegram file_id перед повторной загрузкой (по умолчанию: 48ч)", "cacheTtlHint": "Сколько хранить кэш Telegram file_id перед повторной загрузкой (по умолчанию: 48ч)",
"settingsSaved": "Настройки сохранены", "settingsSaved": "Настройки сохранены",
"noExternalDomain": "Внешний URL домена не настроен" "noExternalDomain": "Внешний URL домена не настроен",
"saveFailed": "Не удалось сохранить бота",
"webhookFailed": "Не удалось зарегистрировать webhook"
}, },
"trackingConfig": { "trackingConfig": {
"title": "Конфигурации отслеживания", "title": "Конфигурации отслеживания",
@@ -584,7 +592,7 @@
"added_assets": "Список файлов ({% for asset in added_assets %})", "added_assets": "Список файлов ({% for asset in added_assets %})",
"removed_assets": "Список ID удалённых файлов (строки)", "removed_assets": "Список ID удалённых файлов (строки)",
"shared": "Общий альбом (boolean)", "shared": "Общий альбом (boolean)",
"target_type": "Тип получателя: 'telegram' или 'webhook'", "target_type": "Тип получателя: telegram, webhook, email, discord, slack, ntfy или matrix",
"has_videos": "Содержат ли добавленные файлы видео (boolean)", "has_videos": "Содержат ли добавленные файлы видео (boolean)",
"has_photos": "Содержат ли добавленные файлы фото (boolean)", "has_photos": "Содержат ли добавленные файлы фото (boolean)",
"old_name": "Прежнее название альбома (при переименовании)", "old_name": "Прежнее название альбома (при переименовании)",
@@ -675,7 +683,8 @@
"displayName": "Отображаемое имя", "displayName": "Отображаемое имя",
"testConnection": "Проверить подключение", "testConnection": "Проверить подключение",
"noBots": "Matrix ботов пока нет.", "noBots": "Matrix ботов пока нет.",
"confirmDelete": "Удалить этот Matrix бот?" "confirmDelete": "Удалить этот Matrix бот?",
"operationFailed": "Операция не удалась"
}, },
"emailBot": { "emailBot": {
"title": "Email боты", "title": "Email боты",
@@ -693,7 +702,8 @@
"useTls": "Использовать TLS/SSL", "useTls": "Использовать TLS/SSL",
"testConnection": "Отправить тестовое письмо", "testConnection": "Отправить тестовое письмо",
"noBots": "Email ботов пока нет.", "noBots": "Email ботов пока нет.",
"confirmDelete": "Удалить этот email бот?" "confirmDelete": "Удалить этот email бот?",
"operationFailed": "Операция не удалась"
}, },
"cmdTemplateConfig": { "cmdTemplateConfig": {
"title": "Шаблоны команд", "title": "Шаблоны команд",
@@ -841,7 +851,12 @@
"allTypes": "Все типы", "allTypes": "Все типы",
"allProviders": "Все провайдеры", "allProviders": "Все провайдеры",
"noFilterResults": "Нет элементов, соответствующих фильтру.", "noFilterResults": "Нет элементов, соответствующих фильтру.",
"redirecting": "Перенаправление..." "redirecting": "Перенаправление...",
"noMatches": "Ничего не найдено",
"saveFailed": "Не удалось сохранить",
"loadFailed": "Не удалось загрузить данные",
"dismiss": "Закрыть",
"systemSuffix": " (Системный)"
}, },
"templateSlot": { "templateSlot": {
"message_assets_added": "Новые файлы добавлены в альбом", "message_assets_added": "Новые файлы добавлены в альбом",
@@ -926,12 +941,15 @@
"previewEmail": "Предпросмотр в формате Email HTML", "previewEmail": "Предпросмотр в формате Email HTML",
"previewDiscord": "Предпросмотр в формате Discord", "previewDiscord": "Предпросмотр в формате Discord",
"previewSlack": "Предпросмотр в формате Slack", "previewSlack": "Предпросмотр в формате Slack",
"previewNtfy": "Предпросмотр уведомления ntfy",
"previewMatrix": "Предпросмотр в формате Matrix HTML",
"providerImmich": "Фотосервер для самостоятельного размещения", "providerImmich": "Фотосервер для самостоятельного размещения",
"providerGitea": "Git-сервер для самостоятельного размещения", "providerGitea": "Git-сервер для самостоятельного размещения",
"providerPlanka": "Канбан-доска для самостоятельного размещения", "providerPlanka": "Канбан-доска для самостоятельного размещения",
"providerScheduler": "Запланированные сообщения по расписанию", "providerScheduler": "Запланированные сообщения по расписанию",
"providerNut": "Мониторинг ИБП через NUT", "providerNut": "Мониторинг ИБП через NUT",
"providerGooglePhotos": "Альбомы и общие библиотеки Google Фото" "providerGooglePhotos": "Альбомы и общие библиотеки Google Фото",
"providerWebhook": "Приём событий через HTTP POST"
}, },
"error": { "error": {
"notFound": "Страница не найдена", "notFound": "Страница не найдена",
+63 -5
View File
@@ -215,15 +215,31 @@
}); });
} }
// Mobile: flatten nav for bottom bar // Mobile: flatten nav for bottom bar (first 4 + "More" button)
const mobileNavItems = $derived<NavItem[]>([ const mobileNavItems = $derived<NavItem[]>([
{ href: '/', key: 'nav.dashboard', icon: 'mdiViewDashboard' }, { href: '/', key: 'nav.dashboard', icon: 'mdiViewDashboard' },
{ href: '/notification-trackers', key: 'nav.notification', icon: 'mdiBellOutline' }, { href: '/notification-trackers', key: 'nav.notification', icon: 'mdiBellOutline' },
{ href: '/command-trackers', key: 'nav.commands', icon: 'mdiConsoleLine' }, { href: '/command-trackers', key: 'nav.commands', icon: 'mdiConsoleLine' },
{ href: '/targets', key: 'nav.targets', icon: 'mdiTarget' }, { href: '/targets', key: 'nav.targets', icon: 'mdiTarget' },
{ href: '/bots?tab=telegram', key: 'nav.bots', icon: 'mdiRobot' },
]); ]);
// "More" panel items — everything not in the bottom bar
const mobileMoreItems = $derived<NavItem[]>([
{ href: '/providers', key: 'nav.providers', icon: 'mdiServer' },
{ href: '/bots?tab=telegram', key: 'nav.bots', icon: 'mdiRobot' },
{ href: '/actions', key: 'nav.actions', icon: 'mdiPlayCircleOutline' },
{ href: '/tracking-configs', key: 'nav.configs', icon: 'mdiCog' },
{ href: '/template-configs', key: 'nav.templates', icon: 'mdiFileDocumentEdit' },
{ href: '/command-configs', key: 'nav.configs', icon: 'mdiConsoleLine' },
{ href: '/command-template-configs', key: 'nav.templates', icon: 'mdiCodeBracesBox' },
...(auth.isAdmin ? [
{ href: '/settings', key: 'nav.settings', icon: 'mdiCogOutline' },
{ href: '/users', key: 'nav.users', icon: 'mdiAccountGroup' },
] : []),
]);
let mobileMoreOpen = $state(false);
const isAuthPage = $derived( const isAuthPage = $derived(
page.url.pathname === '/login' || page.url.pathname === '/setup' page.url.pathname === '/login' || page.url.pathname === '/setup'
); );
@@ -526,12 +542,50 @@
<MdiIcon name={item.icon} size={20} /> <MdiIcon name={item.icon} size={20} />
</a> </a>
{/each} {/each}
<button onclick={logout} aria-label={t('nav.logout')} <button onclick={() => openSearch?.()} aria-label={t('searchPalette.placeholder')}
class="flex flex-col items-center gap-0.5 px-2 py-1.5 text-xs" style="color: var(--color-muted-foreground);"> class="flex flex-col items-center gap-0.5 px-2 py-1.5 text-xs rounded-lg transition-all duration-200"
<MdiIcon name="mdiLogout" size={20} /> style="color: var(--color-muted-foreground);">
<MdiIcon name="mdiMagnify" size={20} />
</button>
<button onclick={() => mobileMoreOpen = !mobileMoreOpen} aria-label={t('nav.more')}
class="flex flex-col items-center gap-0.5 px-2 py-1.5 text-xs rounded-lg transition-all duration-200"
style="color: {mobileMoreOpen ? 'var(--color-primary)' : 'var(--color-muted-foreground)'};">
<MdiIcon name="mdiDotsHorizontal" size={20} />
</button> </button>
</nav> </nav>
<!-- Mobile "More" panel -->
{#if mobileMoreOpen}
<div class="mobile-more-backdrop" style="position: fixed; inset: 0; z-index: 49; background: rgba(0,0,0,0.4); backdrop-filter: blur(2px);"
onclick={() => mobileMoreOpen = false} role="presentation"></div>
<div class="mobile-more-panel" style="position: fixed; bottom: 3.25rem; left: 0; right: 0; z-index: 50; background: var(--color-sidebar); border-top: 1px solid var(--color-border); border-radius: 1rem 1rem 0 0; padding: 1rem; max-height: 60vh; overflow-y: auto;"
transition:slide={{ duration: 200, easing: cubicOut }}>
{#if allProviders.length > 1}
<div class="mb-3 pb-3" style="border-bottom: 1px solid var(--color-border);">
<IconGridSelect items={providerFilterItems} bind:value={providerFilterValue} columns={Math.min(providerFilterItems.length, 4)} compact />
</div>
{/if}
<div class="grid grid-cols-3 gap-2">
{#each mobileMoreItems as item}
<a href={item.href}
onclick={() => mobileMoreOpen = false}
class="flex flex-col items-center gap-1 p-3 rounded-lg transition-all duration-200"
style="color: {isActive(item.href) ? 'var(--color-primary)' : 'var(--color-muted-foreground)'}; background: {isActive(item.href) ? 'var(--color-sidebar-active)' : 'transparent'};"
>
<MdiIcon name={item.icon} size={20} />
<span class="text-xs text-center leading-tight">{t(item.key)}</span>
</a>
{/each}
<button onclick={() => { mobileMoreOpen = false; logout(); }}
class="flex flex-col items-center gap-1 p-3 rounded-lg transition-all duration-200"
style="color: var(--color-muted-foreground);">
<MdiIcon name="mdiLogout" size={20} />
<span class="text-xs text-center leading-tight">{t('nav.logout')}</span>
</button>
</div>
</div>
{/if}
<!-- Main content --> <!-- Main content -->
<main class="flex-1 overflow-auto pb-16 md:pb-0"> <main class="flex-1 overflow-auto pb-16 md:pb-0">
{#key page.url.pathname} {#key page.url.pathname}
@@ -579,6 +633,10 @@
<style> <style>
@media (max-width: 767px) { @media (max-width: 767px) {
.mobile-nav { display: flex !important; } .mobile-nav { display: flex !important; }
.mobile-more-panel a:hover,
.mobile-more-panel button:hover {
background: var(--color-muted);
}
} }
/* Provider filter chips */ /* Provider filter chips */
+3 -3
View File
@@ -231,7 +231,7 @@
</div> </div>
</Card> </Card>
{:else if status} {:else if status}
<div class="grid grid-cols-1 sm:grid-cols-3 gap-4 mb-8 stagger-children"> <div class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-4 gap-4 mb-8 stagger-children">
{#each statCards as card, i} {#each statCards as card, i}
<div class="stat-card" style="--accent: {card.color};"> <div class="stat-card" style="--accent: {card.color};">
<div class="stat-card-inner"> <div class="stat-card-inner">
@@ -289,7 +289,7 @@
<div class="flex items-center justify-center gap-1"> <div class="flex items-center justify-center gap-1">
{#if totalPages > 1} {#if totalPages > 1}
<button onclick={() => goToPage(currentPage - 1)} disabled={currentPage <= 1} <button onclick={() => goToPage(currentPage - 1)} disabled={currentPage <= 1}
class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-30 disabled:cursor-default"> class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-50 disabled:cursor-default">
<MdiIcon name="mdiChevronLeft" size={16} /> <MdiIcon name="mdiChevronLeft" size={16} />
</button> </button>
{#each Array.from({ length: totalPages }, (_, i) => i + 1) as page} {#each Array.from({ length: totalPages }, (_, i) => i + 1) as page}
@@ -305,7 +305,7 @@
{/if} {/if}
{/each} {/each}
<button onclick={() => goToPage(currentPage + 1)} disabled={currentPage >= totalPages} <button onclick={() => goToPage(currentPage + 1)} disabled={currentPage >= totalPages}
class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-30 disabled:cursor-default"> class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-50 disabled:cursor-default">
<MdiIcon name="mdiChevronRight" size={16} /> <MdiIcon name="mdiChevronRight" size={16} />
</button> </button>
{/if} {/if}
+6 -5
View File
@@ -10,6 +10,8 @@
import ConfirmModal from '$lib/components/ConfirmModal.svelte'; import ConfirmModal from '$lib/components/ConfirmModal.svelte';
import IconButton from '$lib/components/IconButton.svelte'; import IconButton from '$lib/components/IconButton.svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import Button from '$lib/components/Button.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import type { EmailBot } from '$lib/types'; import type { EmailBot } from '$lib/types';
let { onreload }: { onreload: () => Promise<void> } = $props(); let { onreload }: { onreload: () => Promise<void> } = $props();
@@ -72,22 +74,21 @@
try { try {
const res = await api(`/email-bots/${botId}/test`, { method: 'POST' }); const res = await api(`/email-bots/${botId}/test`, { method: 'POST' });
if (res.success) snackSuccess(t('snack.emailBotTestSent')); if (res.success) snackSuccess(t('snack.emailBotTestSent'));
else snackError(res.error || 'Failed'); else snackError(res.error || t('emailBot.operationFailed'));
} catch (err: any) { snackError(err.message); } } catch (err: any) { snackError(err.message); }
emailTesting = { ...emailTesting, [botId]: false }; emailTesting = { ...emailTesting, [botId]: false };
} }
</script> </script>
<PageHeader title={t('emailBot.title')} description={t('emailBot.description')}> <PageHeader title={t('emailBot.title')} description={t('emailBot.description')}>
<button onclick={() => { showEmailForm ? (showEmailForm = false, editingEmail = null) : openNewEmail(); }} <Button size="sm" onclick={() => { showEmailForm ? (showEmailForm = false, editingEmail = null) : openNewEmail(); }}>
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
{showEmailForm ? t('common.cancel') : t('emailBot.addBot')} {showEmailForm ? t('common.cancel') : t('emailBot.addBot')}
</button> </Button>
</PageHeader> </PageHeader>
{#if showEmailForm} {#if showEmailForm}
<Card class="mb-6"> <Card class="mb-6">
{#if error}<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>{/if} <ErrorBanner message={error} />
<form onsubmit={saveEmailBot} class="space-y-3"> <form onsubmit={saveEmailBot} class="space-y-3">
<div> <div>
<label for="ebot-name" class="block text-sm font-medium mb-1">{t('emailBot.name')}</label> <label for="ebot-name" class="block text-sm font-medium mb-1">{t('emailBot.name')}</label>
+6 -5
View File
@@ -10,6 +10,8 @@
import ConfirmModal from '$lib/components/ConfirmModal.svelte'; import ConfirmModal from '$lib/components/ConfirmModal.svelte';
import IconButton from '$lib/components/IconButton.svelte'; import IconButton from '$lib/components/IconButton.svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import Button from '$lib/components/Button.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import type { MatrixBot } from '$lib/types'; import type { MatrixBot } from '$lib/types';
let { onreload }: { onreload: () => Promise<void> } = $props(); let { onreload }: { onreload: () => Promise<void> } = $props();
@@ -70,22 +72,21 @@
try { try {
const res = await api(`/matrix-bots/${botId}/test`, { method: 'POST' }); const res = await api(`/matrix-bots/${botId}/test`, { method: 'POST' });
if (res.success) snackSuccess(t('snack.matrixBotTestOk')); if (res.success) snackSuccess(t('snack.matrixBotTestOk'));
else snackError(res.error || 'Failed'); else snackError(res.error || t('matrixBot.operationFailed'));
} catch (err: any) { snackError(err.message); } } catch (err: any) { snackError(err.message); }
matrixTesting = { ...matrixTesting, [botId]: false }; matrixTesting = { ...matrixTesting, [botId]: false };
} }
</script> </script>
<PageHeader title={t('matrixBot.title')} description={t('matrixBot.description')}> <PageHeader title={t('matrixBot.title')} description={t('matrixBot.description')}>
<button onclick={() => { showMatrixForm ? (showMatrixForm = false, editingMatrix = null) : openNewMatrix(); }} <Button size="sm" onclick={() => { showMatrixForm ? (showMatrixForm = false, editingMatrix = null) : openNewMatrix(); }}>
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
{showMatrixForm ? t('common.cancel') : t('matrixBot.addBot')} {showMatrixForm ? t('common.cancel') : t('matrixBot.addBot')}
</button> </Button>
</PageHeader> </PageHeader>
{#if showMatrixForm} {#if showMatrixForm}
<Card class="mb-6"> <Card class="mb-6">
{#if error}<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>{/if} <ErrorBanner message={error} />
<form onsubmit={saveMatrixBot} class="space-y-3"> <form onsubmit={saveMatrixBot} class="space-y-3">
<div> <div>
<label for="mbot-name" class="block text-sm font-medium mb-1">{t('matrixBot.name')}</label> <label for="mbot-name" class="block text-sm font-medium mb-1">{t('matrixBot.name')}</label>
@@ -12,6 +12,8 @@
import IconButton from '$lib/components/IconButton.svelte'; import IconButton from '$lib/components/IconButton.svelte';
import EntitySelect from '$lib/components/EntitySelect.svelte'; import EntitySelect from '$lib/components/EntitySelect.svelte';
import { snackSuccess, snackError, snackInfo } from '$lib/stores/snackbar.svelte'; import { snackSuccess, snackError, snackInfo } from '$lib/stores/snackbar.svelte';
import Button from '$lib/components/Button.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import type { TelegramBot, TelegramChat } from '$lib/types'; import type { TelegramBot, TelegramChat } from '$lib/types';
interface CommandTrackerSummary { id: number; name: string; icon?: string; enabled: boolean } interface CommandTrackerSummary { id: number; name: string; icon?: string; enabled: boolean }
@@ -186,7 +188,7 @@
try { try {
const res = await api<ApiResult>(`/telegram-bots/${botId}/sync-commands`, { method: 'POST' }); const res = await api<ApiResult>(`/telegram-bots/${botId}/sync-commands`, { method: 'POST' });
if (res.success) snackSuccess(t('telegramBot.commandsSynced')); if (res.success) snackSuccess(t('telegramBot.commandsSynced'));
else snackError(res.error || 'Failed'); else snackError(res.error || t('telegramBot.saveFailed'));
} catch (err: any) { snackError(err.message); } } catch (err: any) { snackError(err.message); }
modeChanging = { ...modeChanging, [botId]: false }; modeChanging = { ...modeChanging, [botId]: false };
} }
@@ -218,7 +220,7 @@
snackSuccess(res.verified ? t('telegramBot.webhookVerified') : t('telegramBot.webhookRegistered')); snackSuccess(res.verified ? t('telegramBot.webhookVerified') : t('telegramBot.webhookRegistered'));
await loadWebhookStatus(botId); await loadWebhookStatus(botId);
} else { } else {
snackError(res.error || 'Failed to register webhook'); snackError(res.error || t('telegramBot.webhookFailed'));
} }
} catch (err: any) { snackError(err.message); } } catch (err: any) { snackError(err.message); }
modeChanging = { ...modeChanging, [botId]: false }; modeChanging = { ...modeChanging, [botId]: false };
@@ -229,7 +231,7 @@
try { try {
const res = await api<ApiResult>(`/telegram-bots/${botId}/webhook/unregister`, { method: 'POST' }); const res = await api<ApiResult>(`/telegram-bots/${botId}/webhook/unregister`, { method: 'POST' });
if (res.success) { snackSuccess(t('telegramBot.webhookUnregistered')); await loadWebhookStatus(botId); } if (res.success) { snackSuccess(t('telegramBot.webhookUnregistered')); await loadWebhookStatus(botId); }
else snackError(res.error || 'Failed'); else snackError(res.error || t('telegramBot.saveFailed'));
} catch (err: any) { snackError(err.message); } } catch (err: any) { snackError(err.message); }
modeChanging = { ...modeChanging, [botId]: false }; modeChanging = { ...modeChanging, [botId]: false };
} }
@@ -260,7 +262,7 @@
try { try {
const res = await api<ApiResult>(`/telegram-bots/${botId}/chats/${chatId}/test?locale=${getLocale()}`, { method: 'POST' }); const res = await api<ApiResult>(`/telegram-bots/${botId}/chats/${chatId}/test?locale=${getLocale()}`, { method: 'POST' });
if (res.success) snackSuccess(t('snack.targetTestSent')); if (res.success) snackSuccess(t('snack.targetTestSent'));
else snackError(res.error || 'Failed'); else snackError(res.error || t('telegramBot.saveFailed'));
} catch (err: any) { snackError(err.message); } } catch (err: any) { snackError(err.message); }
chatTesting = { ...chatTesting, [key]: false }; chatTesting = { ...chatTesting, [key]: false };
} }
@@ -277,15 +279,14 @@
</script> </script>
<PageHeader title={t('telegramBot.title')} description={t('telegramBot.description')}> <PageHeader title={t('telegramBot.title')} description={t('telegramBot.description')}>
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }} <Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
{showForm ? t('common.cancel') : t('telegramBot.addBot')} {showForm ? t('common.cancel') : t('telegramBot.addBot')}
</button> </Button>
</PageHeader> </PageHeader>
{#if showForm} {#if showForm}
<Card class="mb-6"> <Card class="mb-6">
{#if error}<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>{/if} <ErrorBanner message={error} />
<form onsubmit={saveBot} class="space-y-3"> <form onsubmit={saveBot} class="space-y-3">
<div> <div>
<label for="bot-name" class="block text-sm font-medium mb-1">{t('telegramBot.name')}</label> <label for="bot-name" class="block text-sm font-medium mb-1">{t('telegramBot.name')}</label>
@@ -15,6 +15,7 @@
import IconGridSelect from '$lib/components/IconGridSelect.svelte'; import IconGridSelect from '$lib/components/IconGridSelect.svelte';
import { providerTypeItems, providerTypeFilterItems, responseModeItems } from '$lib/grid-items'; import { providerTypeItems, providerTypeFilterItems, responseModeItems } from '$lib/grid-items';
import EntitySelect from '$lib/components/EntitySelect.svelte'; import EntitySelect from '$lib/components/EntitySelect.svelte';
import Button from '$lib/components/Button.svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import { highlightFromUrl } from '$lib/highlight'; import { highlightFromUrl } from '$lib/highlight';
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte'; import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
@@ -37,7 +38,7 @@
let cmdTemplateConfigs = $derived(commandTemplateConfigsCache.items); let cmdTemplateConfigs = $derived(commandTemplateConfigsCache.items);
const templateItems = $derived(cmdTemplateConfigs const templateItems = $derived(cmdTemplateConfigs
.filter((c) => c.provider_type === form.provider_type) .filter((c) => c.provider_type === form.provider_type)
.map((c) => ({ value: c.id, label: c.name + (c.user_id === 0 ? ' (System)' : ''), icon: c.icon || 'mdiCodeBracesBox', desc: c.provider_type })) .map((c) => ({ value: c.id, label: c.name + (c.user_id === 0 ? t('common.systemSuffix') : ''), icon: c.icon || 'mdiCodeBracesBox', desc: c.provider_type }))
); );
let loaded = $state(false); let loaded = $state(false);
let showForm = $state(false); let showForm = $state(false);
@@ -151,10 +152,9 @@
</script> </script>
<PageHeader title={t('commandConfig.title')} description={t('commandConfig.description')}> <PageHeader title={t('commandConfig.title')} description={t('commandConfig.description')}>
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }} <Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
{showForm ? t('common.cancel') : t('commandConfig.newConfig')} {showForm ? t('common.cancel') : t('commandConfig.newConfig')}
</button> </Button>
</PageHeader> </PageHeader>
{#if !loaded}<Loading />{:else} {#if !loaded}<Loading />{:else}
+1 -1
View File
@@ -32,7 +32,7 @@
await login(username, password); await login(username, password);
window.location.href = '/'; window.location.href = '/';
} catch (err: any) { } catch (err: any) {
error = err.message || 'Login failed'; error = err.message || t('auth.loginFailed');
} }
submitting = false; submitting = false;
} }
@@ -17,6 +17,8 @@
import { providerDefaultIcon } from '$lib/grid-items'; import { providerDefaultIcon } from '$lib/grid-items';
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte'; import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
import { getDescriptor } from '$lib/providers'; import { getDescriptor } from '$lib/providers';
import Button from '$lib/components/Button.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import type { Tracker, TrackerTarget, TrackingConfig, TemplateConfig, NotificationTarget } from '$lib/types'; import type { Tracker, TrackerTarget, TrackingConfig, TemplateConfig, NotificationTarget } from '$lib/types';
import TrackerForm from './TrackerForm.svelte'; import TrackerForm from './TrackerForm.svelte';
@@ -119,7 +121,7 @@
capabilitiesCache.fetch(), capabilitiesCache.fetch(),
]); ]);
} catch (err: any) { } catch (err: any) {
loadError = err.message || 'Failed to load data'; loadError = err.message || t('common.loadFailed');
snackError(loadError); snackError(loadError);
} finally { loaded = true; highlightFromUrl(); } } finally { loaded = true; highlightFromUrl(); }
} }
@@ -212,7 +214,7 @@
} }
} }
} }
if (created > 0) snackSuccess(`Created ${created} public link(s)`); if (created > 0) snackSuccess(t('notificationTracker.createdLinks').replace('{count}', String(created)));
linkWarning = null; linkWarning = null;
linkCreating = false; linkCreating = false;
await doSave(); await doSave();
@@ -361,17 +363,16 @@
</script> </script>
<PageHeader title={t('notificationTracker.title')} description={t('notificationTracker.description')}> <PageHeader title={t('notificationTracker.title')} description={t('notificationTracker.description')}>
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }} <Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
{showForm ? t('notificationTracker.cancel') : t('notificationTracker.newTracker')} {showForm ? t('notificationTracker.cancel') : t('notificationTracker.newTracker')}
</button> </Button>
</PageHeader> </PageHeader>
{#if !loaded} {#if !loaded}
<Loading /> <Loading />
{:else if loadError} {:else if loadError}
<Card> <Card>
<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3">{loadError}</div> <ErrorBanner message={loadError} class="mb-0" />
</Card> </Card>
{:else if showForm} {:else if showForm}
<TrackerForm <TrackerForm
+7 -9
View File
@@ -12,12 +12,14 @@
import EmptyState from '$lib/components/EmptyState.svelte'; import EmptyState from '$lib/components/EmptyState.svelte';
import ConfirmModal from '$lib/components/ConfirmModal.svelte'; import ConfirmModal from '$lib/components/ConfirmModal.svelte';
import IconButton from '$lib/components/IconButton.svelte'; import IconButton from '$lib/components/IconButton.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import IconGridSelect from '$lib/components/IconGridSelect.svelte'; import IconGridSelect from '$lib/components/IconGridSelect.svelte';
import { providerTypeItems, providerDefaultIcon } from '$lib/grid-items'; import { providerTypeItems, providerDefaultIcon } from '$lib/grid-items';
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte'; import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import { highlightFromUrl } from '$lib/highlight'; import { highlightFromUrl } from '$lib/highlight';
import { getDescriptor, buildProviderFormDefaults } from '$lib/providers'; import { getDescriptor, buildProviderFormDefaults } from '$lib/providers';
import Button from '$lib/components/Button.svelte';
import type { ServiceProvider } from '$lib/types'; import type { ServiceProvider } from '$lib/types';
let allProviders = $derived(providersCache.items); let allProviders = $derived(providersCache.items);
@@ -136,10 +138,9 @@
</script> </script>
<PageHeader title={t('providers.title')} description={t('providers.description')}> <PageHeader title={t('providers.title')} description={t('providers.description')}>
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }} <Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
{showForm ? t('providers.cancel') : t('providers.addProvider')} {showForm ? t('providers.cancel') : t('providers.addProvider')}
</button> </Button>
</PageHeader> </PageHeader>
{#if !loaded} {#if !loaded}
@@ -158,9 +159,7 @@
{#if showForm} {#if showForm}
<div in:slide={{ duration: 200 }}> <div in:slide={{ duration: 200 }}>
<Card class="mb-6"> <Card class="mb-6">
{#if error} <ErrorBanner message={error} />
<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>
{/if}
<form onsubmit={save} class="space-y-3"> <form onsubmit={save} class="space-y-3">
<div> <div>
<label class="block text-sm font-medium mb-1">{t('providers.type')}</label> <label class="block text-sm font-medium mb-1">{t('providers.type')}</label>
@@ -211,10 +210,9 @@
<p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.webhookUrlHint')}</p> <p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.webhookUrlHint')}</p>
</div> </div>
{/if} {/if}
<button type="submit" disabled={submitting} <Button type="submit" disabled={submitting}>
class="px-4 py-2 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90 disabled:opacity-50">
{submitting ? t('providers.connecting') : (editing ? t('common.save') : t('providers.addProvider'))} {submitting ? t('providers.connecting') : (editing ? t('common.save') : t('providers.addProvider'))}
</button> </Button>
</form> </form>
</Card> </Card>
</div> </div>
+17 -18
View File
@@ -5,8 +5,11 @@
import Card from '$lib/components/Card.svelte'; import Card from '$lib/components/Card.svelte';
import IconPicker from '$lib/components/IconPicker.svelte'; import IconPicker from '$lib/components/IconPicker.svelte';
import IconGridSelect from '$lib/components/IconGridSelect.svelte'; import IconGridSelect from '$lib/components/IconGridSelect.svelte';
import { goto } from '$app/navigation';
import { providerTypeItems } from '$lib/grid-items'; import { providerTypeItems } from '$lib/grid-items';
import { getDescriptor, buildProviderFormDefaults } from '$lib/providers'; import { getDescriptor, buildProviderFormDefaults } from '$lib/providers';
import Button from '$lib/components/Button.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
let form = $state(buildProviderFormDefaults()); let form = $state(buildProviderFormDefaults());
let error = $state(''); let error = $state('');
@@ -16,7 +19,7 @@
async function testAndSave() { async function testAndSave() {
const desc = descriptor; const desc = descriptor;
if (!desc) { error = 'Select a provider type'; return; } if (!desc) { error = t('providers.selectType'); return; }
const { config, error: buildError } = desc.buildConfig(form, false); const { config, error: buildError } = desc.buildConfig(form, false);
if (buildError) { error = t(buildError); snackError(error); return; } if (buildError) { error = t(buildError); snackError(error); return; }
@@ -32,22 +35,22 @@
if (!result.ok) { if (!result.ok) {
await api(`/providers/${provider.id}`, { method: 'DELETE' }).catch(() => {}); await api(`/providers/${provider.id}`, { method: 'DELETE' }).catch(() => {});
createdId = null; createdId = null;
error = result.message || 'Connection test failed'; error = result.message || t('providers.testFailed');
snackError(error); snackError(error);
} else { } else {
snackSuccess(t('snack.providerSaved')); snackSuccess(t('snack.providerSaved'));
window.location.href = '/providers'; goto('/providers');
} }
} catch (e: any) { } catch (e: any) {
if (createdId) await api(`/providers/${createdId}`, { method: 'DELETE' }).catch(() => {}); if (createdId) await api(`/providers/${createdId}`, { method: 'DELETE' }).catch(() => {});
error = e.message || 'Test failed'; snackError(error); error = e.message || t('providers.testFailed'); snackError(error);
} }
finally { testing = false; } finally { testing = false; }
} }
async function saveWithoutTest() { async function saveWithoutTest() {
const desc = descriptor; const desc = descriptor;
if (!desc) { error = 'Select a provider type'; return; } if (!desc) { error = t('providers.selectType'); return; }
const { config, error: buildError } = desc.buildConfig(form, false); const { config, error: buildError } = desc.buildConfig(form, false);
if (buildError) { error = t(buildError); snackError(error); return; } if (buildError) { error = t(buildError); snackError(error); return; }
@@ -58,8 +61,8 @@
body: JSON.stringify({ type: form.type, name: form.name || desc.defaultName, icon: form.icon, config }), body: JSON.stringify({ type: form.type, name: form.name || desc.defaultName, icon: form.icon, config }),
}); });
snackSuccess(t('snack.providerSaved')); snackSuccess(t('snack.providerSaved'));
window.location.href = '/providers'; goto('/providers');
} catch (e: any) { error = e.message || 'Save failed'; snackError(error); } } catch (e: any) { error = e.message || t('common.saveFailed'); snackError(error); }
finally { saving = false; } finally { saving = false; }
} }
</script> </script>
@@ -112,22 +115,18 @@
</div> </div>
{/each} {/each}
{#if error} <ErrorBanner message={error} />
<p class="text-sm text-[var(--color-error-fg)]">{error}</p>
{/if}
<div class="flex gap-3 pt-2"> <div class="flex gap-3 pt-2">
<button onclick={testAndSave} disabled={testing || saving} <Button onclick={testAndSave} disabled={testing || saving}>
class="px-4 py-2 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90 disabled:opacity-50">
{testing ? t('providers.connecting') : t('providers.testAndSave')} {testing ? t('providers.connecting') : t('providers.testAndSave')}
</button> </Button>
<button onclick={saveWithoutTest} disabled={testing || saving} <Button variant="secondary" onclick={saveWithoutTest} disabled={testing || saving}>
class="px-4 py-2 bg-[var(--color-muted)] text-[var(--color-foreground)] rounded-md text-sm font-medium hover:opacity-80 disabled:opacity-50">
{saving ? t('common.loading') : t('providers.saveWithoutTest')} {saving ? t('common.loading') : t('providers.saveWithoutTest')}
</button> </Button>
<a href="/providers" class="px-4 py-2 bg-[var(--color-muted)] text-[var(--color-muted-foreground)] rounded-md text-sm font-medium hover:opacity-80"> <Button variant="secondary" href="/providers">
{t('common.cancel')} {t('common.cancel')}
</a> </Button>
</div> </div>
</div> </div>
</Card> </Card>
+6 -3
View File
@@ -7,10 +7,12 @@
import Loading from '$lib/components/Loading.svelte'; import Loading from '$lib/components/Loading.svelte';
import MdiIcon from '$lib/components/MdiIcon.svelte'; import MdiIcon from '$lib/components/MdiIcon.svelte';
import Hint from '$lib/components/Hint.svelte'; import Hint from '$lib/components/Hint.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
let loaded = $state(false); let loaded = $state(false);
let saving = $state(false); let saving = $state(false);
let error = $state('');
let settings = $state({ let settings = $state({
external_url: '', external_url: '',
telegram_webhook_secret: '', telegram_webhook_secret: '',
@@ -20,16 +22,16 @@
onMount(async () => { onMount(async () => {
try { try {
settings = await api('/settings'); settings = await api('/settings');
} catch (err: any) { snackError(err.message); } } catch (err: any) { error = err.message; snackError(err.message); }
finally { loaded = true; } finally { loaded = true; }
}); });
async function save() { async function save() {
saving = true; saving = true; error = '';
try { try {
settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) }); settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) });
snackSuccess(t('settings.saved')); snackSuccess(t('settings.saved'));
} catch (err: any) { snackError(err.message); } } catch (err: any) { error = err.message; snackError(err.message); }
saving = false; saving = false;
} }
</script> </script>
@@ -39,6 +41,7 @@
{#if !loaded} {#if !loaded}
<Loading /> <Loading />
{:else} {:else}
<ErrorBanner message={error} />
<div class="space-y-6"> <div class="space-y-6">
<!-- General section --> <!-- General section -->
<Card> <Card>
+1 -1
View File
@@ -25,7 +25,7 @@
try { try {
await setup(username, password); await setup(username, password);
window.location.href = '/'; window.location.href = '/';
} catch (err: any) { error = err.message || 'Setup failed'; } } catch (err: any) { error = err.message || t('auth.setupFailed'); }
submitting = false; submitting = false;
} }
</script> </script>
+2 -1
View File
@@ -15,6 +15,7 @@
import { chatActionItems } from '$lib/grid-items'; import { chatActionItems } from '$lib/grid-items';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import { highlightFromUrl } from '$lib/highlight'; import { highlightFromUrl } from '$lib/highlight';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import type { NotificationTarget, TargetReceiver, TelegramChat } from '$lib/types'; import type { NotificationTarget, TargetReceiver, TelegramChat } from '$lib/types';
import TargetForm from './TargetForm.svelte'; import TargetForm from './TargetForm.svelte';
@@ -419,7 +420,7 @@
{#if !loaded}<Loading />{:else} {#if !loaded}<Loading />{:else}
{#if loadError} {#if loadError}
<div class="mb-4 p-3 rounded-md text-sm bg-[var(--color-error-bg)] text-[var(--color-error-fg)]">{loadError}</div> <ErrorBanner message={loadError} />
{/if} {/if}
{#if showForm} {#if showForm}
@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@@ -68,14 +69,17 @@ class NotificationDispatcher:
Returns list of results (one per target). Returns list of results (one per target).
""" """
raw_results = await asyncio.gather(
*[self._send_to_target(event, t) for t in targets],
return_exceptions=True,
)
results = [] results = []
for target in targets: for raw in raw_results:
try: if isinstance(raw, Exception):
result = await self._send_to_target(event, target) _LOGGER.error("Failed to dispatch to target: %s", raw)
results.append(result) results.append({"success": False, "error": str(raw)})
except Exception as e: else:
_LOGGER.error("Failed to dispatch to target: %s", e) results.append(raw)
results.append({"success": False, "error": str(e)})
return results return results
def _resolve_template( def _resolve_template(
@@ -85,6 +85,20 @@ class GiteaClient:
return repos return repos
async def get_repo(self, owner: str, repo: str) -> dict[str, Any] | None:
"""Fetch a single repository by owner/repo name."""
try:
async with self._session.get(
f"{self._url}/api/v1/repos/{owner}/{repo}",
headers=self._headers,
) as response:
if response.status == 200:
return await response.json()
_LOGGER.warning("Failed to fetch repo %s/%s: HTTP %s", owner, repo, response.status)
except aiohttp.ClientError as err:
_LOGGER.warning("Failed to fetch repo %s/%s: %s", owner, repo, err)
return None
async def get_repo_issues( async def get_repo_issues(
self, owner: str, repo: str, state: str = "open", limit: int = 10, self, owner: str, repo: str, state: str = "open", limit: int = 10,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
@@ -14,12 +14,28 @@ _DEFAULT_PORT = 3493
_READ_TIMEOUT = 10.0 _READ_TIMEOUT = 10.0
_CONNECT_TIMEOUT = 5.0 _CONNECT_TIMEOUT = 5.0
# Allowed characters for NUT protocol identifiers (UPS names, variable names).
# Prevents command injection via newlines or special characters.
_SAFE_NAME_RE = re.compile(r"^[\w.\-]+$")
# Regex to parse VAR lines: VAR <ups> <name> "<value>" # Regex to parse VAR lines: VAR <ups> <name> "<value>"
_VAR_RE = re.compile(r'^VAR\s+(\S+)\s+(\S+)\s+"(.*)"$') _VAR_RE = re.compile(r'^VAR\s+(\S+)\s+(\S+)\s+"(.*)"$')
# Regex to parse UPS lines: UPS <name> "<description>" # Regex to parse UPS lines: UPS <name> "<description>"
_UPS_RE = re.compile(r'^UPS\s+(\S+)\s+"(.*)"$') _UPS_RE = re.compile(r'^UPS\s+(\S+)\s+"(.*)"$')
def _validate_name(value: str, label: str) -> None:
"""Validate that *value* is a safe NUT protocol identifier.
Raises ``NutClientError`` if *value* contains characters outside
``[\\w.\\-]``, which could be used for protocol command injection.
"""
if not _SAFE_NAME_RE.match(value):
raise NutClientError(
f"Invalid {label}: {value!r} contains disallowed characters"
)
class NutClientError(Exception): class NutClientError(Exception):
"""Error communicating with NUT server.""" """Error communicating with NUT server."""
@@ -91,6 +107,7 @@ class NutClient:
async def list_var(self, ups_name: str) -> dict[str, str]: async def list_var(self, ups_name: str) -> dict[str, str]:
"""Get all variables for a UPS device.""" """Get all variables for a UPS device."""
_validate_name(ups_name, "UPS name")
lines = await self._list_command(f"LIST VAR {ups_name}") lines = await self._list_command(f"LIST VAR {ups_name}")
variables: dict[str, str] = {} variables: dict[str, str] = {}
for line in lines: for line in lines:
@@ -101,6 +118,8 @@ class NutClient:
async def get_var(self, ups_name: str, var_name: str) -> str: async def get_var(self, ups_name: str, var_name: str) -> str:
"""Get a single variable value.""" """Get a single variable value."""
_validate_name(ups_name, "UPS name")
_validate_name(var_name, "variable name")
response = await self._command(f"GET VAR {ups_name} {var_name}") response = await self._command(f"GET VAR {ups_name} {var_name}")
m = _VAR_RE.match(response) m = _VAR_RE.match(response)
if m: if m:
@@ -10,7 +10,7 @@ from jinja2.sandbox import SandboxedEnvironment
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_env = SandboxedEnvironment(autoescape=False) _env = SandboxedEnvironment(autoescape=True)
def render_template(template_str: str, context: dict[str, Any]) -> str: def render_template(template_str: str, context: dict[str, Any]) -> str:
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from ..auth.dependencies import get_current_user from ..auth.dependencies import get_current_user
from ..database.engine import get_session from ..database.engine import get_session
from ..database.models import Action, ActionRule, User from ..database.models import Action, ActionRule, User
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -59,10 +60,9 @@ def _rule_response(rule: ActionRule) -> dict:
async def _get_user_action( async def _get_user_action(
session: AsyncSession, action_id: int, user: User session: AsyncSession, action_id: int, user: User
) -> Action: ) -> Action:
action = await session.get(Action, action_id) return await get_owned_entity(
if not action or action.user_id != user.id: session, Action, action_id, user.id, not_found_msg="Action not found",
raise HTTPException(status_code=404, detail="Action not found") )
return action
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -12,12 +12,10 @@ from pydantic import BaseModel
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from jinja2.sandbox import SandboxedEnvironment
from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined
from ..auth.dependencies import get_current_user from ..auth.dependencies import get_current_user
from ..database.engine import get_session from ..database.engine import get_session
from ..database.models import CommandTemplateConfig, CommandTemplateSlot, User from ..database.models import CommandTemplateConfig, CommandTemplateSlot, User
from .slot_helpers import load_slots, render_template_preview, save_slots
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -44,38 +42,11 @@ class CommandTemplateConfigUpdate(BaseModel):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]: async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]:
"""Load slots as {slot_name: {locale: template}}.""" return await load_slots(session, CommandTemplateSlot, config_id)
result = await session.exec(
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == config_id)
)
nested: dict[str, dict[str, str]] = {}
for s in result.all():
nested.setdefault(s.slot_name, {})[s.locale] = s.template
return nested
async def _save_slots(session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]]) -> None: async def _save_slots(session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]]) -> None:
"""Save slots from {slot_name: {locale: template}} format.""" await save_slots(session, CommandTemplateSlot, config_id, slots)
for slot_name, locale_map in slots.items():
for locale, template_text in locale_map.items():
result = await session.exec(
select(CommandTemplateSlot).where(
CommandTemplateSlot.config_id == config_id,
CommandTemplateSlot.slot_name == slot_name,
CommandTemplateSlot.locale == locale,
)
)
existing = result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(CommandTemplateSlot(
config_id=config_id,
slot_name=slot_name,
locale=locale,
template=template_text,
))
async def _response(session: AsyncSession, c: CommandTemplateConfig) -> dict[str, Any]: async def _response(session: AsyncSession, c: CommandTemplateConfig) -> dict[str, Any]:
@@ -367,18 +338,4 @@ async def preview_raw(
"wait": 15, "wait": 15,
} }
try: return render_template_preview(body.template, sample_ctx)
env = SandboxedEnvironment(autoescape=False)
env.from_string(body.template)
except TemplateSyntaxError as e:
return {"rendered": None, "error": e.message, "error_line": e.lineno}
try:
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
tmpl = strict_env.from_string(body.template)
rendered = tmpl.render(**sample_ctx)
return {"rendered": rendered}
except UndefinedError as e:
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
except Exception as e:
return {"rendered": None, "error": str(e), "error_line": None}
@@ -17,6 +17,7 @@ from ..database.models import (
TelegramBot, TelegramBot,
User, User,
) )
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -401,7 +402,7 @@ async def _listener_response(session: AsyncSession, l: CommandTrackerListener) -
async def _get_user_tracker( async def _get_user_tracker(
session: AsyncSession, tracker_id: int, user_id: int session: AsyncSession, tracker_id: int, user_id: int
) -> CommandTracker: ) -> CommandTracker:
tracker = await session.get(CommandTracker, tracker_id) return await get_owned_entity(
if not tracker or tracker.user_id != user_id: session, CommandTracker, tracker_id, user_id,
raise HTTPException(status_code=404, detail="Command tracker not found") not_found_msg="Command tracker not found",
return tracker )
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from ..auth.dependencies import get_current_user from ..auth.dependencies import get_current_user
from ..database.engine import get_session from ..database.engine import get_session
from ..database.models import EmailBot, User from ..database.models import EmailBot, User
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -156,7 +157,6 @@ def _response(bot: EmailBot) -> dict:
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> EmailBot: async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> EmailBot:
bot = await session.get(EmailBot, bot_id) return await get_owned_entity(
if not bot or bot.user_id != user_id: session, EmailBot, bot_id, user_id, not_found_msg="Email bot not found",
raise HTTPException(status_code=404, detail="Email bot not found") )
return bot
@@ -0,0 +1,25 @@
"""Shared helpers for API route modules."""
from typing import TypeVar
from fastapi import HTTPException
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
T = TypeVar("T", bound=SQLModel)
async def get_owned_entity(
session: AsyncSession,
model: type[T],
entity_id: int,
user_id: int,
*,
owner_field: str = "user_id",
not_found_msg: str = "Not found",
) -> T:
"""Fetch an entity by PK and verify ownership, or raise 404."""
entity = await session.get(model, entity_id)
if not entity or getattr(entity, owner_field) != user_id:
raise HTTPException(status_code=404, detail=not_found_msg)
return entity
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from ..auth.dependencies import get_current_user from ..auth.dependencies import get_current_user
from ..database.engine import get_session from ..database.engine import get_session
from ..database.models import MatrixBot, User from ..database.models import MatrixBot, User
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -108,33 +109,34 @@ async def test_matrix_bot(
bot = await _get_user_bot(session, bot_id, user.id) bot = await _get_user_bot(session, bot_id, user.id)
import aiohttp import aiohttp
async with aiohttp.ClientSession() as http: from ..services.http_session import get_http_session
# Verify token with /whoami http = await get_http_session()
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami" # Verify token with /whoami
headers = {"Authorization": f"Bearer {bot.access_token}"} whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
try: headers = {"Authorization": f"Bearer {bot.access_token}"}
async with http.get(whoami_url, headers=headers) as resp: try:
if resp.status != 200: async with http.get(whoami_url, headers=headers) as resp:
body = await resp.text() if resp.status != 200:
return {"success": False, "error": f"Auth failed: HTTP {resp.status}{body[:200]}"} body = await resp.text()
whoami = await resp.json() return {"success": False, "error": f"Auth failed: HTTP {resp.status}{body[:200]}"}
except aiohttp.ClientError as e: whoami = await resp.json()
return {"success": False, "error": f"Connection failed: {e}"} except aiohttp.ClientError as e:
return {"success": False, "error": f"Connection failed: {e}"}
result = {"success": True, "user_id": whoami.get("user_id", "")} result = {"success": True, "user_id": whoami.get("user_id", "")}
# Optionally send a test message # Optionally send a test message
if room_id: if room_id:
from notify_bridge_core.notifications.matrix.client import MatrixClient from notify_bridge_core.notifications.matrix.client import MatrixClient
client = MatrixClient(http, bot.homeserver_url, bot.access_token) client = MatrixClient(http, bot.homeserver_url, bot.access_token)
send_result = await client.send_message( send_result = await client.send_message(
room_id, room_id,
"Test message from Notify Bridge", "Test message from Notify Bridge",
html_message="<b>Test message</b> from Notify Bridge", html_message="<b>Test message</b> from Notify Bridge",
) )
result["send_result"] = send_result result["send_result"] = send_result
return result return result
def _response(bot: MatrixBot) -> dict: def _response(bot: MatrixBot) -> dict:
@@ -150,7 +152,6 @@ def _response(bot: MatrixBot) -> dict:
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> MatrixBot: async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> MatrixBot:
bot = await session.get(MatrixBot, bot_id) return await get_owned_entity(
if not bot or bot.user_id != user_id: session, MatrixBot, bot_id, user_id, not_found_msg="Matrix bot not found",
raise HTTPException(status_code=404, detail="Matrix bot not found") )
return bot
@@ -23,6 +23,7 @@ from ..database.models import (
) )
from ..services.notifier import send_test_notification from ..services.notifier import send_test_notification
from ..services.test_dispatch import dispatch_test_notification from ..services.test_dispatch import dispatch_test_notification
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -277,7 +278,7 @@ async def _tt_response(session: AsyncSession, tt: NotificationTrackerTarget) ->
async def _get_user_tracker( async def _get_user_tracker(
session: AsyncSession, tracker_id: int, user_id: int session: AsyncSession, tracker_id: int, user_id: int
) -> NotificationTracker: ) -> NotificationTracker:
tracker = await session.get(NotificationTracker, tracker_id) return await get_owned_entity(
if not tracker or tracker.user_id != user_id: session, NotificationTracker, tracker_id, user_id,
raise HTTPException(status_code=404, detail="Tracker not found") not_found_msg="Tracker not found",
return tracker )
@@ -18,6 +18,7 @@ from ..database.models import (
User, User,
) )
from ..services.scheduler import schedule_tracker, unschedule_tracker from ..services.scheduler import schedule_tracker, unschedule_tracker
from .helpers import get_owned_entity
from .notification_tracker_targets import _tt_response from .notification_tracker_targets import _tt_response
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -205,7 +206,7 @@ async def _tracker_response(session: AsyncSession, t: NotificationTracker) -> di
async def _get_user_tracker( async def _get_user_tracker(
session: AsyncSession, tracker_id: int, user_id: int session: AsyncSession, tracker_id: int, user_id: int
) -> NotificationTracker: ) -> NotificationTracker:
tracker = await session.get(NotificationTracker, tracker_id) return await get_owned_entity(
if not tracker or tracker.user_id != user_id: session, NotificationTracker, tracker_id, user_id,
raise HTTPException(status_code=404, detail="Tracker not found") not_found_msg="Tracker not found",
return tracker )
@@ -13,7 +13,12 @@ import aiohttp
from ..auth.dependencies import get_current_user from ..auth.dependencies import get_current_user
from ..database.engine import get_session from ..database.engine import get_session
from ..database.models import ServiceProvider, User from ..database.models import ServiceProvider, User
from ..services import make_immich_provider, make_gitea_provider, make_planka_provider, make_nut_provider, make_google_photos_provider from ..services import (
make_immich_provider, make_gitea_provider, make_planka_provider,
make_nut_provider, make_google_photos_provider, list_provider_collections,
)
from ..services.http_session import get_http_session
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -82,6 +87,20 @@ class GooglePhotosProviderConfig(BaseModel):
refresh_token: str refresh_token: str
class PayloadMapping(BaseModel):
variable: str
jsonpath: str
default: str | None = None
class WebhookProviderConfig(BaseModel):
auth_mode: str = "none"
webhook_secret: str | None = None
payload_mappings: list[PayloadMapping] = []
event_type_path: str | None = None
collection_path: str | None = None
_PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = { _PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
"immich": ImmichProviderConfig, "immich": ImmichProviderConfig,
"gitea": GiteaProviderConfig, "gitea": GiteaProviderConfig,
@@ -89,6 +108,7 @@ _PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
"scheduler": SchedulerProviderConfig, "scheduler": SchedulerProviderConfig,
"nut": NutProviderConfig, "nut": NutProviderConfig,
"google_photos": GooglePhotosProviderConfig, "google_photos": GooglePhotosProviderConfig,
"webhook": WebhookProviderConfig,
} }
@@ -106,6 +126,70 @@ def _validate_provider_config(provider_type: str, config: dict[str, Any]) -> Non
) )
async def _test_provider_connection(provider: ServiceProvider) -> dict[str, Any]:
"""Test provider connection and return the result dict.
For providers that lack optional credentials (gitea without api_token,
planka without api_key), returns a success stub.
"""
http_session = await get_http_session()
if provider.type == "immich":
immich = make_immich_provider(http_session, provider)
return await immich.test_connection()
if provider.type == "gitea":
if not provider.config.get("api_token"):
return {"ok": True, "message": "Gitea webhook-only mode (no API token for testing)"}
gitea = make_gitea_provider(http_session, provider)
return await gitea.test_connection()
if provider.type == "planka":
if not provider.config.get("api_key"):
return {"ok": True, "message": "Planka webhook-only mode (no API key for testing)"}
planka = make_planka_provider(http_session, provider)
return await planka.test_connection()
if provider.type == "nut":
nut = make_nut_provider(provider)
return await nut.test_connection()
if provider.type == "google_photos":
gp = make_google_photos_provider(http_session, provider)
return await gp.test_connection()
if provider.type in ("scheduler", "webhook"):
return {"ok": True, "message": "Virtual provider — always available"}
return {"ok": False, "message": f"Unknown provider type: {provider.type}"}
async def _validate_provider_connection(provider: ServiceProvider) -> dict[str, Any]:
"""Test provider connection. Raise HTTPException on failure.
Returns the test_result dict on success (caller may inspect extra fields
like ``external_domain``).
"""
try:
test_result = await _test_provider_connection(provider)
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
except OSError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", f"Cannot connect to {provider.type} provider"),
)
return test_result
@router.get("") @router.get("")
async def list_providers( async def list_providers(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
@@ -128,96 +212,15 @@ async def create_provider(
"""Add a new service provider (validates connection for known types).""" """Add a new service provider (validates connection for known types)."""
_validate_provider_config(body.type, body.config) _validate_provider_config(body.type, body.config)
# Validate connection for known provider types # Build a temporary ServiceProvider for connection testing
try: temp_provider = ServiceProvider(
if body.type == "immich": id=0, user_id=0, type=body.type, name=body.name, config=body.config,
from notify_bridge_core.providers.immich import ImmichServiceProvider )
config = body.config test_result = await _validate_provider_connection(temp_provider)
async with aiohttp.ClientSession() as http_session:
immich = ImmichServiceProvider(
http_session, config.get("url", ""), config.get("api_key", ""),
config.get("external_domain"), body.name,
)
test_result = await immich.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", f"Cannot connect to {body.type} provider"),
)
# Store external_domain from server config if available
if test_result.get("external_domain"):
config["external_domain"] = test_result["external_domain"]
elif body.type == "gitea": # Store external_domain from Immich server config if available
config = body.config if test_result.get("external_domain"):
# api_token is optional (webhook_secret is required, but token only for repo listing) body.config["external_domain"] = test_result["external_domain"]
if config.get("api_token"):
async with aiohttp.ClientSession() as http_session:
from notify_bridge_core.providers.gitea import GiteaServiceProvider
gitea = GiteaServiceProvider(
http_session, config.get("url", ""), config.get("api_token", ""), body.name,
)
test_result = await gitea.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Gitea"),
)
elif body.type == "planka":
config = body.config
if config.get("api_key"):
async with aiohttp.ClientSession() as http_session:
from notify_bridge_core.providers.planka import PlankaServiceProvider
planka = PlankaServiceProvider(
http_session, config.get("url", ""), config.get("api_key", ""), body.name,
)
test_result = await planka.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Planka"),
)
elif body.type == "nut":
nut = make_nut_provider(ServiceProvider(
id=0, user_id=0, type="nut", name=body.name, config=body.config,
))
test_result = await nut.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to NUT server"),
)
elif body.type == "google_photos":
config = body.config
async with aiohttp.ClientSession() as http_session:
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
gp = GooglePhotosServiceProvider(
http_session, config.get("client_id", ""), config.get("client_secret", ""),
config.get("refresh_token", ""), body.name,
)
test_result = await gp.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Google Photos"),
)
except HTTPException:
raise
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
except OSError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
# Scheduler: no validation needed (virtual provider)
provider = ServiceProvider( provider = ServiceProvider(
user_id=user.id, user_id=user.id,
@@ -307,78 +310,10 @@ async def update_provider(
provider.config = body.config provider.config = body.config
# Re-validate connection when config changes for known provider types # Re-validate connection when config changes for known provider types
if config_changed and provider.type == "immich": if config_changed:
try: test_result = await _validate_provider_connection(provider)
async with aiohttp.ClientSession() as http_session: if test_result.get("external_domain"):
immich = make_immich_provider(http_session, provider) provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
test_result = await immich.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", f"Cannot connect to {provider.type} provider"),
)
if test_result.get("external_domain"):
provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
elif config_changed and provider.type == "gitea":
if provider.config.get("api_token"):
try:
async with aiohttp.ClientSession() as http_session:
gitea = make_gitea_provider(http_session, provider)
test_result = await gitea.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Gitea"),
)
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
elif config_changed and provider.type == "planka":
if provider.config.get("api_key"):
try:
async with aiohttp.ClientSession() as http_session:
planka = make_planka_provider(http_session, provider)
test_result = await planka.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Planka"),
)
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
elif config_changed and provider.type == "nut":
nut = make_nut_provider(provider)
test_result = await nut.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to NUT server"),
)
elif config_changed and provider.type == "google_photos":
try:
async with aiohttp.ClientSession() as http_session:
gp = make_google_photos_provider(http_session, provider)
test_result = await gp.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Google Photos"),
)
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
session.add(provider) session.add(provider)
await session.commit() await session.commit()
@@ -408,39 +343,7 @@ async def test_provider(
): ):
"""Check if a service provider is reachable.""" """Check if a service provider is reachable."""
provider = await _get_user_provider(session, provider_id, user.id) provider = await _get_user_provider(session, provider_id, user.id)
return await _test_provider_connection(provider)
if provider.type == "immich":
async with aiohttp.ClientSession() as http_session:
immich = make_immich_provider(http_session, provider)
return await immich.test_connection()
if provider.type == "gitea":
if not provider.config.get("api_token"):
return {"ok": True, "message": "Gitea webhook-only mode (no API token for testing)"}
async with aiohttp.ClientSession() as http_session:
gitea = make_gitea_provider(http_session, provider)
return await gitea.test_connection()
if provider.type == "planka":
if not provider.config.get("api_key"):
return {"ok": True, "message": "Planka webhook-only mode (no API key for testing)"}
async with aiohttp.ClientSession() as http_session:
planka = make_planka_provider(http_session, provider)
return await planka.test_connection()
if provider.type == "scheduler":
return {"ok": True, "message": "Virtual provider — always available"}
if provider.type == "nut":
nut = make_nut_provider(provider)
return await nut.test_connection()
if provider.type == "google_photos":
async with aiohttp.ClientSession() as http_session:
gp = make_google_photos_provider(http_session, provider)
return await gp.test_connection()
return {"ok": False, "message": f"Unknown provider type: {provider.type}"}
@router.get("/{provider_id}/people") @router.get("/{provider_id}/people")
@@ -454,14 +357,14 @@ async def list_people(
if provider.type == "immich": if provider.type == "immich":
from notify_bridge_core.providers.immich.client import ImmichClient from notify_bridge_core.providers.immich.client import ImmichClient
async with aiohttp.ClientSession() as http_session: http_session = await get_http_session()
client = ImmichClient( client = ImmichClient(
http_session, http_session,
provider.config.get("url", ""), provider.config.get("url", ""),
provider.config.get("api_key", ""), provider.config.get("api_key", ""),
) )
people = await client.get_people() people = await client.get_people()
return [{"id": pid, "name": name} for pid, name in people.items()] return [{"id": pid, "name": name} for pid, name in people.items()]
return [] return []
@@ -475,35 +378,7 @@ async def list_collections(
"""Fetch collections from a service provider.""" """Fetch collections from a service provider."""
provider = await _get_user_provider(session, provider_id, user.id) provider = await _get_user_provider(session, provider_id, user.id)
if provider.type == "immich": return await list_provider_collections(provider)
async with aiohttp.ClientSession() as http_session:
immich = make_immich_provider(http_session, provider)
return await immich.list_collections()
if provider.type == "gitea":
if not provider.config.get("api_token"):
return []
async with aiohttp.ClientSession() as http_session:
gitea = make_gitea_provider(http_session, provider)
return await gitea.list_collections()
if provider.type == "planka":
if not provider.config.get("api_key"):
return []
async with aiohttp.ClientSession() as http_session:
planka = make_planka_provider(http_session, provider)
return await planka.list_collections()
if provider.type == "nut":
nut = make_nut_provider(provider)
return await nut.list_collections()
if provider.type == "google_photos":
async with aiohttp.ClientSession() as http_session:
gp = make_google_photos_provider(http_session, provider)
return await gp.list_collections()
return []
@router.get("/{provider_id}/albums/{album_id}/shared-links") @router.get("/{provider_id}/albums/{album_id}/shared-links")
@@ -517,19 +392,19 @@ async def get_album_shared_links(
provider = await _get_user_provider(session, provider_id, user.id) provider = await _get_user_provider(session, provider_id, user.id)
if provider.type == "immich": if provider.type == "immich":
async with aiohttp.ClientSession() as http_session: http_session = await get_http_session()
immich = make_immich_provider(http_session, provider) immich = make_immich_provider(http_session, provider)
links = await immich.client.get_shared_links(album_id) links = await immich.client.get_shared_links(album_id)
return [ return [
{ {
"id": link.id, "id": link.id,
"key": link.key, "key": link.key,
"has_password": link.has_password, "has_password": link.has_password,
"is_expired": link.is_expired, "is_expired": link.is_expired,
"is_accessible": link.is_accessible, "is_accessible": link.is_accessible,
} }
for link in links for link in links
] ]
return [] return []
@@ -545,15 +420,13 @@ async def create_album_shared_link(
provider = await _get_user_provider(session, provider_id, user.id) provider = await _get_user_provider(session, provider_id, user.id)
if provider.type == "immich": if provider.type == "immich":
async with aiohttp.ClientSession() as http_session: http_session = await get_http_session()
immich = make_immich_provider(http_session, provider) immich = make_immich_provider(http_session, provider)
success = await immich.client.create_shared_link(album_id) success = await immich.client.create_shared_link(album_id)
if success: if success:
return {"success": True} return {"success": True}
from fastapi import HTTPException raise HTTPException(status_code=400, detail="Failed to create shared link")
raise HTTPException(status_code=400, detail="Failed to create shared link")
from fastapi import HTTPException
raise HTTPException(status_code=400, detail="Provider type does not support shared links") raise HTTPException(status_code=400, detail="Provider type does not support shared links")
@@ -580,7 +453,7 @@ async def _get_user_provider(
session: AsyncSession, provider_id: int, user_id: int session: AsyncSession, provider_id: int, user_id: int
) -> ServiceProvider: ) -> ServiceProvider:
"""Get a provider owned by the user, or raise 404.""" """Get a provider owned by the user, or raise 404."""
provider = await session.get(ServiceProvider, provider_id) return await get_owned_entity(
if not provider or provider.user_id != user_id: session, ServiceProvider, provider_id, user_id,
raise HTTPException(status_code=404, detail="Provider not found") not_found_msg="Provider not found",
return provider )
@@ -0,0 +1,90 @@
"""Shared slot load/save and Jinja2 preview helpers for template config APIs."""
from typing import TypeVar
from jinja2 import StrictUndefined, TemplateSyntaxError, UndefinedError
from jinja2.sandbox import SandboxedEnvironment
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
S = TypeVar("S", bound=SQLModel)
async def load_slots(
session: AsyncSession,
slot_model: type[S],
config_id: int,
) -> dict[str, dict[str, str]]:
"""Load all template slots for a config as {slot_name: {locale: template}}.
Works for both TemplateSlot and CommandTemplateSlot they share the same
column names (config_id, slot_name, locale, template).
"""
result = await session.exec(
select(slot_model).where(slot_model.config_id == config_id) # type: ignore[attr-defined]
)
slots: dict[str, dict[str, str]] = {}
for s in result.all():
slots.setdefault(s.slot_name, {})[s.locale] = s.template # type: ignore[attr-defined]
return slots
async def save_slots(
session: AsyncSession,
slot_model: type[S],
config_id: int,
slots: dict[str, dict[str, str]],
) -> None:
"""Create or update template slots for a config (locale-aware).
Works for both TemplateSlot and CommandTemplateSlot.
"""
for slot_name, locale_map in slots.items():
for locale, template_text in locale_map.items():
result = await session.exec(
select(slot_model).where(
slot_model.config_id == config_id, # type: ignore[attr-defined]
slot_model.slot_name == slot_name, # type: ignore[attr-defined]
slot_model.locale == locale, # type: ignore[attr-defined]
)
)
existing = result.first()
if existing:
existing.template = template_text # type: ignore[attr-defined]
session.add(existing)
else:
session.add(slot_model(
config_id=config_id,
slot_name=slot_name,
locale=locale,
template=template_text,
))
def render_template_preview(template: str, context: dict) -> dict:
"""Two-pass Jinja2 render: syntax check, then strict render.
Returns a dict with either ``{"rendered": str}`` on success, or
``{"rendered": None, "error": str, ...}`` on failure.
"""
# Pass 1: syntax check (default Undefined — catches parse errors only)
try:
env = SandboxedEnvironment(autoescape=False)
env.from_string(template)
except TemplateSyntaxError as e:
return {
"rendered": None,
"error": e.message,
"error_line": e.lineno,
}
# Pass 2: render with StrictUndefined to catch unknown variables
try:
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
tmpl = strict_env.from_string(template)
rendered = tmpl.render(**context)
return {"rendered": rendered}
except UndefinedError as e:
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
except Exception as e:
return {"rendered": None, "error": str(e), "error_line": None}
@@ -112,8 +112,16 @@ async def get_nav_counts(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ):
"""Return entity counts for sidebar navigation badges.""" """Return entity counts for sidebar navigation badges.
counts = {}
Note: queries run sequentially because SQLAlchemy AsyncSession is NOT safe
for concurrent use within a single session (no asyncio.gather). We
minimise round-trips by combining user + system counts and per-type
target counts into single aggregate queries where possible.
"""
counts: dict[str, int] = {}
# --- 1) User-owned entity counts (one query per model) ---
for model, key in [ for model, key in [
(ServiceProvider, "providers"), (ServiceProvider, "providers"),
(NotificationTracker, "notification_trackers"), (NotificationTracker, "notification_trackers"),
@@ -132,7 +140,7 @@ async def get_nav_counts(
)).one() )).one()
counts[key] = count counts[key] = count
# System-owned entities (user_id=0) count as well # --- 2) Add system-owned counts (user_id=0) for shared entities ---
for model, key in [ for model, key in [
(TemplateConfig, "template_configs"), (TemplateConfig, "template_configs"),
(CommandTemplateConfig, "command_template_configs"), (CommandTemplateConfig, "command_template_configs"),
@@ -144,15 +152,22 @@ async def get_nav_counts(
)).one() )).one()
counts[key] += system_count counts[key] += system_count
# Per-type target counts for nav badges # --- 3) Per-type target counts in a single query using conditional aggregation ---
for target_type in ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"): target_types = ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix")
type_count = (await session.exec( type_counts_result = (await session.exec(
select(func.count()).select_from(NotificationTarget).where( select(
NotificationTarget.user_id == user.id, NotificationTarget.type,
NotificationTarget.type == target_type, func.count(),
) )
)).one() .where(
counts[f"targets_{target_type}"] = type_count NotificationTarget.user_id == user.id,
NotificationTarget.type.in_(target_types),
)
.group_by(NotificationTarget.type)
)).all()
type_counts_map = dict(type_counts_result)
for target_type in target_types:
counts[f"targets_{target_type}"] = type_counts_map.get(target_type, 0)
return counts return counts
@@ -12,6 +12,7 @@ from ..auth.dependencies import get_current_user
from ..database.engine import get_session from ..database.engine import get_session
from ..database.models import NotificationTarget, TargetReceiver, User from ..database.models import NotificationTarget, TargetReceiver, User
from ..services.notifier import send_to_receiver from ..services.notifier import send_to_receiver
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -170,7 +171,7 @@ def _response(r: TargetReceiver) -> dict:
async def _get_user_target(session: AsyncSession, target_id: int, user_id: int) -> NotificationTarget: async def _get_user_target(session: AsyncSession, target_id: int, user_id: int) -> NotificationTarget:
target = await session.get(NotificationTarget, target_id) return await get_owned_entity(
if not target or target.user_id != user_id: session, NotificationTarget, target_id, user_id,
raise HTTPException(status_code=404, detail="Target not found") not_found_msg="Target not found",
return target )
@@ -12,6 +12,7 @@ from ..auth.dependencies import get_current_user
from ..database.engine import get_session from ..database.engine import get_session
from ..database.models import NotificationTarget, NotificationTrackerTarget, TargetReceiver, TelegramBot, TelegramChat, User from ..database.models import NotificationTarget, NotificationTrackerTarget, TargetReceiver, TelegramBot, TelegramChat, User
from ..services.notifier import send_test_notification from ..services.notifier import send_test_notification
from .helpers import get_owned_entity
from .target_receivers import _receiver_key from .target_receivers import _receiver_key
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -306,8 +307,15 @@ async def _validate_broadcast_children(
return return
if exclude_target_id and exclude_target_id in child_ids: if exclude_target_id and exclude_target_id in child_ids:
raise HTTPException(status_code=400, detail="A broadcast target cannot include itself") raise HTTPException(status_code=400, detail="A broadcast target cannot include itself")
# Batch-load all children in a single IN query instead of N+1 individual fetches
children = (await session.exec(
select(NotificationTarget).where(NotificationTarget.id.in_(child_ids))
)).all()
children_by_id = {c.id: c for c in children}
for child_id in child_ids: for child_id in child_ids:
child = await session.get(NotificationTarget, child_id) child = children_by_id.get(child_id)
if not child or child.user_id != user_id: if not child or child.user_id != user_id:
raise HTTPException(status_code=400, detail=f"Child target {child_id} not found") raise HTTPException(status_code=400, detail=f"Child target {child_id} not found")
if child.type == "broadcast": if child.type == "broadcast":
@@ -378,7 +386,7 @@ def _safe_config(target: NotificationTarget) -> dict:
async def _get_user_target( async def _get_user_target(
session: AsyncSession, target_id: int, user_id: int session: AsyncSession, target_id: int, user_id: int
) -> NotificationTarget: ) -> NotificationTarget:
target = await session.get(NotificationTarget, target_id) return await get_owned_entity(
if not target or target.user_id != user_id: session, NotificationTarget, target_id, user_id,
raise HTTPException(status_code=404, detail="Target not found") not_found_msg="Target not found",
return target )
@@ -7,8 +7,6 @@ from pydantic import BaseModel
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
import aiohttp
from notify_bridge_core.notifications.telegram.client import TelegramClient from notify_bridge_core.notifications.telegram.client import TelegramClient
from ..auth.dependencies import get_current_user from ..auth.dependencies import get_current_user
@@ -19,6 +17,7 @@ from ..database.models import AppSetting, NotificationTarget, TargetReceiver, Te
from ..services.notifier import _get_test_message from ..services.notifier import _get_test_message
from ..services.telegram_poller import schedule_bot_polling, unschedule_bot_polling from ..services.telegram_poller import schedule_bot_polling, unschedule_bot_polling
from .app_settings import get_setting from .app_settings import get_setting
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -290,10 +289,11 @@ async def test_chat(
): ):
"""Send a test message to a chat via the bot.""" """Send a test message to a chat via the bot."""
bot = await _get_user_bot(session, bot_id, user.id) bot = await _get_user_bot(session, bot_id, user.id)
from ..services.http_session import get_http_session
message = _get_test_message(locale, "telegram") message = _get_test_message(locale, "telegram")
async with aiohttp.ClientSession() as http: http = await get_http_session()
client = TelegramClient(http, bot.token) client = TelegramClient(http, bot.token)
return await client.send_message(chat_id, message) return await client.send_message(chat_id, message)
class ChatUpdate(BaseModel): class ChatUpdate(BaseModel):
@@ -344,41 +344,44 @@ async def delete_chat(
async def _get_webhook_info(token: str) -> dict | None: async def _get_webhook_info(token: str) -> dict | None:
"""Call Telegram getWebhookInfo via TelegramClient.""" """Call Telegram getWebhookInfo via TelegramClient."""
async with aiohttp.ClientSession() as http: from ..services.http_session import get_http_session
client = TelegramClient(http, token) http = await get_http_session()
result = await client.get_webhook_info() client = TelegramClient(http, token)
return result.get("result") if result.get("success") else None result = await client.get_webhook_info()
return result.get("result") if result.get("success") else None
async def _get_me(token: str) -> dict | None: async def _get_me(token: str) -> dict | None:
"""Call Telegram getMe via TelegramClient.""" """Call Telegram getMe via TelegramClient."""
async with aiohttp.ClientSession() as http: from ..services.http_session import get_http_session
client = TelegramClient(http, token) http = await get_http_session()
result = await client.get_me() client = TelegramClient(http, token)
return result.get("result") if result.get("success") else None result = await client.get_me()
return result.get("result") if result.get("success") else None
async def _fetch_chats_from_telegram(token: str) -> list[dict]: async def _fetch_chats_from_telegram(token: str) -> list[dict]:
"""Fetch chats from Telegram getUpdates via TelegramClient.""" """Fetch chats from Telegram getUpdates via TelegramClient."""
async with aiohttp.ClientSession() as http: from ..services.http_session import get_http_session
client = TelegramClient(http, token) http = await get_http_session()
result = await client.get_updates(limit=100) client = TelegramClient(http, token)
if not result.get("success"): result = await client.get_updates(limit=100)
return [] if not result.get("success"):
return []
seen: dict[int, dict] = {} seen: dict[int, dict] = {}
for update in result.get("result", []): for update in result.get("result", []):
msg = update.get("message", {}) msg = update.get("message", {})
chat = msg.get("chat", {}) chat = msg.get("chat", {})
chat_id = chat.get("id") chat_id = chat.get("id")
if chat_id and chat_id not in seen: if chat_id and chat_id not in seen:
seen[chat_id] = { seen[chat_id] = {
"id": chat_id, "id": chat_id,
"title": chat.get("title") or (chat.get("first_name", "") + (" " + chat.get("last_name", "")).strip()), "title": chat.get("title") or (chat.get("first_name", "") + (" " + chat.get("last_name", "")).strip()),
"type": chat.get("type", "private"), "type": chat.get("type", "private"),
"username": chat.get("username", ""), "username": chat.get("username", ""),
} }
return list(seen.values()) return list(seen.values())
def _chat_response(c: TelegramChat) -> dict: def _chat_response(c: TelegramChat) -> dict:
@@ -410,10 +413,9 @@ def _bot_response(b: TelegramBot) -> dict:
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> TelegramBot: async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> TelegramBot:
bot = await session.get(TelegramBot, bot_id) return await get_owned_entity(
if not bot or bot.user_id != user_id: session, TelegramBot, bot_id, user_id, not_found_msg="Bot not found",
raise HTTPException(status_code=404, detail="Bot not found") )
return bot
@@ -13,12 +13,12 @@ from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from jinja2.sandbox import SandboxedEnvironment from jinja2.sandbox import SandboxedEnvironment
from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined
from ..auth.dependencies import get_current_user from ..auth.dependencies import get_current_user
from ..database.engine import get_session from ..database.engine import get_session
from ..database.models import TemplateConfig, TemplateSlot, User from ..database.models import TemplateConfig, TemplateSlot, User
from ..services.sample_context import _SAMPLE_CONTEXT from ..services.sample_context import _SAMPLE_CONTEXT
from .slot_helpers import load_slots, render_template_preview, save_slots
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -49,40 +49,13 @@ class TemplateConfigUpdate(BaseModel):
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]: async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]:
"""Load all template slots for a config as {slot_name: {locale: template}}.""" return await load_slots(session, TemplateSlot, config_id)
result = await session.exec(
select(TemplateSlot).where(TemplateSlot.config_id == config_id)
)
slots: dict[str, dict[str, str]] = {}
for s in result.all():
slots.setdefault(s.slot_name, {})[s.locale] = s.template
return slots
async def _save_slots( async def _save_slots(
session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]] session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]]
) -> None: ) -> None:
"""Create or update template slots for a config (locale-aware).""" await save_slots(session, TemplateSlot, config_id, slots)
for slot_name, locale_map in slots.items():
for locale, template_text in locale_map.items():
result = await session.exec(
select(TemplateSlot).where(
TemplateSlot.config_id == config_id,
TemplateSlot.slot_name == slot_name,
TemplateSlot.locale == locale,
)
)
existing = result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(TemplateSlot(
config_id=config_id,
slot_name=slot_name,
locale=locale,
template=template_text,
))
async def _response(session: AsyncSession, c: TemplateConfig) -> dict[str, Any]: async def _response(session: AsyncSession, c: TemplateConfig) -> dict[str, Any]:
@@ -155,7 +128,7 @@ async def get_template_variables(
"photo_count": "Total photo count in album", "photo_count": "Total photo count in album",
"video_count": "Total video count in album", "video_count": "Total video count in album",
"owner": "Album owner name", "owner": "Album owner name",
"target_type": "Target type: 'telegram' or 'webhook'", "target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
"has_videos": "Whether added assets contain videos (boolean)", "has_videos": "Whether added assets contain videos (boolean)",
"has_photos": "Whether added assets contain photos (boolean)", "has_photos": "Whether added assets contain photos (boolean)",
"has_oversized_videos": "Whether any video exceeds the target's size limit (boolean)", "has_oversized_videos": "Whether any video exceeds the target's size limit (boolean)",
@@ -206,7 +179,7 @@ async def get_template_variables(
} }
scheduled_vars = { scheduled_vars = {
"date": "Current date string", "date": "Current date string",
"target_type": "Target type: 'telegram' or 'webhook'", "target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
} }
return { return {
@@ -284,7 +257,7 @@ def _webhook_variables() -> dict:
"source_ip": "IP address of the webhook sender", "source_ip": "IP address of the webhook sender",
"raw_payload": "Full JSON payload as dict (use raw_payload.field or raw_payload | tojson)", "raw_payload": "Full JSON payload as dict (use raw_payload.field or raw_payload | tojson)",
"timestamp": "When the webhook was received", "timestamp": "When the webhook was received",
"target_type": "Target type: 'telegram' or 'webhook'", "target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
}, },
}, },
} }
@@ -529,7 +502,7 @@ async def preview_date_format(
class PreviewRequest(BaseModel): class PreviewRequest(BaseModel):
template: str template: str
target_type: str = "telegram" # "telegram" or "webhook" target_type: str = "telegram" # telegram, webhook, email, discord, slack, ntfy, matrix
date_format: str = "%d.%m.%Y, %H:%M UTC" date_format: str = "%d.%m.%Y, %H:%M UTC"
date_only_format: str = "%d.%m.%Y" date_only_format: str = "%d.%m.%Y"
@@ -545,33 +518,12 @@ async def preview_raw(
1. Parse with default Undefined (catches syntax errors) 1. Parse with default Undefined (catches syntax errors)
2. Render with StrictUndefined (catches unknown variables like {{ asset.a }}) 2. Render with StrictUndefined (catches unknown variables like {{ asset.a }})
""" """
# Pass 1: syntax check from datetime import datetime
ctx = {**_SAMPLE_CONTEXT, "target_type": body.target_type,
"date_format": body.date_format, "date_only_format": body.date_only_format}
# Format common_date using the provided date_only_format
try: try:
env = SandboxedEnvironment(autoescape=False) ctx["common_date"] = datetime(2026, 3, 19).strftime(body.date_only_format)
env.from_string(body.template) except (ValueError, TypeError):
except TemplateSyntaxError as e: ctx["common_date"] = "19.03.2026"
return { return render_template_preview(body.template, ctx)
"rendered": None,
"error": e.message,
"error_line": e.lineno,
}
# Pass 2: render with strict undefined to catch unknown variables
try:
from datetime import datetime
ctx = {**_SAMPLE_CONTEXT, "target_type": body.target_type,
"date_format": body.date_format, "date_only_format": body.date_only_format}
# Format common_date using the provided date_only_format
try:
ctx["common_date"] = datetime(2026, 3, 19).strftime(body.date_only_format)
except (ValueError, TypeError):
ctx["common_date"] = "19.03.2026"
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
tmpl = strict_env.from_string(body.template)
rendered = tmpl.render(**ctx)
return {"rendered": rendered}
except UndefinedError as e:
# Still a valid template syntactically, but references unknown variable
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
except Exception as e:
return {"rendered": None, "error": str(e), "error_line": None}
@@ -3,9 +3,18 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any from typing import Any
from ..database.models import CommandTracker, CommandConfig, ServiceProvider, TelegramBot from ..database.models import CommandConfig, CommandTracker, ServiceProvider, TelegramBot
@dataclass(frozen=True)
class CommandResponse:
"""A single response from one tracker's command execution."""
text: str | None = None
media: list[dict[str, Any]] = field(default_factory=list)
class ProviderCommandHandler(ABC): class ProviderCommandHandler(ABC):
@@ -14,6 +23,8 @@ class ProviderCommandHandler(ABC):
Each provider (Immich, Gitea, etc.) implements this interface to handle Each provider (Immich, Gitea, etc.) implements this interface to handle
its own set of commands. The dispatch layer routes commands to the its own set of commands. The dispatch layer routes commands to the
correct handler based on the provider type. correct handler based on the provider type.
Each handler call receives a single (tracker, config, provider) context.
""" """
provider_type: str provider_type: str
@@ -35,26 +46,28 @@ class ProviderCommandHandler(ABC):
count: int, count: int,
locale: str, locale: str,
response_mode: str, response_mode: str,
providers_map: dict[int, ServiceProvider], provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot, bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], tracker: CommandTracker,
) -> str | list[dict[str, Any]] | None: config: CommandConfig,
"""Handle a provider-specific command. ) -> CommandResponse | None:
"""Handle a provider-specific command for a single tracker.
Args: Args:
cmd: The command name (without '/'). cmd: The command name (without '/').
args: Arguments after the command. args: Arguments after the command.
count: Number of results to return. count: Number of results to return.
locale: User's locale ('en', 'ru'). locale: User's locale ('en', 'ru').
response_mode: 'media' or 'text'. response_mode: 'media' or 'text' (from this tracker's config).
providers_map: Provider instances keyed by ID. provider: The service provider instance for this tracker.
cmd_templates: Template slots {slot_name: {locale: template}}. cmd_templates: Template slots for this tracker's command template config.
bot: The Telegram bot instance. bot: The Telegram bot instance.
ctx_tuples: Command context tuples for this provider type. tracker: The command tracker being dispatched.
config: The command config for this tracker.
Returns: Returns:
Text response, media list, or None if unhandled. A CommandResponse, or None if unhandled.
""" """
def get_rate_categories(self) -> dict[str, str]: def get_rate_categories(self) -> dict[str, str]:
@@ -0,0 +1,67 @@
"""Shared command handler utilities to reduce boilerplate across providers."""
from __future__ import annotations
import logging
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import EventLog, NotificationTracker, ServiceProvider
_LOGGER = logging.getLogger(__name__)
async def get_trackers_for_provider(provider_id: int) -> list[NotificationTracker]:
"""Get notification trackers for a single provider."""
from .handler import _get_notification_trackers_for_providers
return await _get_notification_trackers_for_providers({provider_id})
async def get_last_event_str(tracker_ids: list[int]) -> str:
"""Get formatted timestamp of most recent event for given trackers.
Returns a 'YYYY-MM-DD HH:MM' string, or '-' if no events exist.
"""
if not tracker_ids:
return "-"
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(EventLog)
.where(EventLog.tracker_id.in_(tracker_ids))
.order_by(EventLog.created_at.desc())
.limit(1)
)
last_event = result.first()
return last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
def get_tracked_collection_ids(
provider: ServiceProvider,
trackers: list[NotificationTracker],
*,
max_items: int = 20,
) -> list[str]:
"""Get deduplicated collection IDs from trackers for a provider.
Iterates all trackers belonging to *provider*, collects IDs from both
``collection_ids`` and ``filters.collections``, deduplicates while
preserving order, and caps at *max_items*.
"""
seen: set[str] = set()
result: list[str] = []
for tracker in trackers:
if tracker.provider_id != provider.id:
continue
for cid in tracker.collection_ids or []:
if cid not in seen:
seen.add(cid)
result.append(cid)
for cid in (tracker.filters or {}).get("collections", []):
if cid not in seen:
seen.add(cid)
result.append(cid)
return result[:max_items]
@@ -2,27 +2,55 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from collections.abc import Callable, Coroutine
from typing import Any from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import ( from ..database.models import (
CommandConfig, CommandTracker, EventLog, CommandConfig, CommandTracker, ServiceProvider, TelegramBot,
NotificationTracker, ServiceProvider, TelegramBot,
) )
from ..services import make_gitea_provider from ..services import make_gitea_provider
from .base import ProviderCommandHandler from ..services.http_session import get_http_session
from .handler import _render_cmd_template, _get_notification_trackers_for_providers from .base import CommandResponse, ProviderCommandHandler
from .command_utils import get_last_event_str, get_tracked_collection_ids, get_trackers_for_provider
from .handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_GITEA_COMMANDS = {"status", "repos", "issues", "prs", "commits"} _GITEA_COMMANDS = {"status", "repos", "issues", "prs", "commits"}
def _get_tracked_repos(
provider: ServiceProvider,
trackers: list,
) -> list[tuple[ServiceProvider, str, str]]:
"""Get (provider, owner, repo) tuples from tracked collection_ids."""
if not provider.config.get("api_token"):
return []
collection_ids = get_tracked_collection_ids(provider, trackers)
repos: list[tuple[ServiceProvider, str, str]] = []
for full_name in collection_ids:
parts = full_name.split("/", 1)
if len(parts) == 2:
repos.append((provider, parts[0], parts[1]))
return repos
# ---------------------------------------------------------------------------
# Command dispatch table
# ---------------------------------------------------------------------------
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
"""Register a function in the text command dispatch table."""
name = fn.__name__.removeprefix("_cmd_")
_TEXT_COMMANDS[name] = fn
return fn
class GiteaCommandHandler(ProviderCommandHandler): class GiteaCommandHandler(ProviderCommandHandler):
"""Handles Gitea-specific bot commands.""" """Handles Gitea-specific bot commands."""
@@ -44,91 +72,35 @@ class GiteaCommandHandler(ProviderCommandHandler):
count: int, count: int,
locale: str, locale: str,
response_mode: str, response_mode: str,
providers_map: dict[int, ServiceProvider], provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot, bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], tracker: CommandTracker,
) -> str | list[dict[str, Any]] | None: config: CommandConfig,
if cmd == "status": ) -> CommandResponse | None:
ctx = await _cmd_status(providers_map) fn = _TEXT_COMMANDS.get(cmd)
return _render_cmd_template(cmd_templates, "status", locale, ctx) if fn is None:
if cmd == "repos": return None
ctx = await _cmd_repos(providers_map) ctx = await fn(provider, count)
return _render_cmd_template(cmd_templates, "repos", locale, ctx) return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
if cmd == "issues":
ctx = await _cmd_issues(providers_map, count)
return _render_cmd_template(cmd_templates, "issues", locale, ctx)
if cmd == "prs":
ctx = await _cmd_prs(providers_map, count)
return _render_cmd_template(cmd_templates, "prs", locale, ctx)
if cmd == "commits":
ctx = await _cmd_commits(providers_map, count)
return _render_cmd_template(cmd_templates, "commits", locale, ctx)
return None
def _get_tracked_repos( @_text_cmd
providers_map: dict[int, ServiceProvider], async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers: list[NotificationTracker], trackers = await get_trackers_for_provider(provider.id)
) -> list[tuple[ServiceProvider, str, str]]: tracked_repos = _get_tracked_repos(provider, trackers)
"""Get (provider, owner, repo) tuples from tracked collection_ids."""
repos: list[tuple[ServiceProvider, str, str]] = []
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "gitea":
continue
if not provider.config.get("api_token"):
continue
for full_name in (tracker.collection_ids or []):
parts = full_name.split("/", 1)
if len(parts) == 2:
repos.append((provider, parts[0], parts[1]))
# Also check filters.collections
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "gitea":
continue
if not provider.config.get("api_token"):
continue
for full_name in (tracker.filters or {}).get("collections", []):
parts = full_name.split("/", 1)
if len(parts) == 2:
entry = (provider, parts[0], parts[1])
if entry not in repos:
repos.append(entry)
return repos[:20] # Cap to prevent API hammering
# Get server version
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
# Get server version from first Gitea provider with token
server_version = "unknown" server_version = "unknown"
async with aiohttp.ClientSession() as http: if provider.config.get("api_token"):
for provider in providers_map.values(): http = await get_http_session()
if provider.type == "gitea" and provider.config.get("api_token"): gitea = make_gitea_provider(http, provider)
gitea = make_gitea_provider(http, provider) version = await gitea.client.get_server_version()
version = await gitea.client.get_server_version() if version:
if version: server_version = version
server_version = version
break
# Last event tracker_ids = [t.id for t in trackers]
engine = get_engine() last_str = await get_last_event_str(tracker_ids)
async with AsyncSession(engine) as session:
tracker_ids = [t.id for t in trackers]
if tracker_ids:
result = await session.exec(
select(EventLog)
.where(EventLog.tracker_id.in_(tracker_ids))
.order_by(EventLog.created_at.desc()).limit(1)
)
last_event = result.first()
else:
last_event = None
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
return { return {
"repos_count": len(tracked_repos), "repos_count": len(tracked_repos),
@@ -137,116 +109,139 @@ async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, An
} }
async def _cmd_repos(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: @_text_cmd
provider_ids = set(providers_map.keys()) async def _cmd_repos(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await _get_notification_trackers_for_providers(provider_ids) trackers = await get_trackers_for_provider(provider.id)
tracked_repos = _get_tracked_repos(providers_map, trackers) tracked_repos = _get_tracked_repos(provider, trackers)
repos_data: list[dict[str, Any]] = [] repos_data: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http: http = await get_http_session()
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider) async def _fetch_repo(prov: ServiceProvider, owner: str, repo: str) -> dict[str, Any]:
try: gitea = make_gitea_provider(http, prov)
all_repos = await gitea.client.get_repos(limit=50) # Use direct get_repo endpoint instead of listing all repos
for r in all_repos: r = await gitea.client.get_repo(owner, repo)
if r.get("full_name") == f"{owner}/{repo}": if r:
repos_data.append({ return {
"full_name": r.get("full_name", ""), "full_name": r.get("full_name", ""),
"description": r.get("description", ""), "description": r.get("description", ""),
"stars": r.get("stars_count", 0), "stars": r.get("stars_count", 0),
"url": r.get("html_url", ""), "url": r.get("html_url", ""),
}) }
break return {
else: "full_name": f"{owner}/{repo}",
repos_data.append({ "description": "",
"full_name": f"{owner}/{repo}", "stars": 0,
"description": "", "url": "",
"stars": 0, }
"url": "",
}) tasks = [_fetch_repo(prov, owner, repo) for prov, owner, repo in tracked_repos]
except Exception: results = await asyncio.gather(*tasks, return_exceptions=True)
repos_data.append({ for (prov, owner, repo), result in zip(tracked_repos, results):
"full_name": f"{owner}/{repo}", if isinstance(result, Exception):
"description": "?", _LOGGER.warning("Failed to fetch repo %s/%s: %s", owner, repo, result)
"stars": 0, repos_data.append({
"url": "", "full_name": f"{owner}/{repo}",
}) "description": "?",
"stars": 0,
"url": "",
})
else:
repos_data.append(result)
return {"repos": repos_data} return {"repos": repos_data}
async def _cmd_issues( @_text_cmd
providers_map: dict[int, ServiceProvider], count: int, async def _cmd_issues(provider: ServiceProvider, count: int) -> dict[str, Any]:
) -> dict[str, Any]: trackers = await get_trackers_for_provider(provider.id)
provider_ids = set(providers_map.keys()) tracked_repos = _get_tracked_repos(provider, trackers)
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
all_issues: list[dict[str, Any]] = [] all_issues: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http: http = await get_http_session()
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider) async def _fetch_issues(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
issues = await gitea.client.get_repo_issues(owner, repo, limit=count) gitea = make_gitea_provider(http, prov)
for issue in issues: return await gitea.client.get_repo_issues(owner, repo, limit=count)
all_issues.append({
"repo": f"{owner}/{repo}", tasks = [_fetch_issues(prov, owner, repo) for prov, owner, repo in tracked_repos]
"number": issue.get("number", 0), results = await asyncio.gather(*tasks, return_exceptions=True)
"title": issue.get("title", ""), for (prov, owner, repo), result in zip(tracked_repos, results):
"url": issue.get("html_url", ""), if isinstance(result, Exception):
"user": issue.get("user", {}).get("login", ""), _LOGGER.warning("Failed to fetch issues for %s/%s: %s", owner, repo, result)
"state": issue.get("state", ""), continue
}) for issue in result:
all_issues.append({
"repo": f"{owner}/{repo}",
"number": issue.get("number", 0),
"title": issue.get("title", ""),
"url": issue.get("html_url", ""),
"user": issue.get("user", {}).get("login", ""),
"state": issue.get("state", ""),
})
all_issues.sort(key=lambda i: i.get("number", 0), reverse=True) all_issues.sort(key=lambda i: i.get("number", 0), reverse=True)
return {"issues": all_issues[:count]} return {"issues": all_issues[:count]}
async def _cmd_prs( @_text_cmd
providers_map: dict[int, ServiceProvider], count: int, async def _cmd_prs(provider: ServiceProvider, count: int) -> dict[str, Any]:
) -> dict[str, Any]: trackers = await get_trackers_for_provider(provider.id)
provider_ids = set(providers_map.keys()) tracked_repos = _get_tracked_repos(provider, trackers)
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
all_prs: list[dict[str, Any]] = [] all_prs: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http: http = await get_http_session()
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider) async def _fetch_prs(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
prs = await gitea.client.get_repo_pulls(owner, repo, limit=count) gitea = make_gitea_provider(http, prov)
for pr in prs: return await gitea.client.get_repo_pulls(owner, repo, limit=count)
all_prs.append({
"repo": f"{owner}/{repo}", tasks = [_fetch_prs(prov, owner, repo) for prov, owner, repo in tracked_repos]
"number": pr.get("number", 0), results = await asyncio.gather(*tasks, return_exceptions=True)
"title": pr.get("title", ""), for (prov, owner, repo), result in zip(tracked_repos, results):
"url": pr.get("html_url", ""), if isinstance(result, Exception):
"user": pr.get("user", {}).get("login", ""), _LOGGER.warning("Failed to fetch PRs for %s/%s: %s", owner, repo, result)
"state": pr.get("state", ""), continue
}) for pr in result:
all_prs.append({
"repo": f"{owner}/{repo}",
"number": pr.get("number", 0),
"title": pr.get("title", ""),
"url": pr.get("html_url", ""),
"user": pr.get("user", {}).get("login", ""),
"state": pr.get("state", ""),
})
all_prs.sort(key=lambda p: p.get("number", 0), reverse=True) all_prs.sort(key=lambda p: p.get("number", 0), reverse=True)
return {"prs": all_prs[:count]} return {"prs": all_prs[:count]}
async def _cmd_commits( @_text_cmd
providers_map: dict[int, ServiceProvider], count: int, async def _cmd_commits(provider: ServiceProvider, count: int) -> dict[str, Any]:
) -> dict[str, Any]: trackers = await get_trackers_for_provider(provider.id)
provider_ids = set(providers_map.keys()) tracked_repos = _get_tracked_repos(provider, trackers)
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
all_commits: list[dict[str, Any]] = [] all_commits: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http: http = await get_http_session()
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider) async def _fetch_commits(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
commits = await gitea.client.get_repo_commits(owner, repo, limit=count) gitea = make_gitea_provider(http, prov)
for c in commits: return await gitea.client.get_repo_commits(owner, repo, limit=count)
commit_data = c.get("commit", {})
all_commits.append({ tasks = [_fetch_commits(prov, owner, repo) for prov, owner, repo in tracked_repos]
"repo": f"{owner}/{repo}", results = await asyncio.gather(*tasks, return_exceptions=True)
"short_id": c.get("sha", "")[:7], for (prov, owner, repo), result in zip(tracked_repos, results):
"message": commit_data.get("message", "").split("\n")[0][:80], if isinstance(result, Exception):
"author": commit_data.get("author", {}).get("name", ""), _LOGGER.warning("Failed to fetch commits for %s/%s: %s", owner, repo, result)
"url": c.get("html_url", ""), continue
}) for c in result:
commit_data = c.get("commit", {})
all_commits.append({
"repo": f"{owner}/{repo}",
"short_id": c.get("sha", "")[:7],
"message": commit_data.get("message", "").split("\n")[0][:80],
"author": commit_data.get("author", {}).get("name", ""),
"url": c.get("html_url", ""),
})
return {"commits": all_commits[:count]} return {"commits": all_commits[:count]}
@@ -4,6 +4,7 @@ from __future__ import annotations
import logging import logging
import time import time
from functools import lru_cache
from typing import Any from typing import Any
import aiohttp import aiohttp
@@ -25,17 +26,21 @@ from ..database.models import (
ServiceProvider, ServiceProvider,
TelegramBot, TelegramBot,
) )
from .base import CommandResponse
from .parser import parse_command from .parser import parse_command
from .registry import get_rate_category from .registry import get_rate_category
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# Singleton Jinja2 environment for template rendering (Phase 4d) # Singleton Jinja2 environment for template rendering (Phase 4d)
_JINJA_ENV = SandboxedEnvironment(autoescape=False) _JINJA_ENV = SandboxedEnvironment(autoescape=True)
# Rate limit state with automatic TTL expiry (Phase 4e) # Rate limit state with automatic TTL expiry (Phase 4e)
_rate_limits: TTLCache = TTLCache(maxsize=10000, ttl=3600) _rate_limits: TTLCache = TTLCache(maxsize=10000, ttl=3600)
# Maximum responses per command to avoid Telegram rate limits
_MAX_RESPONSES_PER_COMMAND = 5
def _check_rate_limit(bot_id: int, chat_id: str, cmd: str, limits: dict[str, int]) -> int | None: def _check_rate_limit(bot_id: int, chat_id: str, cmd: str, limits: dict[str, int]) -> int | None:
"""Check rate limit. Returns seconds to wait, or None if OK.""" """Check rate limit. Returns seconds to wait, or None if OK."""
@@ -60,6 +65,12 @@ def _resolve_template(
return locale_map.get(locale) or locale_map.get("en") return locale_map.get(locale) or locale_map.get("en")
@lru_cache(maxsize=256)
def _compile_template(template_str: str):
"""Cache compiled Jinja2 templates to avoid re-parsing identical strings."""
return _JINJA_ENV.from_string(template_str)
def _render_cmd_template( def _render_cmd_template(
templates: dict[str, dict[str, str]], slot_name: str, locale: str, templates: dict[str, dict[str, str]], slot_name: str, locale: str,
context: dict[str, Any], context: dict[str, Any],
@@ -70,20 +81,28 @@ def _render_cmd_template(
_LOGGER.warning("No command template found for slot '%s' locale '%s'", slot_name, locale) _LOGGER.warning("No command template found for slot '%s' locale '%s'", slot_name, locale)
return f"[No template: {slot_name}]" return f"[No template: {slot_name}]"
try: try:
tmpl = _JINJA_ENV.from_string(template_str) tmpl = _compile_template(template_str)
return tmpl.render(**context) return tmpl.render(**context)
except Exception as e: except Exception as e:
_LOGGER.warning("Failed to render command template '%s': %s", slot_name, e) _LOGGER.warning("Failed to render command template '%s': %s", slot_name, e)
return f"[Template error: {slot_name}]" return f"[Template error: {slot_name}]"
# ---------------------------------------------------------------------------
# Context resolution
# ---------------------------------------------------------------------------
async def _resolve_command_context( async def _resolve_command_context(
bot: TelegramBot, bot: TelegramBot,
) -> tuple[list[tuple[CommandTracker, CommandConfig, ServiceProvider]], dict[str, dict[str, str]]]: ) -> tuple[
list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
dict[int, dict[str, dict[str, str]]],
]:
"""Resolve all enabled command trackers, configs, and providers for a bot. """Resolve all enabled command trackers, configs, and providers for a bot.
Returns (context_tuples, cmd_template_slots). Returns:
cmd_template_slots is {slot_name: {locale: template}}. (context_tuples, templates_by_config_id)
templates_by_config_id is {command_template_config_id: {slot_name: {locale: template}}}.
""" """
engine = get_engine() engine = get_engine()
async with AsyncSession(engine) as session: async with AsyncSession(engine) as session:
@@ -142,8 +161,8 @@ async def _resolve_command_context(
continue continue
tuples.append((tracker, config, provider)) tuples.append((tracker, config, provider))
# Load command template slots — merge from all configs # Load command template slots per config (not merged)
cmd_template_slots: dict[str, dict[str, str]] = {} templates_by_config_id: dict[int, dict[str, dict[str, str]]] = {}
seen_config_ids: set[int] = set() seen_config_ids: set[int] = set()
for _, config, _ in tuples: for _, config, _ in tuples:
cfg_id = config.command_template_config_id cfg_id = config.command_template_config_id
@@ -154,98 +173,136 @@ async def _resolve_command_context(
CommandTemplateSlot.config_id == cfg_id CommandTemplateSlot.config_id == cfg_id
) )
) )
slots: dict[str, dict[str, str]] = {}
for s in slot_result.all(): for s in slot_result.all():
cmd_template_slots.setdefault(s.slot_name, {})[s.locale] = s.template slots.setdefault(s.slot_name, {})[s.locale] = s.template
templates_by_config_id[cfg_id] = slots
return tuples, cmd_template_slots return tuples, templates_by_config_id
def _merge_command_context( def _templates_for_config(
templates_by_config_id: dict[int, dict[str, dict[str, str]]],
config: CommandConfig,
) -> dict[str, dict[str, str]]:
"""Get template slots for a specific command config."""
cfg_id = config.command_template_config_id
if cfg_id and cfg_id in templates_by_config_id:
return templates_by_config_id[cfg_id]
return {}
def _merge_all_templates(
templates_by_config_id: dict[int, dict[str, dict[str, str]]],
) -> dict[str, dict[str, str]]:
"""Merge all template config slots into one dict (for universal commands)."""
merged: dict[str, dict[str, str]] = {}
for slots in templates_by_config_id.values():
for slot_name, locale_map in slots.items():
merged.setdefault(slot_name, {}).update(locale_map)
return merged
def _merge_enabled_commands(
ctx: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], ctx: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> tuple[list[str], str, int, dict[str, Any]]: ) -> tuple[list[str], dict[str, Any]]:
"""Merge enabled_commands from all configs and pick defaults from first config.""" """Merge enabled_commands (union) and rate_limits from all configs.
Rate limits use the most restrictive (minimum) cooldown per category.
"""
if not ctx: if not ctx:
return [], "media", 5, {} return [], {}
enabled: set[str] = set() enabled: set[str] = set()
merged_limits: dict[str, int] = {}
for _, config, _ in ctx: for _, config, _ in ctx:
enabled.update(config.enabled_commands or []) enabled.update(config.enabled_commands or [])
for category, cooldown in (config.rate_limits or {}).items():
if category not in merged_limits:
merged_limits[category] = cooldown
else:
merged_limits[category] = min(merged_limits[category], cooldown)
first_config = ctx[0][1] return sorted(enabled), merged_limits
response_mode = first_config.response_mode or "media"
default_count = first_config.default_count or 5
rate_limits = first_config.rate_limits or {}
return sorted(enabled), response_mode, default_count, rate_limits
# ---------------------------------------------------------------------------
# Main dispatcher
# ---------------------------------------------------------------------------
async def handle_command( async def handle_command(
bot: TelegramBot, bot: TelegramBot,
chat_id: str, chat_id: str,
text: str, text: str,
language_code: str = "", language_code: str = "",
) -> str | list[dict[str, Any]] | None: ) -> list[CommandResponse] | None:
"""Handle a bot command. Routes to provider-specific handlers. """Handle a bot command. Routes to provider-specific handlers.
Returns text response, media list, or None. Returns a list of CommandResponse objects (one per tracker), or None.
Universal commands (/start, /help) return a single-element list.
Provider-specific commands dispatch per-tracker with per-tracker config.
""" """
cmd, args, count_override = parse_command(text) cmd, args, count_override = parse_command(text)
if not cmd: if not cmd:
return None return None
ctx_tuples, cmd_templates = await _resolve_command_context(bot) ctx_tuples, templates_by_config_id = await _resolve_command_context(bot)
enabled, response_mode, default_count, rate_limits = _merge_command_context(ctx_tuples) enabled, rate_limits = _merge_enabled_commands(ctx_tuples)
locale = language_code[:2].lower() if language_code else "en" locale = language_code[:2].lower() if language_code else "en"
if locale not in ("en", "ru"): if locale not in ("en", "ru"):
locale = "en" locale = "en"
# Merged templates for universal commands
merged_templates = _merge_all_templates(templates_by_config_id)
if cmd == "start": if cmd == "start":
return _render_cmd_template(cmd_templates, "start", locale, {"bot_name": bot.name}) text_resp = _render_cmd_template(merged_templates, "start", locale, {"bot_name": bot.name})
return [CommandResponse(text=text_resp)]
if cmd not in enabled and cmd != "start": if cmd not in enabled and cmd != "start":
return None return None
# Rate limit check # Rate limit check (once per command, shared across all trackers)
wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits) wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits)
if wait is not None: if wait is not None:
return _render_cmd_template(cmd_templates, "rate_limited", locale, {"wait": wait}) text_resp = _render_cmd_template(merged_templates, "rate_limited", locale, {"wait": wait})
return [CommandResponse(text=text_resp)]
count = min(count_override or default_count, 20) # Universal commands — single merged response
# Build providers map from command context
providers_map: dict[int, ServiceProvider] = {}
for _, _, provider in ctx_tuples:
providers_map[provider.id] = provider
# Universal commands
if cmd == "help": if cmd == "help":
ctx = _cmd_help(enabled, locale, cmd_templates) ctx = _cmd_help(enabled, locale, merged_templates)
return _render_cmd_template(cmd_templates, "help", locale, ctx) text_resp = _render_cmd_template(merged_templates, "help", locale, ctx)
return [CommandResponse(text=text_resp)]
# Provider-specific dispatch # Provider-specific dispatch — per-tracker
from .dispatch import get_handler from .dispatch import get_handler
# Group ctx_tuples by provider type responses: list[CommandResponse] = []
by_type: dict[str, list[tuple[CommandTracker, CommandConfig, ServiceProvider]]] = {} for tracker, config, provider in ctx_tuples:
for t in ctx_tuples: if len(responses) >= _MAX_RESPONSES_PER_COMMAND:
ptype = t[2].type _LOGGER.warning(
by_type.setdefault(ptype, []).append(t) "Truncated command responses at %d for bot %d cmd /%s",
_MAX_RESPONSES_PER_COMMAND, bot.id, cmd,
# Find which handler claims this command
for ptype, ptuples in by_type.items():
handler = get_handler(ptype)
if handler and cmd in handler.get_provider_commands():
# Build provider map filtered to this provider type
pmap = {p.id: p for _, _, p in ptuples}
result = await handler.handle(
cmd, args, count, locale, response_mode,
pmap, cmd_templates, bot, ptuples,
) )
if result is not None: break
return result
return None handler = get_handler(provider.type)
if not handler or cmd not in handler.get_provider_commands():
continue
tracker_templates = _templates_for_config(templates_by_config_id, config)
count = min(count_override or config.default_count or 5, 20)
response_mode = config.response_mode or "media"
result = await handler.handle(
cmd, args, count, locale, response_mode,
provider, tracker_templates, bot, tracker, config,
)
if result is not None:
responses.append(result)
return responses if responses else None
def _cmd_help( def _cmd_help(
@@ -283,17 +340,13 @@ async def send_reply(
session: aiohttp.ClientSession | None = None, session: aiohttp.ClientSession | None = None,
) -> None: ) -> None:
"""Send a text reply via TelegramClient.""" """Send a text reply via TelegramClient."""
async def _send(http: aiohttp.ClientSession) -> None: if session is None:
client = TelegramClient(http, bot_token) from ..services.http_session import get_http_session
result = await client.send_message(chat_id, text, reply_to_message_id=reply_to_message_id) session = await get_http_session()
if not result.get("success"): client = TelegramClient(session, bot_token)
_LOGGER.warning("Telegram reply failed: %s", result.get("error")) result = await client.send_message(chat_id, text, reply_to_message_id=reply_to_message_id)
if not result.get("success"):
if session is not None: _LOGGER.warning("Telegram reply failed: %s", result.get("error"))
await _send(session)
else:
async with aiohttp.ClientSession() as http:
await _send(http)
async def send_media_group( async def send_media_group(
@@ -319,52 +372,50 @@ async def send_media_group(
captions = [item.get("caption", "") for item in media_items if item.get("caption")] captions = [item.get("caption", "") for item in media_items if item.get("caption")]
caption = "\n".join(captions) if captions else None caption = "\n".join(captions) if captions else None
async def _send(http: aiohttp.ClientSession) -> None: if session is None:
client = TelegramClient(http, bot_token) from ..services.http_session import get_http_session
result = await client.send_notification( session = await get_http_session()
chat_id, assets=assets, caption=caption, client = TelegramClient(session, bot_token)
reply_to_message_id=reply_to_message_id, result = await client.send_notification(
chat_action=None, chat_id, assets=assets, caption=caption,
) reply_to_message_id=reply_to_message_id,
if not result.get("success"): chat_action=None,
_LOGGER.warning("Telegram media group failed: %s", result.get("error")) )
if not result.get("success"):
if session is not None: _LOGGER.warning("Telegram media group failed: %s", result.get("error"))
await _send(session)
else:
async with aiohttp.ClientSession() as http:
await _send(http)
async def register_commands_with_telegram(bot: TelegramBot) -> bool: async def register_commands_with_telegram(bot: TelegramBot) -> bool:
"""Register enabled commands with Telegram BotFather API via TelegramClient.""" """Register enabled commands with Telegram BotFather API via TelegramClient."""
ctx_tuples, templates = await _resolve_command_context(bot) ctx_tuples, templates_by_config_id = await _resolve_command_context(bot)
enabled, _, _, _ = _merge_command_context(ctx_tuples) enabled, _ = _merge_enabled_commands(ctx_tuples)
templates = _merge_all_templates(templates_by_config_id)
async with aiohttp.ClientSession() as http: from ..services.http_session import get_http_session
client = TelegramClient(http, bot.token) http = await get_http_session()
success = False client = TelegramClient(http, bot.token)
success = False
# Register per-locale commands # Register per-locale commands
for locale in ("en", "ru"): for locale in ("en", "ru"):
commands = [] commands = []
for cmd in enabled:
desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd
commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(commands, language_code=locale)
if result.get("success"):
success = True
else:
_LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error"))
# Register default (no language_code) with EN descriptions
en_commands = []
for cmd in enabled: for cmd in enabled:
desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd
en_commands.append({"command": cmd, "description": desc}) commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(en_commands) result = await client.set_my_commands(commands, language_code=locale)
if result.get("success"): if result.get("success"):
_LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username)
success = True success = True
else:
_LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error"))
return success # Register default (no language_code) with EN descriptions
en_commands = []
for cmd in enabled:
desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd
en_commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(en_commands)
if result.get("success"):
_LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username)
success = True
return success
@@ -6,70 +6,48 @@ import asyncio
import logging import logging
from typing import Any from typing import Any
import aiohttp from ...database.models import ServiceProvider
from notify_bridge_core.providers.immich.asset_utils import get_public_url
from ...database.models import ServiceProvider, TelegramBot
from ...services import make_immich_provider from ...services import make_immich_provider
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template from ...services.http_session import get_http_session
from .common import _format_assets, build_asset_dict from ..command_utils import get_trackers_for_provider
from ..handler import _render_cmd_template
from .common import _format_assets, build_asset_dict, fetch_albums_with_links
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def _cmd_albums( async def _cmd_albums(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str, provider: ServiceProvider, locale: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
provider_ids = set(providers_map.keys()) trackers = await get_trackers_for_provider(provider.id)
trackers = await _get_notification_trackers_for_providers(provider_ids)
if not trackers: if not trackers:
return {"albums": []} return {"albums": []}
albums_data: list[dict] = [] # Deduplicate album IDs while preserving order
async with aiohttp.ClientSession() as http: seen: set[str] = set()
for tracker in trackers: album_ids: list[str] = []
provider = providers_map.get(tracker.provider_id) for tracker in trackers:
if not provider or provider.type != "immich": for aid in tracker.collection_ids or []:
continue if aid not in seen:
immich = make_immich_provider(http, provider) seen.add(aid)
album_ids = tracker.collection_ids or [] album_ids.append(aid)
if not album_ids: if not album_ids:
continue return {"albums": []}
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/") ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
album_results = await asyncio.gather( http = await get_http_session()
*[immich.client.get_album(aid) for aid in album_ids], immich = make_immich_provider(http, provider)
return_exceptions=True, albums_data = await fetch_albums_with_links(immich.client, album_ids, ext_domain)
)
link_results = await asyncio.gather(
*[immich.client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
for album_id, result, links in zip(album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
albums_data.append({
"name": f"{album_id[:8]}...", "asset_count": "?", "id": album_id,
})
elif result:
pub_url = ""
if not isinstance(links, Exception) and ext_domain:
pub_url = get_public_url(ext_domain, links) or ""
albums_data.append({
"name": result.name, "asset_count": result.asset_count,
"id": album_id, "public_url": pub_url,
})
return {"albums": albums_data} return {"albums": albums_data}
async def cmd_favorites( async def cmd_favorites(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], providers_map: dict[int, ServiceProvider],
all_album_ids: list[str], count: int, locale: str, all_album_ids: list[str], count: int, locale: str,
response_mode: str, client: Any, response_mode: str, client: Any,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]: ) -> str | dict[str, Any]:
"""Handle /favorites command with concurrent album fetching.""" """Handle /favorites command with concurrent album fetching."""
album_ids = all_album_ids[:10] album_ids = all_album_ids[:10]
if not album_ids: if not album_ids:
@@ -104,28 +82,6 @@ async def cmd_summary(
if not all_album_ids: if not all_album_ids:
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": []}) return _render_cmd_template(cmd_templates, "summary", locale, {"albums": []})
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in all_album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in all_album_ids],
return_exceptions=True,
)
ext = external_domain.rstrip("/") ext = external_domain.rstrip("/")
albums_data = await fetch_albums_with_links(client, all_album_ids, ext, include_failed=False)
albums_data: list[dict] = []
for album_id, result, links in zip(all_album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
pub_url = ""
if not isinstance(links, Exception) and ext:
pub_url = get_public_url(ext, links) or ""
albums_data.append({
"name": result.name, "asset_count": result.asset_count,
"id": album_id, "public_url": pub_url,
})
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": albums_data}) return _render_cmd_template(cmd_templates, "summary", locale, {"albums": albums_data})
@@ -2,10 +2,12 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from typing import Any from typing import Any
from ...services import make_immich_provider from notify_bridge_core.providers.immich.asset_utils import get_public_url
from ..handler import _render_cmd_template from ..handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -17,6 +19,53 @@ _IMMICH_COMMANDS = {
} }
async def fetch_albums_with_links(
client: Any,
album_ids: list[str],
ext_domain: str,
*,
include_failed: bool = True,
) -> list[dict[str, Any]]:
"""Fetch albums and their shared links concurrently.
Returns a list of album data dicts with keys: name, asset_count, id,
public_url, and ``_album`` (the raw album object for callers that need
asset-level access).
When *include_failed* is True, albums that fail to fetch are included
with placeholder data (``"?"`` for counts). When False, they are
silently skipped.
"""
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
albums_data: list[dict[str, Any]] = []
for album_id, result, links in zip(album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
if include_failed:
albums_data.append({
"name": f"{album_id[:8]}...", "asset_count": "?",
"id": album_id, "public_url": "", "_album": None,
})
continue
if result:
pub_url = ""
if not isinstance(links, Exception) and ext_domain:
pub_url = get_public_url(ext_domain, links) or ""
albums_data.append({
"name": result.name, "asset_count": result.asset_count,
"id": album_id, "public_url": pub_url, "_album": result,
})
return albums_data
def build_asset_dict( def build_asset_dict(
asset: Any, asset: Any,
*, *,
@@ -56,8 +105,14 @@ def _format_assets(
assets: list[dict[str, Any]], cmd: str, query: str, assets: list[dict[str, Any]], cmd: str, query: str,
locale: str, response_mode: str, client: Any, locale: str, response_mode: str, client: Any,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]: ) -> str | dict[str, Any]:
"""Format asset results as text or media payload.""" """Format asset results as text or a text-plus-media payload.
Returns:
str: rendered text when *response_mode* is ``"text"`` (or no assets).
dict: ``{"text": ..., "media": [...]}`` when *response_mode* is
``"media"`` and assets are present.
"""
if not assets: if not assets:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": query}) return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": query})
@@ -68,7 +123,7 @@ def _format_assets(
}) })
if response_mode == "media": if response_mode == "media":
media_items = [] media_items: list[dict[str, Any]] = []
for asset in assets: for asset in assets:
asset_id = asset.get("id", "") asset_id = asset.get("id", "")
media_items.append({ media_items.append({
@@ -13,23 +13,22 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from ...database.engine import get_engine from ...database.engine import get_engine
from ...database.models import ( from ...database.models import (
EventLog, NotificationTarget, NotificationTrackerTarget, EventLog, NotificationTracker, NotificationTrackerTarget,
ServiceProvider, TelegramBot, TrackingConfig, ServiceProvider, TrackingConfig,
) )
from notify_bridge_core.providers.immich.asset_utils import get_public_url
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template from ..command_utils import get_trackers_for_provider
from .common import _format_assets, build_asset_dict from ..handler import _render_cmd_template
from .common import _format_assets, build_asset_dict, fetch_albums_with_links
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def _cmd_events( async def _cmd_events(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], provider: ServiceProvider,
count: int, locale: str, count: int, locale: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
provider_ids = set(providers_map.keys()) trackers = await get_trackers_for_provider(provider.id)
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracker_ids = [t.id for t in trackers] tracker_ids = [t.id for t in trackers]
if not tracker_ids: if not tracker_ids:
return {"events": []} return {"events": []}
@@ -57,32 +56,21 @@ async def cmd_latest(
locale: str, response_mode: str, locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
external_domain: str = "", external_domain: str = "",
) -> str | list[dict[str, Any]]: ) -> str | dict[str, Any]:
"""Handle /latest command with concurrent album fetching.""" """Handle /latest command with concurrent album fetching."""
album_ids = all_album_ids[:10] album_ids = all_album_ids[:10]
if not album_ids: if not album_ids:
return _format_assets([], "latest", "", locale, response_mode, client, cmd_templates) return _format_assets([], "latest", "", locale, response_mode, client, cmd_templates)
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
ext = external_domain.rstrip("/") ext = external_domain.rstrip("/")
fetched = await fetch_albums_with_links(client, album_ids, ext, include_failed=False)
latest_assets: list[dict[str, Any]] = [] latest_assets: list[dict[str, Any]] = []
for album_id, result, links in zip(album_ids, album_results, link_results): for album_data in fetched:
if isinstance(result, Exception): pub_url = album_data.get("public_url", "")
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result) album_obj = album_data.get("_album")
continue if album_obj:
if result: for aid, asset in list(album_obj.assets.items())[:count]:
pub_url = ""
if not isinstance(links, Exception) and ext:
pub_url = get_public_url(ext, links) or ""
for aid, asset in list(result.assets.items())[:count]:
asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else "" asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else ""
latest_assets.append(build_asset_dict(asset, public_url=asset_pub)) latest_assets.append(build_asset_dict(asset, public_url=asset_pub))
@@ -95,32 +83,21 @@ async def cmd_random(
locale: str, response_mode: str, locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
external_domain: str = "", external_domain: str = "",
) -> str | list[dict[str, Any]]: ) -> str | dict[str, Any]:
"""Handle /random command with concurrent album fetching.""" """Handle /random command with concurrent album fetching."""
album_ids = all_album_ids[:10] album_ids = all_album_ids[:10]
if not album_ids: if not album_ids:
return _format_assets([], "random", "", locale, response_mode, client, cmd_templates) return _format_assets([], "random", "", locale, response_mode, client, cmd_templates)
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
ext = external_domain.rstrip("/") ext = external_domain.rstrip("/")
fetched = await fetch_albums_with_links(client, album_ids, ext, include_failed=False)
random_assets: list[dict[str, Any]] = [] random_assets: list[dict[str, Any]] = []
for album_id, result, links in zip(album_ids, album_results, link_results): for album_data in fetched:
if isinstance(result, Exception): pub_url = album_data.get("public_url", "")
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result) album_obj = album_data.get("_album")
continue if album_obj:
if result: asset_list = list(album_obj.assets.values())
pub_url = ""
if not isinstance(links, Exception) and ext:
pub_url = get_public_url(ext, links) or ""
asset_list = list(result.assets.values())
sampled = rng.sample(asset_list, min(count, len(asset_list))) sampled = rng.sample(asset_list, min(count, len(asset_list)))
for asset in sampled: for asset in sampled:
asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else "" asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else ""
@@ -130,40 +107,40 @@ async def cmd_random(
return _format_assets(random_assets[:count], "random", "", locale, response_mode, client, cmd_templates) return _format_assets(random_assets[:count], "random", "", locale, response_mode, client, cmd_templates)
async def _check_native_memory(bot: TelegramBot) -> bool: async def _check_native_memory(provider_id: int) -> bool:
"""Check if any tracker-target linked to this bot uses native memory source.""" """Check if any notification tracker for this provider uses native memory source."""
engine = get_engine() engine = get_engine()
async with AsyncSession(engine) as session: async with AsyncSession(engine) as session:
result = await session.exec( tracker_result = await session.exec(
select(NotificationTarget).where( select(NotificationTracker).where(
NotificationTarget.type == "telegram", NotificationTracker.provider_id == provider_id,
NotificationTarget.user_id == bot.user_id,
) )
) )
targets = result.all() trackers = tracker_result.all()
bot_target_ids = {t.id for t in targets if t.config.get("bot_token") == bot.token} tracker_ids = [t.id for t in trackers]
if not bot_target_ids: if not tracker_ids:
return False return False
tt_result = await session.exec( tt_result = await session.exec(
select(NotificationTrackerTarget).where( select(NotificationTrackerTarget).where(
NotificationTrackerTarget.target_id.in_(bot_target_ids) NotificationTrackerTarget.tracker_id.in_(tracker_ids)
) )
) )
for tt in tt_result.all(): tc_ids = list({tt.tracking_config_id for tt in tt_result.all() if tt.tracking_config_id})
if tt.tracking_config_id: if not tc_ids:
tc = await session.get(TrackingConfig, tt.tracking_config_id) return False
if tc and tc.memory_source == "native": tc_result = await session.exec(
return True select(TrackingConfig).where(TrackingConfig.id.in_(tc_ids))
return False )
return any(tc.memory_source == "native" for tc in tc_result.all())
async def cmd_memory( async def cmd_memory(
bot: TelegramBot, client: Any, all_album_ids: list[str], count: int, provider_id: int, client: Any, all_album_ids: list[str], count: int,
locale: str, response_mode: str, locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]: ) -> str | dict[str, Any]:
"""Handle /memory command with concurrent album fetching.""" """Handle /memory command with concurrent album fetching."""
use_native = await _check_native_memory(bot) use_native = await _check_native_memory(provider_id)
today = datetime.now(timezone.utc) today = datetime.now(timezone.utc)
memory_assets: list[dict[str, Any]] = [] memory_assets: list[dict[str, Any]] = []
@@ -2,26 +2,21 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from typing import Any from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ...database.engine import get_engine
from ...database.models import ( from ...database.models import (
CommandConfig, CommandTracker, EventLog, CommandConfig, CommandTracker,
ServiceProvider, TelegramBot, ServiceProvider, TelegramBot,
) )
from ...services import make_immich_provider from ...services import make_immich_provider
from ..base import ProviderCommandHandler from ...services.http_session import get_http_session
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template from ..base import CommandResponse, ProviderCommandHandler
from notify_bridge_core.providers.immich.asset_utils import get_public_url from ..command_utils import get_last_event_str, get_trackers_for_provider
from ..handler import _render_cmd_template
from .albums import _cmd_albums, cmd_favorites, cmd_summary from .albums import _cmd_albums, cmd_favorites, cmd_summary
from .common import _IMMICH_COMMANDS from .common import _IMMICH_COMMANDS, fetch_albums_with_links
from .events import _cmd_events, cmd_latest, cmd_memory, cmd_random from .events import _cmd_events, cmd_latest, cmd_memory, cmd_random
from .search import cmd_find, cmd_person, cmd_place, cmd_search from .search import cmd_find, cmd_person, cmd_place, cmd_search
@@ -29,21 +24,15 @@ _LOGGER = logging.getLogger(__name__)
async def _cmd_status( async def _cmd_status(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str, provider: ServiceProvider, locale: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
provider_ids = set(providers_map.keys()) trackers = await get_trackers_for_provider(provider.id)
trackers = await _get_notification_trackers_for_providers(provider_ids)
active = sum(1 for t in trackers if t.enabled) active = sum(1 for t in trackers if t.enabled)
total = len(trackers) total = len(trackers)
total_albums = sum(len(t.collection_ids or []) for t in trackers) total_albums = sum(len(t.collection_ids or []) for t in trackers)
engine = get_engine() tracker_ids = [t.id for t in trackers]
async with AsyncSession(engine) as session: last_str = await get_last_event_str(tracker_ids)
result = await session.exec(
select(EventLog).order_by(EventLog.created_at.desc()).limit(1)
)
last_event = result.first()
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
return { return {
"trackers_active": active, "trackers_total": total, "trackers_active": active, "trackers_total": total,
@@ -52,16 +41,13 @@ async def _cmd_status(
async def _cmd_people( async def _cmd_people(
providers_map: dict[int, ServiceProvider], locale: str, provider: ServiceProvider, locale: str,
) -> dict[str, Any]: ) -> dict[str, Any]:
all_people: dict[str, str] = {} all_people: dict[str, str] = {}
async with aiohttp.ClientSession() as http: http = await get_http_session()
for provider in providers_map.values(): immich = make_immich_provider(http, provider)
if provider.type != "immich": people = await immich.client.get_people()
continue all_people.update(people)
immich = make_immich_provider(http, provider)
people = await immich.client.get_people()
all_people.update(people)
names = sorted(all_people.values()) names = sorted(all_people.values())
return {"people": names} return {"people": names}
@@ -87,106 +73,92 @@ class ImmichCommandHandler(ProviderCommandHandler):
count: int, count: int,
locale: str, locale: str,
response_mode: str, response_mode: str,
providers_map: dict[int, ServiceProvider], provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot, bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], tracker: CommandTracker,
) -> str | list[dict[str, Any]] | None: config: CommandConfig,
) -> CommandResponse | None:
if cmd == "status": if cmd == "status":
ctx = await _cmd_status(bot, providers_map, locale) ctx = await _cmd_status(provider, locale)
return _render_cmd_template(cmd_templates, "status", locale, ctx) return CommandResponse(text=_render_cmd_template(cmd_templates, "status", locale, ctx))
if cmd == "albums": if cmd == "albums":
ctx = await _cmd_albums(bot, providers_map, locale) ctx = await _cmd_albums(provider, locale)
return _render_cmd_template(cmd_templates, "albums", locale, ctx) return CommandResponse(text=_render_cmd_template(cmd_templates, "albums", locale, ctx))
if cmd == "events": if cmd == "events":
ctx = await _cmd_events(bot, providers_map, count, locale) ctx = await _cmd_events(provider, count, locale)
return _render_cmd_template(cmd_templates, "events", locale, ctx) return CommandResponse(text=_render_cmd_template(cmd_templates, "events", locale, ctx))
if cmd == "people": if cmd == "people":
ctx = await _cmd_people(providers_map, locale) ctx = await _cmd_people(provider, locale)
return _render_cmd_template(cmd_templates, "people", locale, ctx) return CommandResponse(text=_render_cmd_template(cmd_templates, "people", locale, ctx))
if cmd in ("search", "find", "person", "place", "latest", if cmd in ("search", "find", "person", "place", "latest",
"random", "favorites", "summary", "memory"): "random", "favorites", "summary", "memory"):
return await _cmd_immich( return await _cmd_immich(
bot, cmd, args, count, locale, response_mode, cmd, args, count, locale, response_mode,
providers_map, cmd_templates, provider, cmd_templates,
) )
return None return None
async def _cmd_immich( async def _cmd_immich(
bot: TelegramBot, cmd: str, args: str, count: int, locale: str, cmd: str, args: str, count: int, locale: str,
response_mode: str, providers_map: dict[int, ServiceProvider], response_mode: str, provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]: ) -> CommandResponse | None:
"""Handle commands that need Immich API access and may return media.""" """Handle commands that need Immich API access and may return media."""
if not providers_map: notification_trackers = await get_trackers_for_provider(provider.id)
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
provider_ids = set(providers_map.keys())
notification_trackers = await _get_notification_trackers_for_providers(provider_ids)
all_album_ids: list[str] = [] all_album_ids: list[str] = []
for t in notification_trackers: for t in notification_trackers:
all_album_ids.extend(t.collection_ids or []) all_album_ids.extend(t.collection_ids or [])
provider: ServiceProvider | None = None
for p in providers_map.values():
if p.type == "immich":
provider = p
break
if not provider:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/") ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
async with aiohttp.ClientSession() as http: http = await get_http_session()
immich = make_immich_provider(http, provider) immich = make_immich_provider(http, provider)
client = immich.client client = immich.client
# Build asset_id → public_url map from tracked albums' shared links # Build asset_id → public_url map from tracked albums' shared links
asset_public_urls: dict[str, str] = {} asset_public_urls: dict[str, str] = {}
if ext_domain and all_album_ids and cmd in ("search", "find", "person", "place", "favorites"): if ext_domain and all_album_ids and cmd in ("search", "find", "person", "place", "favorites"):
link_results = await asyncio.gather( fetched = await fetch_albums_with_links(client, all_album_ids, ext_domain, include_failed=False)
*[client.get_shared_links(aid) for aid in all_album_ids], for album_data in fetched:
return_exceptions=True, pub_url = album_data.get("public_url", "")
) album_obj = album_data.get("_album")
album_results = await asyncio.gather( if pub_url and album_obj:
*[client.get_album(aid) for aid in all_album_ids], for asset_id in album_obj.assets:
return_exceptions=True, asset_public_urls[asset_id] = f"{pub_url}/photos/{asset_id}"
)
for album_id, links, album in zip(all_album_ids, link_results, album_results):
if isinstance(links, Exception) or isinstance(album, Exception):
continue
pub_url = get_public_url(ext_domain, links)
if pub_url and album:
for asset_id in album.assets:
asset_public_urls[asset_id] = f"{pub_url}/photos/{asset_id}"
if cmd == "search": # Wrap single-provider in a map for functions that still expect it
return await cmd_search(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls) providers_map = {provider.id: provider}
if cmd == "find": result: str | dict[str, Any] | None = None
return await cmd_find(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
if cmd == "person": if cmd == "search":
return await cmd_person(client, args, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls) result = await cmd_search(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "find":
result = await cmd_find(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "person":
result = await cmd_person(client, args, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "place":
result = await cmd_place(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "favorites":
result = await cmd_favorites(providers_map, all_album_ids, count, locale, response_mode, client, cmd_templates)
elif cmd == "latest":
result = await cmd_latest(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
elif cmd == "random":
result = await cmd_random(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
elif cmd == "summary":
result = await cmd_summary(client, all_album_ids, locale, cmd_templates, external_domain=ext_domain)
elif cmd == "memory":
result = await cmd_memory(provider.id, client, all_album_ids, count, locale, response_mode, cmd_templates)
if cmd == "place": if result is None:
return await cmd_place(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls) return None
# _format_assets returns {"text": ..., "media": [...]} for media mode
if cmd == "favorites": if isinstance(result, dict):
return await cmd_favorites(bot, providers_map, all_album_ids, count, locale, response_mode, client, cmd_templates) return CommandResponse(
text=result.get("text"),
if cmd == "latest": media=result.get("media", []),
return await cmd_latest(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain) )
return CommandResponse(text=result)
if cmd == "random":
return await cmd_random(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
if cmd == "summary":
return await cmd_summary(client, all_album_ids, locale, cmd_templates, external_domain=ext_domain)
if cmd == "memory":
return await cmd_memory(bot, client, all_album_ids, count, locale, response_mode, cmd_templates)
return None
@@ -9,14 +9,15 @@ from .common import _format_assets
def _enrich_assets(assets: list[dict[str, Any]], asset_public_urls: dict[str, str]) -> list[dict[str, Any]]: def _enrich_assets(assets: list[dict[str, Any]], asset_public_urls: dict[str, str]) -> list[dict[str, Any]]:
"""Add public_url to assets from the pre-built map.""" """Add public_url to assets from the pre-built map. Returns new list without mutating inputs."""
if not asset_public_urls: if not asset_public_urls:
return assets return assets
for asset in assets: return [
aid = asset.get("id", "") {**asset, "public_url": asset_public_urls.get(asset.get("id", ""), "")}
if aid and aid in asset_public_urls and not asset.get("public_url"): if asset.get("id", "") in asset_public_urls and not asset.get("public_url")
asset["public_url"] = asset_public_urls[aid] else asset
return assets for asset in assets
]
async def cmd_search( async def cmd_search(
@@ -24,7 +25,7 @@ async def cmd_search(
locale: str, response_mode: str, locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None, asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]: ) -> str | dict[str, Any]:
"""Handle /search command.""" """Handle /search command."""
if not args: if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "search", "query": ""}) return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "search", "query": ""})
@@ -38,7 +39,7 @@ async def cmd_find(
locale: str, response_mode: str, locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None, asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]: ) -> str | dict[str, Any]:
"""Handle /find command.""" """Handle /find command."""
if not args: if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "find", "query": ""}) return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "find", "query": ""})
@@ -52,7 +53,7 @@ async def cmd_person(
locale: str, response_mode: str, locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None, asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]: ) -> str | dict[str, Any]:
"""Handle /person command.""" """Handle /person command."""
if not args: if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": ""}) return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": ""})
@@ -74,7 +75,7 @@ async def cmd_place(
locale: str, response_mode: str, locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None, asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]: ) -> str | dict[str, Any]:
"""Handle /place command.""" """Handle /place command."""
if not args: if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "place", "query": ""}) return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "place", "query": ""})
@@ -3,17 +3,31 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections.abc import Callable, Coroutine
from typing import Any from typing import Any
from ..database.models import CommandConfig, CommandTracker, ServiceProvider, TelegramBot from ..database.models import CommandConfig, CommandTracker, ServiceProvider, TelegramBot
from ..services import make_nut_provider from ..services import make_nut_provider
from .base import ProviderCommandHandler from .base import CommandResponse, ProviderCommandHandler
from .handler import _render_cmd_template from .handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_NUT_COMMANDS = {"status", "devices", "battery"} _NUT_COMMANDS = {"status", "devices", "battery"}
# ---------------------------------------------------------------------------
# Command dispatch table
# ---------------------------------------------------------------------------
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
"""Register a function in the text command dispatch table."""
name = fn.__name__.removeprefix("_cmd_")
_TEXT_COMMANDS[name] = fn
return fn
class NutCommandHandler(ProviderCommandHandler): class NutCommandHandler(ProviderCommandHandler):
"""Handles NUT-specific bot commands.""" """Handles NUT-specific bot commands."""
@@ -33,80 +47,73 @@ class NutCommandHandler(ProviderCommandHandler):
count: int, count: int,
locale: str, locale: str,
response_mode: str, response_mode: str,
providers_map: dict[int, ServiceProvider], provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot, bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], tracker: CommandTracker,
) -> str | list[dict[str, Any]] | None: config: CommandConfig,
if cmd == "status": ) -> CommandResponse | None:
ctx = await _cmd_status(providers_map) fn = _TEXT_COMMANDS.get(cmd)
return _render_cmd_template(cmd_templates, "status", locale, ctx) if fn is None:
if cmd == "devices": return None
ctx = await _cmd_devices(providers_map) ctx = await fn(provider, count)
return _render_cmd_template(cmd_templates, "devices", locale, ctx) return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
if cmd == "battery":
ctx = await _cmd_battery(providers_map)
return _render_cmd_template(cmd_templates, "battery", locale, ctx)
return None
async def _query_all_ups( async def _query_ups(
providers_map: dict[int, ServiceProvider], provider: ServiceProvider,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Connect to all NUT providers and query UPS data.""" """Connect to a NUT provider and query UPS data."""
from notify_bridge_core.providers.nut.models import NutUpsData from notify_bridge_core.providers.nut.models import NutUpsData
results: list[dict[str, Any]] = [] results: list[dict[str, Any]] = []
for provider in providers_map.values(): nut = make_nut_provider(provider)
if provider.type != "nut": try:
continue client = nut._make_client()
nut = make_nut_provider(provider) await client.connect()
try: try:
client = nut._make_client() devices = await client.list_ups()
await client.connect() for dev in devices:
try: variables = await client.list_var(dev.name)
devices = await client.list_ups() data = NutUpsData.from_variables(dev.name, variables)
for dev in devices: results.append({
variables = await client.list_var(dev.name) "name": data.name,
data = NutUpsData.from_variables(dev.name, variables) "description": data.description,
results.append({ "model": data.model,
"name": data.name, "manufacturer": data.manufacturer,
"description": data.description, "status": data.status,
"model": data.model, "battery_charge": int(data.battery_charge) if data.battery_charge is not None else None,
"manufacturer": data.manufacturer, "battery_runtime": data.battery_runtime_formatted,
"status": data.status, "ups_load": int(data.ups_load) if data.ups_load is not None else None,
"battery_charge": int(data.battery_charge) if data.battery_charge is not None else None, "input_voltage": str(data.input_voltage) if data.input_voltage is not None else None,
"battery_runtime": data.battery_runtime_formatted, "output_voltage": str(data.output_voltage) if data.output_voltage is not None else None,
"ups_load": int(data.ups_load) if data.ups_load is not None else None, })
"input_voltage": str(data.input_voltage) if data.input_voltage is not None else None, finally:
"output_voltage": str(data.output_voltage) if data.output_voltage is not None else None, await client.disconnect()
}) except Exception as exc:
finally: _LOGGER.warning("Failed to query NUT provider %s: %s", provider.name, exc)
await client.disconnect()
except Exception as exc:
_LOGGER.warning("Failed to query NUT provider %s: %s", provider.name, exc)
return results return results
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: @_text_cmd
devices = await _query_all_ups(providers_map) async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
devices = await _query_ups(provider)
return {"devices": devices} return {"devices": devices}
async def _cmd_devices(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: @_text_cmd
async def _cmd_devices(provider: ServiceProvider, count: int) -> dict[str, Any]:
devices: list[dict[str, Any]] = [] devices: list[dict[str, Any]] = []
for provider in providers_map.values(): nut = make_nut_provider(provider)
if provider.type != "nut": try:
continue device_list = await nut.list_collections()
nut = make_nut_provider(provider) devices.extend(device_list)
try: except Exception as exc:
device_list = await nut.list_collections() _LOGGER.warning("Failed to list devices from %s: %s", provider.name, exc)
devices.extend(device_list)
except Exception as exc:
_LOGGER.warning("Failed to list devices from %s: %s", provider.name, exc)
return {"devices": devices} return {"devices": devices}
async def _cmd_battery(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: @_text_cmd
devices = await _query_all_ups(providers_map) async def _cmd_battery(provider: ServiceProvider, count: int) -> dict[str, Any]:
devices = await _query_ups(provider)
return {"devices": devices} return {"devices": devices}
@@ -3,26 +3,47 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections.abc import Callable, Coroutine
from typing import Any from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import ( from ..database.models import (
CommandConfig, CommandTracker, EventLog, CommandConfig, CommandTracker, ServiceProvider, TelegramBot,
NotificationTracker, ServiceProvider, TelegramBot,
) )
from ..services import make_planka_provider from ..services import make_planka_provider
from .base import ProviderCommandHandler from ..services.http_session import get_http_session
from .handler import _render_cmd_template, _get_notification_trackers_for_providers from .base import CommandResponse, ProviderCommandHandler
from .command_utils import get_last_event_str, get_tracked_collection_ids, get_trackers_for_provider
from .handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_PLANKA_COMMANDS = {"status", "boards", "cards", "lists"} _PLANKA_COMMANDS = {"status", "boards", "cards", "lists"}
def _get_tracked_board_ids(
provider: ServiceProvider,
trackers: list,
) -> list[str]:
"""Get board IDs from tracked collection_ids for this provider."""
if not provider.config.get("api_key"):
return []
return get_tracked_collection_ids(provider, trackers)
# ---------------------------------------------------------------------------
# Command dispatch table
# ---------------------------------------------------------------------------
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
"""Register a function in the text command dispatch table."""
name = fn.__name__.removeprefix("_cmd_")
_TEXT_COMMANDS[name] = fn
return fn
class PlankaCommandHandler(ProviderCommandHandler): class PlankaCommandHandler(ProviderCommandHandler):
"""Handles Planka-specific bot commands.""" """Handles Planka-specific bot commands."""
@@ -43,69 +64,26 @@ class PlankaCommandHandler(ProviderCommandHandler):
count: int, count: int,
locale: str, locale: str,
response_mode: str, response_mode: str,
providers_map: dict[int, ServiceProvider], provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]], cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot, bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], tracker: CommandTracker,
) -> str | list[dict[str, Any]] | None: config: CommandConfig,
if cmd == "status": ) -> CommandResponse | None:
ctx = await _cmd_status(providers_map) fn = _TEXT_COMMANDS.get(cmd)
return _render_cmd_template(cmd_templates, "status", locale, ctx) if fn is None:
if cmd == "boards": return None
ctx = await _cmd_boards(providers_map) ctx = await fn(provider, count)
return _render_cmd_template(cmd_templates, "boards", locale, ctx) return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
if cmd == "cards":
ctx = await _cmd_cards(providers_map, count)
return _render_cmd_template(cmd_templates, "cards", locale, ctx)
if cmd == "lists":
ctx = await _cmd_lists(providers_map)
return _render_cmd_template(cmd_templates, "lists", locale, ctx)
return None
def _get_tracked_board_ids( @_text_cmd
providers_map: dict[int, ServiceProvider], async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers: list[NotificationTracker], trackers = await get_trackers_for_provider(provider.id)
) -> list[tuple[ServiceProvider, str]]: tracked_boards = _get_tracked_board_ids(provider, trackers)
"""Get (provider, board_id) tuples from tracked collection_ids."""
boards: list[tuple[ServiceProvider, str]] = []
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "planka":
continue
if not provider.config.get("api_key"):
continue
for board_id in (tracker.collection_ids or []):
entry = (provider, board_id)
if entry not in boards:
boards.append(entry)
# Also check filters.collections
for board_id in (tracker.filters or {}).get("collections", []):
entry = (provider, board_id)
if entry not in boards:
boards.append(entry)
return boards[:20]
tracker_ids = [t.id for t in trackers]
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: last_str = await get_last_event_str(tracker_ids)
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
# Last event
engine = get_engine()
async with AsyncSession(engine) as session:
tracker_ids = [t.id for t in trackers]
if tracker_ids:
result = await session.exec(
select(EventLog)
.where(EventLog.tracker_id.in_(tracker_ids))
.order_by(EventLog.created_at.desc()).limit(1)
)
last_event = result.first()
else:
last_event = None
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
return { return {
"boards_count": len(tracked_boards), "boards_count": len(tracked_boards),
@@ -113,81 +91,69 @@ async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, An
} }
async def _cmd_boards(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: @_text_cmd
provider_ids = set(providers_map.keys()) async def _cmd_boards(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await _get_notification_trackers_for_providers(provider_ids) trackers = await get_trackers_for_provider(provider.id)
tracked_boards = _get_tracked_board_ids(providers_map, trackers) tracked_boards = _get_tracked_board_ids(provider, trackers)
boards_data: list[dict[str, Any]] = [] boards_data: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http: http = await get_http_session()
for provider, board_id in tracked_boards: planka = make_planka_provider(http, provider)
planka = make_planka_provider(http, provider) all_boards = await planka.client.get_boards()
all_boards = await planka.client.get_boards() board_names = {str(b.get("id", "")): b.get("name", "") for b in all_boards}
for b in all_boards: for board_id in tracked_boards:
if str(b.get("id", "")) == board_id: boards_data.append({"name": board_names.get(board_id, board_id)})
boards_data.append({"name": b.get("name", board_id)})
break
else:
boards_data.append({"name": board_id})
return {"boards": boards_data} return {"boards": boards_data}
async def _cmd_cards( @_text_cmd
providers_map: dict[int, ServiceProvider], count: int, async def _cmd_cards(provider: ServiceProvider, count: int) -> dict[str, Any]:
) -> dict[str, Any]: trackers = await get_trackers_for_provider(provider.id)
provider_ids = set(providers_map.keys()) tracked_boards = _get_tracked_board_ids(provider, trackers)
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
all_cards: list[dict[str, Any]] = [] all_cards: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http: http = await get_http_session()
for provider, board_id in tracked_boards: planka = make_planka_provider(http, provider)
planka = make_planka_provider(http, provider) boards = await planka.client.get_boards()
cards = await planka.client.get_board_cards(board_id, limit=count) board_names = {str(b.get("id", "")): b.get("name", "") for b in boards}
lists = await planka.client.get_board_lists(board_id)
lists_by_id = {str(lst.get("id", "")): lst.get("name", "") for lst in lists}
boards = await planka.client.get_boards() for board_id in tracked_boards:
board_name = board_id cards = await planka.client.get_board_cards(board_id, limit=count)
for b in boards: lists = await planka.client.get_board_lists(board_id)
if str(b.get("id", "")) == board_id: lists_by_id = {str(lst.get("id", "")): lst.get("name", "") for lst in lists}
board_name = b.get("name", board_id) board_name = board_names.get(board_id, board_id)
break
for card in cards: for card in cards:
list_id = str(card.get("listId", "")) list_id = str(card.get("listId", ""))
all_cards.append({ all_cards.append({
"name": card.get("name", ""), "name": card.get("name", ""),
"list_name": lists_by_id.get(list_id, ""), "list_name": lists_by_id.get(list_id, ""),
"board_name": board_name, "board_name": board_name,
}) })
return {"cards": all_cards[:count]} return {"cards": all_cards[:count]}
async def _cmd_lists(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: @_text_cmd
provider_ids = set(providers_map.keys()) async def _cmd_lists(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await _get_notification_trackers_for_providers(provider_ids) trackers = await get_trackers_for_provider(provider.id)
tracked_boards = _get_tracked_board_ids(providers_map, trackers) tracked_boards = _get_tracked_board_ids(provider, trackers)
all_lists: list[dict[str, Any]] = [] all_lists: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http: http = await get_http_session()
for provider, board_id in tracked_boards: planka = make_planka_provider(http, provider)
planka = make_planka_provider(http, provider) boards = await planka.client.get_boards()
lists = await planka.client.get_board_lists(board_id) board_names = {str(b.get("id", "")): b.get("name", "") for b in boards}
boards = await planka.client.get_boards() for board_id in tracked_boards:
board_name = board_id lists = await planka.client.get_board_lists(board_id)
for b in boards: board_name = board_names.get(board_id, board_id)
if str(b.get("id", "")) == board_id:
board_name = b.get("name", board_id)
break
for lst in lists: for lst in lists:
all_lists.append({ all_lists.append({
"name": lst.get("name", ""), "name": lst.get("name", ""),
"board_name": board_name, "board_name": board_name,
}) })
return {"lists": all_lists} return {"lists": all_lists}
@@ -6,7 +6,6 @@ import hmac
import logging import logging
from typing import Any from typing import Any
import aiohttp
from fastapi import APIRouter, Depends, Header, HTTPException, Request from fastapi import APIRouter, Depends, Header, HTTPException, Request
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -16,6 +15,7 @@ from notify_bridge_core.notifications.telegram.client import TelegramClient
from ..database.engine import get_session from ..database.engine import get_session
from ..database.models import TelegramBot, TelegramChat from ..database.models import TelegramBot, TelegramChat
from ..services.telegram import save_chat_from_webhook from ..services.telegram import save_chat_from_webhook
from .base import CommandResponse
from .handler import handle_command, send_media_group, send_reply from .handler import handle_command, send_media_group, send_reply
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -89,15 +89,13 @@ async def telegram_webhook(
return {"ok": True, "skipped": "commands_disabled"} return {"ok": True, "skipped": "commands_disabled"}
effective_lang = chat_row.language_override or msg_language effective_lang = chat_row.language_override or msg_language
message_id = message.get("message_id") message_id = message.get("message_id")
cmd_response = await handle_command(bot, chat_id, text, language_code=effective_lang) responses = await handle_command(bot, chat_id, text, language_code=effective_lang)
if cmd_response is not None: if responses:
if isinstance(cmd_response, dict) and "media" in cmd_response: for resp in responses:
await send_reply(bot.token, chat_id, cmd_response["text"], reply_to_message_id=message_id) if resp.text:
await send_media_group(bot.token, chat_id, cmd_response["media"], reply_to_message_id=message_id) await send_reply(bot.token, chat_id, resp.text, reply_to_message_id=message_id)
elif isinstance(cmd_response, list): if resp.media:
await send_media_group(bot.token, chat_id, cmd_response, reply_to_message_id=message_id) await send_media_group(bot.token, chat_id, resp.media, reply_to_message_id=message_id)
else:
await send_reply(bot.token, chat_id, cmd_response, reply_to_message_id=message_id)
return {"ok": True} return {"ok": True}
return {"ok": True, "skipped": "not_a_command"} return {"ok": True, "skipped": "not_a_command"}
@@ -105,13 +103,15 @@ async def telegram_webhook(
async def register_webhook(bot_token: str, webhook_url: str, secret: str | None = None) -> dict: async def register_webhook(bot_token: str, webhook_url: str, secret: str | None = None) -> dict:
"""Register webhook URL with Telegram Bot API via TelegramClient.""" """Register webhook URL with Telegram Bot API via TelegramClient."""
async with aiohttp.ClientSession() as http: from ..services.http_session import get_http_session
client = TelegramClient(http, bot_token) http = await get_http_session()
return await client.set_webhook(webhook_url, secret=secret) client = TelegramClient(http, bot_token)
return await client.set_webhook(webhook_url, secret=secret)
async def unregister_webhook(bot_token: str) -> dict: async def unregister_webhook(bot_token: str) -> dict:
"""Remove webhook from Telegram Bot API via TelegramClient.""" """Remove webhook from Telegram Bot API via TelegramClient."""
async with aiohttp.ClientSession() as http: from ..services.http_session import get_http_session
client = TelegramClient(http, bot_token) http = await get_http_session()
return await client.delete_webhook() client = TelegramClient(http, bot_token)
return await client.delete_webhook()
@@ -359,6 +359,7 @@ class NotificationTrackerState(SQLModel, table=True):
# Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id # Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id
tracker_id: int = Field( tracker_id: int = Field(
foreign_key="notification_tracker.id", foreign_key="notification_tracker.id",
index=True,
sa_column_kwargs={"name": "notification_tracker_id"}, sa_column_kwargs={"name": "notification_tracker_id"},
) )
collection_id: str collection_id: str
@@ -458,7 +459,7 @@ class CommandTrackerListener(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
command_tracker_id: int = Field( command_tracker_id: int = Field(
foreign_key="command_tracker.id", foreign_key="command_tracker.id",
index=True,
) )
listener_type: str # e.g. "telegram_bot" listener_type: str # e.g. "telegram_bot"
@@ -73,6 +73,8 @@ async def lifespan(app: FastAPI):
await start_scheduler() await start_scheduler()
yield yield
# Graceful shutdown # Graceful shutdown
from .services.http_session import close_http_session
await close_http_session()
scheduler = get_scheduler() scheduler = get_scheduler()
if scheduler.running: if scheduler.running:
scheduler.shutdown() scheduler.shutdown()
@@ -1,5 +1,11 @@
"""Shared service utilities.""" """Shared service utilities."""
from __future__ import annotations
from typing import Any, Protocol
import aiohttp
from notify_bridge_core.providers.immich import ImmichServiceProvider from notify_bridge_core.providers.immich import ImmichServiceProvider
from notify_bridge_core.providers.gitea import GiteaServiceProvider from notify_bridge_core.providers.gitea import GiteaServiceProvider
from notify_bridge_core.providers.planka import PlankaServiceProvider from notify_bridge_core.providers.planka import PlankaServiceProvider
@@ -8,8 +14,23 @@ from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvid
from ..database.models import ServiceProvider from ..database.models import ServiceProvider
# Default timeout for all outgoing HTTP requests to external services.
DEFAULT_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
def make_immich_provider(http_session, provider: ServiceProvider) -> ImmichServiceProvider:
class CollectionProvider(Protocol):
"""Protocol for providers that can list collections."""
async def list_collections(self) -> list[dict[str, Any]]: ...
class TestableProvider(Protocol):
"""Protocol for providers that support connection testing."""
async def test_connection(self) -> dict[str, Any]: ...
def make_immich_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> ImmichServiceProvider:
"""Create an ImmichServiceProvider from a DB provider model.""" """Create an ImmichServiceProvider from a DB provider model."""
config = provider.config or {} config = provider.config or {}
return ImmichServiceProvider( return ImmichServiceProvider(
@@ -21,7 +42,7 @@ def make_immich_provider(http_session, provider: ServiceProvider) -> ImmichServi
) )
def make_gitea_provider(http_session, provider: ServiceProvider) -> GiteaServiceProvider: def make_gitea_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> GiteaServiceProvider:
"""Create a GiteaServiceProvider from a DB provider model.""" """Create a GiteaServiceProvider from a DB provider model."""
config = provider.config or {} config = provider.config or {}
return GiteaServiceProvider( return GiteaServiceProvider(
@@ -32,7 +53,7 @@ def make_gitea_provider(http_session, provider: ServiceProvider) -> GiteaService
) )
def make_planka_provider(http_session, provider: ServiceProvider) -> PlankaServiceProvider: def make_planka_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> PlankaServiceProvider:
"""Create a PlankaServiceProvider from a DB provider model.""" """Create a PlankaServiceProvider from a DB provider model."""
config = provider.config or {} config = provider.config or {}
return PlankaServiceProvider( return PlankaServiceProvider(
@@ -55,7 +76,7 @@ def make_nut_provider(provider: ServiceProvider) -> NutServiceProvider:
) )
def make_google_photos_provider(http_session, provider: ServiceProvider) -> GooglePhotosServiceProvider: def make_google_photos_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> GooglePhotosServiceProvider:
"""Create a GooglePhotosServiceProvider from a DB provider model.""" """Create a GooglePhotosServiceProvider from a DB provider model."""
config = provider.config or {} config = provider.config or {}
return GooglePhotosServiceProvider( return GooglePhotosServiceProvider(
@@ -65,3 +86,61 @@ def make_google_photos_provider(http_session, provider: ServiceProvider) -> Goog
config.get("refresh_token", ""), config.get("refresh_token", ""),
provider.name, provider.name,
) )
# ---------------------------------------------------------------------------
# Provider factory registry — maps provider type strings to factory callables
# that create a provider with a ``list_collections`` method. Providers that
# require an API credential skip creation when the credential is missing
# (the factory returns None in that case).
# ---------------------------------------------------------------------------
def _make_collection_provider(
http_session: aiohttp.ClientSession,
provider: ServiceProvider,
) -> CollectionProvider | None:
"""Create a CollectionProvider for the given DB provider, or None if unsupported."""
ptype = provider.type
config = provider.config or {}
if ptype == "immich":
return make_immich_provider(http_session, provider)
if ptype == "gitea":
if not config.get("api_token"):
return None
return make_gitea_provider(http_session, provider)
if ptype == "planka":
if not config.get("api_key"):
return None
return make_planka_provider(http_session, provider)
if ptype == "google_photos":
return make_google_photos_provider(http_session, provider)
# NUT provider needs no http_session
if ptype == "nut":
return make_nut_provider(provider) # type: ignore[return-value]
return None
# Set of provider types that need an aiohttp session for collection listing.
_HTTP_COLLECTION_PROVIDERS = {"immich", "gitea", "planka", "google_photos"}
async def list_provider_collections(provider: ServiceProvider) -> list[dict[str, Any]]:
"""List collections for any supported provider type.
Returns an empty list for providers that don't support collections or
are missing required credentials.
"""
if provider.type in _HTTP_COLLECTION_PROVIDERS:
from .http_session import get_http_session
http_session = await get_http_session()
svc = _make_collection_provider(http_session, provider)
if svc is None:
return []
return await svc.list_collections()
# Non-HTTP providers (e.g. NUT)
svc = _make_collection_provider(None, provider) # type: ignore[arg-type]
if svc is None:
return []
return await svc.list_collections()
@@ -6,7 +6,6 @@ import logging
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import Any
import aiohttp
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -159,27 +158,28 @@ async def _execute_with_provider(
) )
from notify_bridge_core.providers.immich.client import ImmichClient from notify_bridge_core.providers.immich.client import ImmichClient
async with aiohttp.ClientSession() as http_session: from .http_session import get_http_session
client = ImmichClient( http_session = await get_http_session()
http_session, client = ImmichClient(
provider_config.get("url", ""), http_session,
provider_config.get("api_key", ""), provider_config.get("url", ""),
provider_config.get("api_key", ""),
)
external_domain = provider_config.get("external_domain")
if external_domain:
client.external_domain = external_domain
# Verify connectivity
if not await client.ping():
return ActionResult(
success=False,
error=f"Cannot connect to Immich server ({provider_name})",
) )
external_domain = provider_config.get("external_domain")
if external_domain:
client.external_domain = external_domain
# Verify connectivity executor = ImmichActionExecutor(client)
if not await client.ping(): if dry_run:
return ActionResult( return await executor.dry_run(action_type, rule_configs, action_config)
success=False, return await executor.execute(action_type, rule_configs, action_config)
error=f"Cannot connect to Immich server ({provider_name})",
)
executor = ImmichActionExecutor(client)
if dry_run:
return await executor.dry_run(action_type, rule_configs, action_config)
return await executor.execute(action_type, rule_configs, action_config)
return ActionResult( return ActionResult(
success=False, success=False,
@@ -0,0 +1,33 @@
"""Application-level shared aiohttp.ClientSession.
All outgoing HTTP requests in the server package should use the shared
session returned by ``get_http_session()`` instead of creating
per-request ``aiohttp.ClientSession`` instances. This keeps a single
TCP connection pool alive for the lifetime of the process, avoiding
the overhead of pool creation/teardown on every request.
Call ``close_http_session()`` once during application shutdown.
"""
from __future__ import annotations
import aiohttp
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
_session: aiohttp.ClientSession | None = None
async def get_http_session() -> aiohttp.ClientSession:
"""Get or create the shared HTTP session."""
global _session
if _session is None or _session.closed:
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
return _session
async def close_http_session() -> None:
"""Close the shared HTTP session (call on app shutdown)."""
global _session
if _session is not None and not _session.closed:
await _session.close()
_session = None
@@ -3,8 +3,6 @@
import logging import logging
from typing import Any from typing import Any
import aiohttp
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -90,19 +88,21 @@ async def _send_telegram_broadcast(target: NotificationTarget, message: str, rec
if not receivers: if not receivers:
return {"success": False, "error": "No receivers configured"} return {"success": False, "error": "No receivers configured"}
from .http_session import get_http_session
http = await get_http_session()
results: list[dict] = [] results: list[dict] = []
async with aiohttp.ClientSession() as session: client = TelegramClient(http, bot_token)
client = TelegramClient(session, bot_token) for recv in receivers:
for recv in receivers: chat_id = recv.get("chat_id")
chat_id = recv.get("chat_id") if not chat_id:
if not chat_id: continue
continue result = await client.send_message(
result = await client.send_message( chat_id=str(chat_id),
chat_id=str(chat_id), text=message,
text=message, disable_web_page_preview=bool(disable_preview),
disable_web_page_preview=bool(disable_preview), )
) results.append(result)
results.append(result)
return _aggregate(results) return _aggregate(results)
@@ -113,15 +113,17 @@ async def _send_webhook_broadcast(target: NotificationTarget, message: str, rece
if not receivers: if not receivers:
return {"success": False, "error": "No receivers configured"} return {"success": False, "error": "No receivers configured"}
from .http_session import get_http_session
http = await get_http_session()
results: list[dict] = [] results: list[dict] = []
async with aiohttp.ClientSession() as session: for recv in receivers:
for recv in receivers: url = recv.get("url")
url = recv.get("url") headers = recv.get("headers", {})
headers = recv.get("headers", {}) if not url:
if not url: continue
continue client = WebhookClient(http, url, headers)
client = WebhookClient(session, url, headers) results.append(await client.send({"message": message, "event_type": "notification"}))
results.append(await client.send({"message": message, "event_type": "notification"}))
return _aggregate(results) return _aggregate(results)
@@ -178,22 +180,24 @@ async def _send_webhook_like_broadcast(target: NotificationTarget, message: str,
if not receivers: if not receivers:
return {"success": False, "error": "No receivers configured"} return {"success": False, "error": "No receivers configured"}
from .http_session import get_http_session
http = await get_http_session()
results: list[dict] = [] results: list[dict] = []
async with aiohttp.ClientSession() as session: if target.type == "discord":
if target.type == "discord": from notify_bridge_core.notifications.discord.client import DiscordClient
from notify_bridge_core.notifications.discord.client import DiscordClient client = DiscordClient(http)
client = DiscordClient(session) for recv in receivers:
for recv in receivers: url = recv.get("webhook_url")
url = recv.get("webhook_url") if url:
if url: results.append(await client.send(url, message, username=target.config.get("username")))
results.append(await client.send(url, message, username=target.config.get("username"))) elif target.type == "slack":
elif target.type == "slack": from notify_bridge_core.notifications.slack.client import SlackClient
from notify_bridge_core.notifications.slack.client import SlackClient client = SlackClient(http)
client = SlackClient(session) for recv in receivers:
for recv in receivers: url = recv.get("webhook_url")
url = recv.get("webhook_url") if url:
if url: results.append(await client.send(url, message, username=target.config.get("username")))
results.append(await client.send(url, message, username=target.config.get("username")))
return _aggregate(results) return _aggregate(results)
@@ -207,18 +211,20 @@ async def _send_ntfy_broadcast(target: NotificationTarget, message: str, receive
return {"success": False, "error": "No receivers configured"} return {"success": False, "error": "No receivers configured"}
from notify_bridge_core.notifications.ntfy.client import NtfyClient from notify_bridge_core.notifications.ntfy.client import NtfyClient
from .http_session import get_http_session
http = await get_http_session()
results: list[dict] = [] results: list[dict] = []
async with aiohttp.ClientSession() as session: client = NtfyClient(http)
client = NtfyClient(session) for recv in receivers:
for recv in receivers: topic = recv.get("topic")
topic = recv.get("topic") if topic:
if topic: results.append(await client.send(
results.append(await client.send( server_url, topic, message,
server_url, topic, message, title="Notify Bridge",
title="Notify Bridge", priority=recv.get("priority", 3),
priority=recv.get("priority", 3), auth_token=auth_token,
auth_token=auth_token, ))
))
return _aggregate(results) return _aggregate(results)
@@ -243,13 +249,15 @@ async def _send_matrix_broadcast(target: NotificationTarget, message: str, recei
if not receivers: if not receivers:
return {"success": False, "error": "No receivers configured"} return {"success": False, "error": "No receivers configured"}
from .http_session import get_http_session
http = await get_http_session()
results: list[dict] = [] results: list[dict] = []
async with aiohttp.ClientSession() as http: client = MatrixClient(http, homeserver, access_token)
client = MatrixClient(http, homeserver, access_token) for recv in receivers:
for recv in receivers: room_id = recv.get("room_id")
room_id = recv.get("room_id") if room_id:
if room_id: results.append(await client.send_message(room_id, message, html_message=message))
results.append(await client.send_message(room_id, message, html_message=message))
return _aggregate(results) return _aggregate(results)
@@ -31,11 +31,50 @@ async def start_scheduler() -> None:
from .telegram_poller import start_command_listener_polling from .telegram_poller import start_command_listener_polling
await start_command_listener_polling() await start_command_listener_polling()
# Schedule daily cleanup of old event log entries
_schedule_event_cleanup()
# Start debounced command auto-sync scheduler # Start debounced command auto-sync scheduler
from .command_sync import start_sync_scheduler from .command_sync import start_sync_scheduler
start_sync_scheduler() start_sync_scheduler()
def _schedule_event_cleanup() -> None:
"""Schedule a daily job to delete EventLog entries older than 90 days."""
from apscheduler.triggers.cron import CronTrigger
scheduler = get_scheduler()
job_id = "cleanup_old_events"
if scheduler.get_job(job_id):
return
scheduler.add_job(
_cleanup_old_events,
CronTrigger(hour=3, minute=0),
id=job_id,
replace_existing=True,
max_instances=1,
)
_LOGGER.info("Scheduled daily event log cleanup at 03:00 UTC")
async def _cleanup_old_events() -> None:
"""Delete EventLog entries older than 90 days."""
from datetime import datetime, timedelta, timezone
from sqlmodel import delete
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import EventLog
cutoff = datetime.now(timezone.utc) - timedelta(days=90)
engine = get_engine()
async with AsyncSession(engine) as session:
await session.exec(delete(EventLog).where(EventLog.created_at < cutoff))
await session.commit()
_LOGGER.info("Cleaned up event log entries older than %s", cutoff.date())
async def _load_tracker_jobs() -> None: async def _load_tracker_jobs() -> None:
"""Load enabled trackers and schedule polling jobs.""" """Load enabled trackers and schedule polling jobs."""
from sqlmodel import select from sqlmodel import select
@@ -50,13 +89,16 @@ async def _load_tracker_jobs() -> None:
result = await session.exec(select(NotificationTracker).where(NotificationTracker.enabled == True)) result = await session.exec(select(NotificationTracker).where(NotificationTracker.enabled == True))
trackers = result.all() trackers = result.all()
# Pre-load provider types for scheduler detection # Batch-load provider types for scheduler detection
unique_provider_ids = list({t.provider_id for t in trackers})
provider_types: dict[int, str] = {} provider_types: dict[int, str] = {}
for tracker in trackers: if unique_provider_ids:
if tracker.provider_id not in provider_types: provider_result = await session.exec(
provider = await session.get(ServiceProviderModel, tracker.provider_id) select(ServiceProviderModel).where(
if provider: ServiceProviderModel.id.in_(unique_provider_ids)
provider_types[tracker.provider_id] = provider.type )
)
provider_types = {p.id: p.type for p in provider_result.all()}
for tracker in trackers: for tracker in trackers:
job_id = f"tracker_{tracker.id}" job_id = f"tracker_{tracker.id}"
@@ -86,6 +128,7 @@ async def _load_tracker_jobs() -> None:
id=job_id, id=job_id,
args=[tracker.id], args=[tracker.id],
replace_existing=True, replace_existing=True,
max_instances=1,
) )
_LOGGER.info("Scheduled tracker %d (%s) every %ds", tracker.id, tracker.name, tracker.scan_interval) _LOGGER.info("Scheduled tracker %d (%s) every %ds", tracker.id, tracker.name, tracker.scan_interval)
@@ -106,6 +149,7 @@ def _add_cron_job(
id=job_id, id=job_id,
args=[tracker_id], args=[tracker_id],
replace_existing=True, replace_existing=True,
max_instances=1,
) )
_LOGGER.info("Scheduled tracker %d (%s) with cron: %s", tracker_id, tracker_name, cron_expression) _LOGGER.info("Scheduled tracker %d (%s) with cron: %s", tracker_id, tracker_name, cron_expression)
@@ -13,7 +13,6 @@ from __future__ import annotations
import logging import logging
from typing import Any from typing import Any
import aiohttp
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -47,10 +46,18 @@ async def _get_bot_ids_with_active_listeners() -> set[int]:
listeners = result.all() listeners = result.all()
active_bot_ids: set[int] = set() active_bot_ids: set[int] = set()
for listener in listeners: tracker_ids = list({l.command_tracker_id for l in listeners})
tracker = await session.get(CommandTracker, listener.command_tracker_id) if tracker_ids:
if tracker and tracker.enabled: tracker_result = await session.exec(
active_bot_ids.add(listener.listener_id) select(CommandTracker).where(
CommandTracker.id.in_(tracker_ids),
CommandTracker.enabled == True, # noqa: E712
)
)
enabled_tracker_ids = {t.id for t in tracker_result.all()}
for listener in listeners:
if listener.command_tracker_id in enabled_tracker_ids:
active_bot_ids.add(listener.listener_id)
return active_bot_ids return active_bot_ids
@@ -145,21 +152,23 @@ async def _poll_bot(bot_id: int) -> None:
if not bot or bot.update_mode != "polling": if not bot or bot.update_mode != "polling":
unschedule_bot_polling(bot_id) unschedule_bot_polling(bot_id)
return return
# Extract what we need before closing session # Copy attributes before session closes to avoid detached-instance errors
from types import SimpleNamespace
bot_token = bot.token bot_token = bot.token
bot_obj = bot bot_obj = SimpleNamespace(id=bot.id, name=bot.name, token=bot.token)
offset = _last_update_id.get(bot_id, 0) offset = _last_update_id.get(bot_id, 0)
try: try:
async with aiohttp.ClientSession() as http: from .http_session import get_http_session
client = TelegramClient(http, bot_token) http = await get_http_session()
result = await client.get_updates( client = TelegramClient(http, bot_token)
offset=offset + 1 if offset else None, limit=50, result = await client.get_updates(
) offset=offset + 1 if offset else None, limit=50,
if not result.get("success"): )
return if not result.get("success"):
updates = result.get("result", []) return
updates = result.get("result", [])
except Exception as e: except Exception as e:
_LOGGER.debug("Polling error for bot %d: %s", bot_id, e) _LOGGER.debug("Polling error for bot %d: %s", bot_id, e)
return return
@@ -209,17 +218,13 @@ async def _poll_bot(bot_id: int) -> None:
continue continue
effective_lang = chat_row.language_override or msg_language effective_lang = chat_row.language_override or msg_language
message_id = message.get("message_id") message_id = message.get("message_id")
cmd_response = await handle_command(bot_obj, chat_id, text, language_code=effective_lang) responses = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
if cmd_response is not None: if responses:
if isinstance(cmd_response, dict) and "media" in cmd_response: for resp in responses:
# Text + media: send text first, media as reply if resp.text:
from ..commands.handler import send_reply as _reply await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
await _reply(bot_token, chat_id, cmd_response["text"], reply_to_message_id=message_id) if resp.media:
await send_media_group(bot_token, chat_id, cmd_response["media"], reply_to_message_id=message_id) await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
elif isinstance(cmd_response, list):
await send_media_group(bot_token, chat_id, cmd_response, reply_to_message_id=message_id)
else:
await send_reply(bot_token, chat_id, cmd_response, reply_to_message_id=message_id)
except Exception: except Exception:
_LOGGER.error("Error handling command from bot %d", bot_id, exc_info=True) _LOGGER.error("Error handling command from bot %d", bot_id, exc_info=True)
@@ -7,7 +7,6 @@ objects and dispatches through the same path the watcher uses.
import logging import logging
from typing import Any from typing import Any
import aiohttp
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -183,58 +182,59 @@ async def _build_immich_event(
memory_source = getattr(tracking_config, "memory_source", "albums") if tracking_config else "albums" memory_source = getattr(tracking_config, "memory_source", "albums") if tracking_config else "albums"
is_memory = test_type == "memory" is_memory = test_type == "memory"
async with aiohttp.ClientSession() as http_session: from .http_session import get_http_session
immich = ImmichServiceProvider( http_session = await get_http_session()
http_session, immich = ImmichServiceProvider(
provider_config.get("url", ""), http_session,
provider_config.get("api_key", ""), provider_config.get("url", ""),
provider_config.get("external_domain"), provider_config.get("api_key", ""),
provider_name, provider_config.get("external_domain"),
) provider_name,
if not await immich.connect(): )
return None if not await immich.connect():
return None
# Native Immich memories API path # Native Immich memories API path
if is_memory and memory_source == "native": if is_memory and memory_source == "native":
return await _build_native_memory_event( return await _build_native_memory_event(
immich, ext_domain, provider_name, tracker_name, immich, ext_domain, provider_name, tracker_name,
collection_ids, limit, asset_type, favorite_only, min_rating, collection_ids, limit, asset_type, favorite_only, min_rating,
)
# Album-based path: use shared collect_scheduled_assets
albums: dict[str, ImmichAlbumData] = {}
shared_links: dict[str, list[SharedLinkInfo]] = {}
for album_id in collection_ids:
album = await immich.client.get_album(album_id)
if album:
albums[album_id] = album
shared_links[album_id] = await immich.client.get_shared_links(album_id)
assets, collections_extra = collect_scheduled_assets(
albums, shared_links, ext_domain,
limit=limit,
asset_type=asset_type,
favorite_only=favorite_only,
min_rating=min_rating,
is_memory=is_memory,
) )
first_col = collections_extra[0] if collections_extra else {} # Album-based path: use shared collect_scheduled_assets
return ServiceEvent( albums: dict[str, ImmichAlbumData] = {}
event_type=EventType.SCHEDULED_MESSAGE, shared_links: dict[str, list[SharedLinkInfo]] = {}
provider_type=ServiceProviderType.IMMICH, for album_id in collection_ids:
provider_name=provider_name, album = await immich.client.get_album(album_id)
collection_id=collection_ids[0] if collection_ids else "", if album:
collection_name=first_col.get("name", tracker_name), albums[album_id] = album
timestamp=datetime.now(timezone.utc), shared_links[album_id] = await immich.client.get_shared_links(album_id)
added_assets=assets,
added_count=len(assets), assets, collections_extra = collect_scheduled_assets(
extra={ albums, shared_links, ext_domain,
"collections": collections_extra, limit=limit,
"albums": collections_extra, asset_type=asset_type,
**(first_col if first_col else {}), favorite_only=favorite_only,
}, min_rating=min_rating,
) is_memory=is_memory,
)
first_col = collections_extra[0] if collections_extra else {}
return ServiceEvent(
event_type=EventType.SCHEDULED_MESSAGE,
provider_type=ServiceProviderType.IMMICH,
provider_name=provider_name,
collection_id=collection_ids[0] if collection_ids else "",
collection_name=first_col.get("name", tracker_name),
timestamp=datetime.now(timezone.utc),
added_assets=assets,
added_count=len(assets),
extra={
"collections": collections_extra,
"albums": collections_extra,
**(first_col if first_col else {}),
},
)
async def _build_native_memory_event( async def _build_native_memory_event(
@@ -6,7 +6,6 @@ import asyncio
import logging import logging
from typing import Any from typing import Any
import aiohttp
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
@@ -102,19 +101,20 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
if provider_type == "immich": if provider_type == "immich":
from notify_bridge_core.providers.immich import ImmichServiceProvider from notify_bridge_core.providers.immich import ImmichServiceProvider
async with aiohttp.ClientSession() as http_session: from .http_session import get_http_session
immich = ImmichServiceProvider( http_session = await get_http_session()
http_session, immich = ImmichServiceProvider(
provider_config.get("url", ""), http_session,
provider_config.get("api_key", ""), provider_config.get("url", ""),
provider_config.get("external_domain"), provider_config.get("api_key", ""),
provider_name, provider_config.get("external_domain"),
) provider_name,
connected = await immich.connect() )
if not connected: connected = await immich.connect()
return {"status": "error", "reason": "failed to connect to provider"} if not connected:
return {"status": "error", "reason": "failed to connect to provider"}
events, new_state = await immich.poll(collection_ids, state_dict) events, new_state = await immich.poll(collection_ids, state_dict)
elif provider_type == "gitea": elif provider_type == "gitea":
# Gitea is webhook-based — events arrive via /api/webhooks/gitea endpoint. # Gitea is webhook-based — events arrive via /api/webhooks/gitea endpoint.
# The scheduler still calls check_tracker but there's nothing to poll. # The scheduler still calls check_tracker but there's nothing to poll.
@@ -143,18 +143,22 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
events, new_state = await nut.poll(collection_ids, state_dict) events, new_state = await nut.poll(collection_ids, state_dict)
elif provider_type == "google_photos": elif provider_type == "google_photos":
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
async with aiohttp.ClientSession() as http_session: from .http_session import get_http_session
gp = GooglePhotosServiceProvider( http_session = await get_http_session()
http_session, gp = GooglePhotosServiceProvider(
provider_config.get("client_id", ""), http_session,
provider_config.get("client_secret", ""), provider_config.get("client_id", ""),
provider_config.get("refresh_token", ""), provider_config.get("client_secret", ""),
provider_name, provider_config.get("refresh_token", ""),
) provider_name,
connected = await gp.connect() )
if not connected: connected = await gp.connect()
return {"status": "error", "reason": "failed to connect to Google Photos"} if not connected:
events, new_state = await gp.poll(collection_ids, state_dict) return {"status": "error", "reason": "failed to connect to Google Photos"}
events, new_state = await gp.poll(collection_ids, state_dict)
elif provider_type == "webhook":
# Webhook providers receive events via inbound HTTP; no polling needed.
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
else: else:
return {"status": "error", "reason": f"unsupported provider type: {provider_type}"} return {"status": "error", "reason": f"unsupported provider type: {provider_type}"}