refactor: comprehensive codebase review — security, performance, quality, UX

Security:
- Fix NUT protocol command injection (validate names against safe regex)
- Enable Jinja2 autoescape=True to prevent HTML injection via external data
- Add WebhookProviderConfig validation model

Performance:
- Shared aiohttp.ClientSession singleton (replaces 40+ per-request sessions)
- Fix 4 N+1 queries with batch IN loads (poller, scheduler, memory, broadcast)
- asyncio.gather for Gitea commands and notification dispatcher
- Add DB indexes on NotificationTrackerState.tracker_id, CommandTrackerListener
- LRU cache for compiled Jinja2 templates
- Daily EventLog cleanup job (90-day retention)
- 30s HTTP timeout on all external calls
- GROUP BY for target type counts (replaces 7 sequential queries)

Code quality:
- Extract get_owned_entity() helper (replaces 11 duplicate functions)
- Extract slot_helpers.py (load_slots, save_slots, render_template_preview)
- Extract command_utils.py (tracker lookup, last event, collection IDs)
- Extract http_session.py (shared session lifecycle)
- Provider connection validation dedup (3x → 1 helper)
- Command dispatch tables replacing if/elif chains
- Album+links fetch helper (fetch_albums_with_links)
- Provider dispatch polymorphism (list_provider_collections)
- Immutable _enrich_assets (no longer mutates in-place)
- Fix _format_assets return type + handler unpacking

Frontend:
- Fix 18+ hardcoded English strings → t() with new i18n keys (en + ru)
- Mobile "More" nav panel with provider filter and search
- Shared Button.svelte component (4 variants, 2 sizes)
- Shared ErrorBanner.svelte component (8 pages updated)
- SvelteKit goto() replacing window.location.href
- Dashboard grid fixed for 4 cards, paginator opacity consistency

Functionality:
- max_instances=1 on scheduler jobs (prevents duplicate events)
- Webhook provider in watcher (prevents error spam)
- Fix stale SQLModel reference in poller
- Gitea get_repo() direct API call
This commit is contained in:
2026-03-28 13:22:26 +03:00
parent 616b221c92
commit b803d004e1
65 changed files with 1934 additions and 1498 deletions
+34 -1
View File
@@ -43,6 +43,17 @@ jobs:
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache
cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max
- name: Trigger Portainer redeploy
continue-on-error: true
run: |
if [ -n "${{ secrets.DOCKER_REDEPLOY_WEBHOOK_URL }}" ]; then
echo "Triggering Portainer redeploy..."
curl -sf -X POST "${{ secrets.DOCKER_REDEPLOY_WEBHOOK_URL }}" \
--max-time 30 || echo "::warning::Portainer webhook failed"
else
echo "DOCKER_REDEPLOY_WEBHOOK_URL not set — skipping auto-deploy"
fi
- name: Generate changelog
id: changelog
run: |
@@ -56,7 +67,29 @@ jobs:
- name: Create Gitea Release
run: |
BODY=$(cat /tmp/changelog.txt | python3 -c "import sys,json; print(json.dumps(sys.stdin.read()))")
if [ -f RELEASE_NOTES.md ]; then
export RELEASE_NOTES=$(cat RELEASE_NOTES.md)
echo "Found RELEASE_NOTES.md"
else
export RELEASE_NOTES=""
echo "No RELEASE_NOTES.md found"
fi
BODY=$(python3 -c "
import json, os, sys
release_notes = os.environ.get('RELEASE_NOTES', '')
changelog = open('/tmp/changelog.txt').read().strip()
sections = []
if release_notes.strip():
sections.append(release_notes.strip())
if changelog:
sections.append('## Changelog\n\n' + changelog)
print(json.dumps('\n\n'.join(sections)))
")
curl -s -X POST \
"https://${{ env.REGISTRY }}/api/v1/repos/${{ env.IMAGE_NAME }}/releases" \
-H "Authorization: token ${{ secrets.RELEASE_TOKEN }}" \
+85
View File
@@ -0,0 +1,85 @@
<script lang="ts">
import type { Snippet } from 'svelte';
let {
variant = 'primary',
size = 'md',
disabled = false,
type = 'button',
href,
onclick,
children,
class: extraClass = '',
}: {
variant?: 'primary' | 'secondary' | 'danger' | 'ghost';
size?: 'sm' | 'md';
disabled?: boolean;
type?: 'button' | 'submit';
href?: string;
onclick?: (e: MouseEvent) => void;
children: Snippet;
class?: string;
} = $props();
const baseClasses = 'inline-flex items-center justify-center gap-1.5 rounded-md text-sm font-medium transition-colors disabled:opacity-50';
const sizeClasses: Record<string, string> = {
sm: 'px-2.5 py-1 text-xs',
md: 'px-4 py-2',
};
const variantClasses: Record<string, string> = {
primary: 'btn-primary',
secondary: 'btn-secondary',
danger: 'btn-danger',
ghost: 'btn-ghost',
};
const classes = $derived(
`${baseClasses} ${sizeClasses[size]} ${variantClasses[variant]} ${extraClass}`.trim()
);
</script>
{#if href && !disabled}
<a {href} class={classes} onclick={onclick}>
{@render children()}
</a>
{:else}
<button {type} {disabled} class={classes} onclick={onclick}>
{@render children()}
</button>
{/if}
<style>
.btn-primary {
background: var(--color-primary);
color: var(--color-primary-foreground);
}
.btn-primary:hover:not(:disabled) {
opacity: 0.9;
}
.btn-secondary {
background: var(--color-muted);
color: var(--color-foreground);
border: 1px solid var(--color-border);
}
.btn-secondary:hover:not(:disabled) {
opacity: 0.8;
}
.btn-danger {
background: var(--color-error-fg);
color: white;
}
.btn-danger:hover:not(:disabled) {
opacity: 0.9;
}
.btn-ghost {
background: transparent;
color: var(--color-muted-foreground);
}
.btn-ghost:hover:not(:disabled) {
background: var(--color-muted);
color: var(--color-foreground);
}
</style>
@@ -1,5 +1,6 @@
<script lang="ts">
import MdiIcon from './MdiIcon.svelte';
import { t } from '$lib/i18n';
export interface EntityItem {
value: string | number;
@@ -142,7 +143,7 @@
<div class="ep-list" bind:this={listEl} role="listbox">
{#if filtered.length === 0}
<div class="ep-empty">No matches</div>
<div class="ep-empty">{t('common.noMatches')}</div>
{:else}
{#each filtered as item, i}
<button
@@ -0,0 +1,13 @@
<script lang="ts">
interface Props {
message: string;
class?: string;
}
let { message, class: className = '' }: Props = $props();
</script>
{#if message}
<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4 {className}">
{message}
</div>
{/if}
@@ -1,5 +1,6 @@
<script lang="ts">
import MdiIcon from './MdiIcon.svelte';
import { t } from '$lib/i18n';
export interface GridItem {
value: string | number;
@@ -117,7 +118,7 @@
</button>
{/each}
{#if filtered.length === 0}
<div class="icon-grid-empty" style="grid-column: 1 / -1; text-align: center; padding: 0.75rem; color: var(--color-muted-foreground); font-size: 0.75rem;">No matches</div>
<div class="icon-grid-empty" style="grid-column: 1 / -1; text-align: center; padding: 0.75rem; color: var(--color-muted-foreground); font-size: 0.75rem;">{t('common.noMatches')}</div>
{/if}
</div>
</div>
+2 -1
View File
@@ -1,6 +1,7 @@
<script lang="ts">
import { onMount } from 'svelte';
import MdiIcon from './MdiIcon.svelte';
import { t } from '$lib/i18n';
let { open = false, title = '', onclose, children } = $props<{
open: boolean;
@@ -93,7 +94,7 @@
>
<div style="display: flex; align-items: center; justify-content: space-between; padding: 1.5rem 1.5rem 1rem;">
<h3 id="modal-title-{uniqueId}" style="font-size: 1.125rem; font-weight: 600;">{title}</h3>
<button class="modal-close" onclick={onclose} aria-label="Close">
<button class="modal-close" onclick={onclose} aria-label={t('common.close')}>
<MdiIcon name="mdiClose" size={18} />
</button>
</div>
@@ -1,5 +1,6 @@
<script lang="ts">
import MdiIcon from './MdiIcon.svelte';
import { t } from '$lib/i18n';
export interface MultiEntityItem {
value: string;
@@ -132,7 +133,7 @@
<div class="mes-list" bind:this={listEl} role="listbox">
{#if filtered.length === 0}
<div class="mes-empty">No matches</div>
<div class="mes-empty">{t('common.noMatches')}</div>
{:else}
{#each filtered as item, i}
{@const checked = (values || []).includes(item.value)}
+1 -1
View File
@@ -56,7 +56,7 @@
{/if}
{/if}
</div>
<button class="snack-close" onclick={() => removeSnack(snack.id)} aria-label="Dismiss">
<button class="snack-close" onclick={() => removeSnack(snack.id)} aria-label={t('common.dismiss')}>
<MdiIcon name="mdiClose" size={14} />
</button>
</div>
+2
View File
@@ -110,6 +110,8 @@ export const previewTargetTypeItems = (): GridItem[] => [
{ value: 'email', icon: 'mdiEmailOutline', label: 'Email', desc: t('gridDesc.previewEmail') },
{ value: 'discord', icon: 'mdiChat', label: 'Discord', desc: t('gridDesc.previewDiscord') },
{ value: 'slack', icon: 'mdiSlack', label: 'Slack', desc: t('gridDesc.previewSlack') },
{ value: 'ntfy', icon: 'mdiBellOutline', label: 'ntfy', desc: t('gridDesc.previewNtfy') },
{ value: 'matrix', icon: 'mdiMatrix', label: 'Matrix', desc: t('gridDesc.previewMatrix') },
];
// --- Provider type items (derived from descriptor registry) ---
+28 -10
View File
@@ -36,7 +36,8 @@
"targetMatrix": "Matrix",
"targetBroadcast": "Broadcast",
"automation": "Automation",
"actions": "Actions"
"actions": "Actions",
"more": "More"
},
"auth": {
"signIn": "Sign in",
@@ -51,7 +52,9 @@
"creatingAccount": "Creating account...",
"passwordMismatch": "Passwords do not match",
"passwordTooShort": "Password must be at least 8 characters",
"or": "or"
"or": "or",
"loginFailed": "Login failed",
"setupFailed": "Setup failed"
},
"dashboard": {
"title": "Dashboard",
@@ -150,7 +153,9 @@
"gpRefreshTokenHint": "Obtain from Google OAuth Playground (developers.google.com/oauthplayground) with the Photos Library API scope.",
"gpAllFieldsRequired": "Client ID, Client Secret, and Refresh Token are all required",
"testAndSave": "Test & Save",
"saveWithoutTest": "Save without testing"
"saveWithoutTest": "Save without testing",
"selectType": "Select a provider type",
"testFailed": "Connection test failed"
},
"notificationTracker": {
"title": "Notification Trackers",
@@ -231,7 +236,8 @@
"noLink": "No Link",
"saveWithoutLinks": "Save without links",
"createLinks": "Create {count} link(s)",
"linksNote": "You can also create links manually in Immich."
"linksNote": "You can also create links manually in Immich.",
"createdLinks": "Created {count} public link(s)"
},
"templates": {
"title": "Templates",
@@ -409,7 +415,9 @@
"cacheTtl": "Media cache TTL (hours)",
"cacheTtlHint": "How long to cache uploaded Telegram file_ids before re-uploading (default: 48h)",
"settingsSaved": "Settings saved",
"noExternalDomain": "External domain URL not configured"
"noExternalDomain": "External domain URL not configured",
"saveFailed": "Failed to save bot",
"webhookFailed": "Failed to register webhook"
},
"trackingConfig": {
"title": "Tracking Configs",
@@ -584,7 +592,7 @@
"added_assets": "List of asset dicts (use {% for asset in added_assets %})",
"removed_assets": "List of removed asset IDs (strings)",
"shared": "Whether album is shared (boolean)",
"target_type": "Target type: 'telegram' or 'webhook'",
"target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
"has_videos": "Whether added assets contain videos (boolean)",
"has_photos": "Whether added assets contain photos (boolean)",
"old_name": "Previous album name (rename events)",
@@ -675,7 +683,8 @@
"displayName": "Display Name",
"testConnection": "Test connection",
"noBots": "No Matrix bots yet.",
"confirmDelete": "Delete this Matrix bot?"
"confirmDelete": "Delete this Matrix bot?",
"operationFailed": "Operation failed"
},
"emailBot": {
"title": "Email Bots",
@@ -693,7 +702,8 @@
"useTls": "Use TLS/SSL",
"testConnection": "Send test email",
"noBots": "No email bots yet.",
"confirmDelete": "Delete this email bot?"
"confirmDelete": "Delete this email bot?",
"operationFailed": "Operation failed"
},
"cmdTemplateConfig": {
"title": "Command Templates",
@@ -841,7 +851,12 @@
"allTypes": "All types",
"allProviders": "All providers",
"noFilterResults": "No items match the current filter.",
"redirecting": "Redirecting..."
"redirecting": "Redirecting...",
"noMatches": "No matches",
"saveFailed": "Save failed",
"loadFailed": "Failed to load data",
"dismiss": "Dismiss",
"systemSuffix": " (System)"
},
"templateSlot": {
"message_assets_added": "New assets added to album",
@@ -926,12 +941,15 @@
"previewEmail": "Preview with email HTML format",
"previewDiscord": "Preview with Discord markdown",
"previewSlack": "Preview with Slack markdown",
"previewNtfy": "Preview as ntfy notification",
"previewMatrix": "Preview with Matrix HTML format",
"providerImmich": "Self-hosted photo server",
"providerGitea": "Self-hosted Git service",
"providerPlanka": "Self-hosted Kanban board",
"providerScheduler": "Time-based scheduled messages",
"providerNut": "Network UPS monitoring",
"providerGooglePhotos": "Google Photos albums & shared libraries"
"providerGooglePhotos": "Google Photos albums & shared libraries",
"providerWebhook": "Receive events via HTTP POST"
},
"error": {
"notFound": "Page not found",
+28 -10
View File
@@ -36,7 +36,8 @@
"targetMatrix": "Matrix",
"targetBroadcast": "Рассылка",
"automation": "Автоматизация",
"actions": "Действия"
"actions": "Действия",
"more": "Ещё"
},
"auth": {
"signIn": "Войти",
@@ -51,7 +52,9 @@
"creatingAccount": "Создание...",
"passwordMismatch": "Пароли не совпадают",
"passwordTooShort": "Пароль должен быть не менее 8 символов",
"or": "или"
"or": "или",
"loginFailed": "Ошибка входа",
"setupFailed": "Ошибка настройки"
},
"dashboard": {
"title": "Главная",
@@ -150,7 +153,9 @@
"gpRefreshTokenHint": "Получите через Google OAuth Playground (developers.google.com/oauthplayground) с областью Photos Library API.",
"gpAllFieldsRequired": "Client ID, Client Secret и Refresh Token обязательны",
"testAndSave": "Проверить и сохранить",
"saveWithoutTest": "Сохранить без проверки"
"saveWithoutTest": "Сохранить без проверки",
"selectType": "Выберите тип провайдера",
"testFailed": "Ошибка проверки подключения"
},
"notificationTracker": {
"title": "Трекеры уведомлений",
@@ -231,7 +236,8 @@
"noLink": "Нет ссылки",
"saveWithoutLinks": "Сохранить без ссылок",
"createLinks": "Создать {count} ссылку(и)",
"linksNote": "Вы также можете создать ссылки вручную в Immich."
"linksNote": "Вы также можете создать ссылки вручную в Immich.",
"createdLinks": "Создано публичных ссылок: {count}"
},
"templates": {
"title": "Шаблоны",
@@ -409,7 +415,9 @@
"cacheTtl": "TTL кэша медиа (часы)",
"cacheTtlHint": "Сколько хранить кэш Telegram file_id перед повторной загрузкой (по умолчанию: 48ч)",
"settingsSaved": "Настройки сохранены",
"noExternalDomain": "Внешний URL домена не настроен"
"noExternalDomain": "Внешний URL домена не настроен",
"saveFailed": "Не удалось сохранить бота",
"webhookFailed": "Не удалось зарегистрировать webhook"
},
"trackingConfig": {
"title": "Конфигурации отслеживания",
@@ -584,7 +592,7 @@
"added_assets": "Список файлов ({% for asset in added_assets %})",
"removed_assets": "Список ID удалённых файлов (строки)",
"shared": "Общий альбом (boolean)",
"target_type": "Тип получателя: 'telegram' или 'webhook'",
"target_type": "Тип получателя: telegram, webhook, email, discord, slack, ntfy или matrix",
"has_videos": "Содержат ли добавленные файлы видео (boolean)",
"has_photos": "Содержат ли добавленные файлы фото (boolean)",
"old_name": "Прежнее название альбома (при переименовании)",
@@ -675,7 +683,8 @@
"displayName": "Отображаемое имя",
"testConnection": "Проверить подключение",
"noBots": "Matrix ботов пока нет.",
"confirmDelete": "Удалить этот Matrix бот?"
"confirmDelete": "Удалить этот Matrix бот?",
"operationFailed": "Операция не удалась"
},
"emailBot": {
"title": "Email боты",
@@ -693,7 +702,8 @@
"useTls": "Использовать TLS/SSL",
"testConnection": "Отправить тестовое письмо",
"noBots": "Email ботов пока нет.",
"confirmDelete": "Удалить этот email бот?"
"confirmDelete": "Удалить этот email бот?",
"operationFailed": "Операция не удалась"
},
"cmdTemplateConfig": {
"title": "Шаблоны команд",
@@ -841,7 +851,12 @@
"allTypes": "Все типы",
"allProviders": "Все провайдеры",
"noFilterResults": "Нет элементов, соответствующих фильтру.",
"redirecting": "Перенаправление..."
"redirecting": "Перенаправление...",
"noMatches": "Ничего не найдено",
"saveFailed": "Не удалось сохранить",
"loadFailed": "Не удалось загрузить данные",
"dismiss": "Закрыть",
"systemSuffix": " (Системный)"
},
"templateSlot": {
"message_assets_added": "Новые файлы добавлены в альбом",
@@ -926,12 +941,15 @@
"previewEmail": "Предпросмотр в формате Email HTML",
"previewDiscord": "Предпросмотр в формате Discord",
"previewSlack": "Предпросмотр в формате Slack",
"previewNtfy": "Предпросмотр уведомления ntfy",
"previewMatrix": "Предпросмотр в формате Matrix HTML",
"providerImmich": "Фотосервер для самостоятельного размещения",
"providerGitea": "Git-сервер для самостоятельного размещения",
"providerPlanka": "Канбан-доска для самостоятельного размещения",
"providerScheduler": "Запланированные сообщения по расписанию",
"providerNut": "Мониторинг ИБП через NUT",
"providerGooglePhotos": "Альбомы и общие библиотеки Google Фото"
"providerGooglePhotos": "Альбомы и общие библиотеки Google Фото",
"providerWebhook": "Приём событий через HTTP POST"
},
"error": {
"notFound": "Страница не найдена",
+63 -5
View File
@@ -215,15 +215,31 @@
});
}
// Mobile: flatten nav for bottom bar
// Mobile: flatten nav for bottom bar (first 4 + "More" button)
const mobileNavItems = $derived<NavItem[]>([
{ href: '/', key: 'nav.dashboard', icon: 'mdiViewDashboard' },
{ href: '/notification-trackers', key: 'nav.notification', icon: 'mdiBellOutline' },
{ href: '/command-trackers', key: 'nav.commands', icon: 'mdiConsoleLine' },
{ href: '/targets', key: 'nav.targets', icon: 'mdiTarget' },
{ href: '/bots?tab=telegram', key: 'nav.bots', icon: 'mdiRobot' },
]);
// "More" panel items — everything not in the bottom bar
const mobileMoreItems = $derived<NavItem[]>([
{ href: '/providers', key: 'nav.providers', icon: 'mdiServer' },
{ href: '/bots?tab=telegram', key: 'nav.bots', icon: 'mdiRobot' },
{ href: '/actions', key: 'nav.actions', icon: 'mdiPlayCircleOutline' },
{ href: '/tracking-configs', key: 'nav.configs', icon: 'mdiCog' },
{ href: '/template-configs', key: 'nav.templates', icon: 'mdiFileDocumentEdit' },
{ href: '/command-configs', key: 'nav.configs', icon: 'mdiConsoleLine' },
{ href: '/command-template-configs', key: 'nav.templates', icon: 'mdiCodeBracesBox' },
...(auth.isAdmin ? [
{ href: '/settings', key: 'nav.settings', icon: 'mdiCogOutline' },
{ href: '/users', key: 'nav.users', icon: 'mdiAccountGroup' },
] : []),
]);
let mobileMoreOpen = $state(false);
const isAuthPage = $derived(
page.url.pathname === '/login' || page.url.pathname === '/setup'
);
@@ -526,12 +542,50 @@
<MdiIcon name={item.icon} size={20} />
</a>
{/each}
<button onclick={logout} aria-label={t('nav.logout')}
class="flex flex-col items-center gap-0.5 px-2 py-1.5 text-xs" style="color: var(--color-muted-foreground);">
<MdiIcon name="mdiLogout" size={20} />
<button onclick={() => openSearch?.()} aria-label={t('searchPalette.placeholder')}
class="flex flex-col items-center gap-0.5 px-2 py-1.5 text-xs rounded-lg transition-all duration-200"
style="color: var(--color-muted-foreground);">
<MdiIcon name="mdiMagnify" size={20} />
</button>
<button onclick={() => mobileMoreOpen = !mobileMoreOpen} aria-label={t('nav.more')}
class="flex flex-col items-center gap-0.5 px-2 py-1.5 text-xs rounded-lg transition-all duration-200"
style="color: {mobileMoreOpen ? 'var(--color-primary)' : 'var(--color-muted-foreground)'};">
<MdiIcon name="mdiDotsHorizontal" size={20} />
</button>
</nav>
<!-- Mobile "More" panel -->
{#if mobileMoreOpen}
<div class="mobile-more-backdrop" style="position: fixed; inset: 0; z-index: 49; background: rgba(0,0,0,0.4); backdrop-filter: blur(2px);"
onclick={() => mobileMoreOpen = false} role="presentation"></div>
<div class="mobile-more-panel" style="position: fixed; bottom: 3.25rem; left: 0; right: 0; z-index: 50; background: var(--color-sidebar); border-top: 1px solid var(--color-border); border-radius: 1rem 1rem 0 0; padding: 1rem; max-height: 60vh; overflow-y: auto;"
transition:slide={{ duration: 200, easing: cubicOut }}>
{#if allProviders.length > 1}
<div class="mb-3 pb-3" style="border-bottom: 1px solid var(--color-border);">
<IconGridSelect items={providerFilterItems} bind:value={providerFilterValue} columns={Math.min(providerFilterItems.length, 4)} compact />
</div>
{/if}
<div class="grid grid-cols-3 gap-2">
{#each mobileMoreItems as item}
<a href={item.href}
onclick={() => mobileMoreOpen = false}
class="flex flex-col items-center gap-1 p-3 rounded-lg transition-all duration-200"
style="color: {isActive(item.href) ? 'var(--color-primary)' : 'var(--color-muted-foreground)'}; background: {isActive(item.href) ? 'var(--color-sidebar-active)' : 'transparent'};"
>
<MdiIcon name={item.icon} size={20} />
<span class="text-xs text-center leading-tight">{t(item.key)}</span>
</a>
{/each}
<button onclick={() => { mobileMoreOpen = false; logout(); }}
class="flex flex-col items-center gap-1 p-3 rounded-lg transition-all duration-200"
style="color: var(--color-muted-foreground);">
<MdiIcon name="mdiLogout" size={20} />
<span class="text-xs text-center leading-tight">{t('nav.logout')}</span>
</button>
</div>
</div>
{/if}
<!-- Main content -->
<main class="flex-1 overflow-auto pb-16 md:pb-0">
{#key page.url.pathname}
@@ -579,6 +633,10 @@
<style>
@media (max-width: 767px) {
.mobile-nav { display: flex !important; }
.mobile-more-panel a:hover,
.mobile-more-panel button:hover {
background: var(--color-muted);
}
}
/* Provider filter chips */
+3 -3
View File
@@ -231,7 +231,7 @@
</div>
</Card>
{:else if status}
<div class="grid grid-cols-1 sm:grid-cols-3 gap-4 mb-8 stagger-children">
<div class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-4 gap-4 mb-8 stagger-children">
{#each statCards as card, i}
<div class="stat-card" style="--accent: {card.color};">
<div class="stat-card-inner">
@@ -289,7 +289,7 @@
<div class="flex items-center justify-center gap-1">
{#if totalPages > 1}
<button onclick={() => goToPage(currentPage - 1)} disabled={currentPage <= 1}
class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-30 disabled:cursor-default">
class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-50 disabled:cursor-default">
<MdiIcon name="mdiChevronLeft" size={16} />
</button>
{#each Array.from({ length: totalPages }, (_, i) => i + 1) as page}
@@ -305,7 +305,7 @@
{/if}
{/each}
<button onclick={() => goToPage(currentPage + 1)} disabled={currentPage >= totalPages}
class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-30 disabled:cursor-default">
class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-50 disabled:cursor-default">
<MdiIcon name="mdiChevronRight" size={16} />
</button>
{/if}
+6 -5
View File
@@ -10,6 +10,8 @@
import ConfirmModal from '$lib/components/ConfirmModal.svelte';
import IconButton from '$lib/components/IconButton.svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import Button from '$lib/components/Button.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import type { EmailBot } from '$lib/types';
let { onreload }: { onreload: () => Promise<void> } = $props();
@@ -72,22 +74,21 @@
try {
const res = await api(`/email-bots/${botId}/test`, { method: 'POST' });
if (res.success) snackSuccess(t('snack.emailBotTestSent'));
else snackError(res.error || 'Failed');
else snackError(res.error || t('emailBot.operationFailed'));
} catch (err: any) { snackError(err.message); }
emailTesting = { ...emailTesting, [botId]: false };
}
</script>
<PageHeader title={t('emailBot.title')} description={t('emailBot.description')}>
<button onclick={() => { showEmailForm ? (showEmailForm = false, editingEmail = null) : openNewEmail(); }}
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
<Button size="sm" onclick={() => { showEmailForm ? (showEmailForm = false, editingEmail = null) : openNewEmail(); }}>
{showEmailForm ? t('common.cancel') : t('emailBot.addBot')}
</button>
</Button>
</PageHeader>
{#if showEmailForm}
<Card class="mb-6">
{#if error}<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>{/if}
<ErrorBanner message={error} />
<form onsubmit={saveEmailBot} class="space-y-3">
<div>
<label for="ebot-name" class="block text-sm font-medium mb-1">{t('emailBot.name')}</label>
+6 -5
View File
@@ -10,6 +10,8 @@
import ConfirmModal from '$lib/components/ConfirmModal.svelte';
import IconButton from '$lib/components/IconButton.svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import Button from '$lib/components/Button.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import type { MatrixBot } from '$lib/types';
let { onreload }: { onreload: () => Promise<void> } = $props();
@@ -70,22 +72,21 @@
try {
const res = await api(`/matrix-bots/${botId}/test`, { method: 'POST' });
if (res.success) snackSuccess(t('snack.matrixBotTestOk'));
else snackError(res.error || 'Failed');
else snackError(res.error || t('matrixBot.operationFailed'));
} catch (err: any) { snackError(err.message); }
matrixTesting = { ...matrixTesting, [botId]: false };
}
</script>
<PageHeader title={t('matrixBot.title')} description={t('matrixBot.description')}>
<button onclick={() => { showMatrixForm ? (showMatrixForm = false, editingMatrix = null) : openNewMatrix(); }}
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
<Button size="sm" onclick={() => { showMatrixForm ? (showMatrixForm = false, editingMatrix = null) : openNewMatrix(); }}>
{showMatrixForm ? t('common.cancel') : t('matrixBot.addBot')}
</button>
</Button>
</PageHeader>
{#if showMatrixForm}
<Card class="mb-6">
{#if error}<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>{/if}
<ErrorBanner message={error} />
<form onsubmit={saveMatrixBot} class="space-y-3">
<div>
<label for="mbot-name" class="block text-sm font-medium mb-1">{t('matrixBot.name')}</label>
@@ -12,6 +12,8 @@
import IconButton from '$lib/components/IconButton.svelte';
import EntitySelect from '$lib/components/EntitySelect.svelte';
import { snackSuccess, snackError, snackInfo } from '$lib/stores/snackbar.svelte';
import Button from '$lib/components/Button.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import type { TelegramBot, TelegramChat } from '$lib/types';
interface CommandTrackerSummary { id: number; name: string; icon?: string; enabled: boolean }
@@ -186,7 +188,7 @@
try {
const res = await api<ApiResult>(`/telegram-bots/${botId}/sync-commands`, { method: 'POST' });
if (res.success) snackSuccess(t('telegramBot.commandsSynced'));
else snackError(res.error || 'Failed');
else snackError(res.error || t('telegramBot.saveFailed'));
} catch (err: any) { snackError(err.message); }
modeChanging = { ...modeChanging, [botId]: false };
}
@@ -218,7 +220,7 @@
snackSuccess(res.verified ? t('telegramBot.webhookVerified') : t('telegramBot.webhookRegistered'));
await loadWebhookStatus(botId);
} else {
snackError(res.error || 'Failed to register webhook');
snackError(res.error || t('telegramBot.webhookFailed'));
}
} catch (err: any) { snackError(err.message); }
modeChanging = { ...modeChanging, [botId]: false };
@@ -229,7 +231,7 @@
try {
const res = await api<ApiResult>(`/telegram-bots/${botId}/webhook/unregister`, { method: 'POST' });
if (res.success) { snackSuccess(t('telegramBot.webhookUnregistered')); await loadWebhookStatus(botId); }
else snackError(res.error || 'Failed');
else snackError(res.error || t('telegramBot.saveFailed'));
} catch (err: any) { snackError(err.message); }
modeChanging = { ...modeChanging, [botId]: false };
}
@@ -260,7 +262,7 @@
try {
const res = await api<ApiResult>(`/telegram-bots/${botId}/chats/${chatId}/test?locale=${getLocale()}`, { method: 'POST' });
if (res.success) snackSuccess(t('snack.targetTestSent'));
else snackError(res.error || 'Failed');
else snackError(res.error || t('telegramBot.saveFailed'));
} catch (err: any) { snackError(err.message); }
chatTesting = { ...chatTesting, [key]: false };
}
@@ -277,15 +279,14 @@
</script>
<PageHeader title={t('telegramBot.title')} description={t('telegramBot.description')}>
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
<Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
{showForm ? t('common.cancel') : t('telegramBot.addBot')}
</button>
</Button>
</PageHeader>
{#if showForm}
<Card class="mb-6">
{#if error}<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>{/if}
<ErrorBanner message={error} />
<form onsubmit={saveBot} class="space-y-3">
<div>
<label for="bot-name" class="block text-sm font-medium mb-1">{t('telegramBot.name')}</label>
@@ -15,6 +15,7 @@
import IconGridSelect from '$lib/components/IconGridSelect.svelte';
import { providerTypeItems, providerTypeFilterItems, responseModeItems } from '$lib/grid-items';
import EntitySelect from '$lib/components/EntitySelect.svelte';
import Button from '$lib/components/Button.svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import { highlightFromUrl } from '$lib/highlight';
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
@@ -37,7 +38,7 @@
let cmdTemplateConfigs = $derived(commandTemplateConfigsCache.items);
const templateItems = $derived(cmdTemplateConfigs
.filter((c) => c.provider_type === form.provider_type)
.map((c) => ({ value: c.id, label: c.name + (c.user_id === 0 ? ' (System)' : ''), icon: c.icon || 'mdiCodeBracesBox', desc: c.provider_type }))
.map((c) => ({ value: c.id, label: c.name + (c.user_id === 0 ? t('common.systemSuffix') : ''), icon: c.icon || 'mdiCodeBracesBox', desc: c.provider_type }))
);
let loaded = $state(false);
let showForm = $state(false);
@@ -151,10 +152,9 @@
</script>
<PageHeader title={t('commandConfig.title')} description={t('commandConfig.description')}>
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
<Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
{showForm ? t('common.cancel') : t('commandConfig.newConfig')}
</button>
</Button>
</PageHeader>
{#if !loaded}<Loading />{:else}
+1 -1
View File
@@ -32,7 +32,7 @@
await login(username, password);
window.location.href = '/';
} catch (err: any) {
error = err.message || 'Login failed';
error = err.message || t('auth.loginFailed');
}
submitting = false;
}
@@ -17,6 +17,8 @@
import { providerDefaultIcon } from '$lib/grid-items';
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
import { getDescriptor } from '$lib/providers';
import Button from '$lib/components/Button.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import type { Tracker, TrackerTarget, TrackingConfig, TemplateConfig, NotificationTarget } from '$lib/types';
import TrackerForm from './TrackerForm.svelte';
@@ -119,7 +121,7 @@
capabilitiesCache.fetch(),
]);
} catch (err: any) {
loadError = err.message || 'Failed to load data';
loadError = err.message || t('common.loadFailed');
snackError(loadError);
} finally { loaded = true; highlightFromUrl(); }
}
@@ -212,7 +214,7 @@
}
}
}
if (created > 0) snackSuccess(`Created ${created} public link(s)`);
if (created > 0) snackSuccess(t('notificationTracker.createdLinks').replace('{count}', String(created)));
linkWarning = null;
linkCreating = false;
await doSave();
@@ -361,17 +363,16 @@
</script>
<PageHeader title={t('notificationTracker.title')} description={t('notificationTracker.description')}>
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
<Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
{showForm ? t('notificationTracker.cancel') : t('notificationTracker.newTracker')}
</button>
</Button>
</PageHeader>
{#if !loaded}
<Loading />
{:else if loadError}
<Card>
<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3">{loadError}</div>
<ErrorBanner message={loadError} class="mb-0" />
</Card>
{:else if showForm}
<TrackerForm
+7 -9
View File
@@ -12,12 +12,14 @@
import EmptyState from '$lib/components/EmptyState.svelte';
import ConfirmModal from '$lib/components/ConfirmModal.svelte';
import IconButton from '$lib/components/IconButton.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import IconGridSelect from '$lib/components/IconGridSelect.svelte';
import { providerTypeItems, providerDefaultIcon } from '$lib/grid-items';
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import { highlightFromUrl } from '$lib/highlight';
import { getDescriptor, buildProviderFormDefaults } from '$lib/providers';
import Button from '$lib/components/Button.svelte';
import type { ServiceProvider } from '$lib/types';
let allProviders = $derived(providersCache.items);
@@ -136,10 +138,9 @@
</script>
<PageHeader title={t('providers.title')} description={t('providers.description')}>
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
<Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
{showForm ? t('providers.cancel') : t('providers.addProvider')}
</button>
</Button>
</PageHeader>
{#if !loaded}
@@ -158,9 +159,7 @@
{#if showForm}
<div in:slide={{ duration: 200 }}>
<Card class="mb-6">
{#if error}
<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>
{/if}
<ErrorBanner message={error} />
<form onsubmit={save} class="space-y-3">
<div>
<label class="block text-sm font-medium mb-1">{t('providers.type')}</label>
@@ -211,10 +210,9 @@
<p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.webhookUrlHint')}</p>
</div>
{/if}
<button type="submit" disabled={submitting}
class="px-4 py-2 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90 disabled:opacity-50">
<Button type="submit" disabled={submitting}>
{submitting ? t('providers.connecting') : (editing ? t('common.save') : t('providers.addProvider'))}
</button>
</Button>
</form>
</Card>
</div>
+17 -18
View File
@@ -5,8 +5,11 @@
import Card from '$lib/components/Card.svelte';
import IconPicker from '$lib/components/IconPicker.svelte';
import IconGridSelect from '$lib/components/IconGridSelect.svelte';
import { goto } from '$app/navigation';
import { providerTypeItems } from '$lib/grid-items';
import { getDescriptor, buildProviderFormDefaults } from '$lib/providers';
import Button from '$lib/components/Button.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
let form = $state(buildProviderFormDefaults());
let error = $state('');
@@ -16,7 +19,7 @@
async function testAndSave() {
const desc = descriptor;
if (!desc) { error = 'Select a provider type'; return; }
if (!desc) { error = t('providers.selectType'); return; }
const { config, error: buildError } = desc.buildConfig(form, false);
if (buildError) { error = t(buildError); snackError(error); return; }
@@ -32,22 +35,22 @@
if (!result.ok) {
await api(`/providers/${provider.id}`, { method: 'DELETE' }).catch(() => {});
createdId = null;
error = result.message || 'Connection test failed';
error = result.message || t('providers.testFailed');
snackError(error);
} else {
snackSuccess(t('snack.providerSaved'));
window.location.href = '/providers';
goto('/providers');
}
} catch (e: any) {
if (createdId) await api(`/providers/${createdId}`, { method: 'DELETE' }).catch(() => {});
error = e.message || 'Test failed'; snackError(error);
error = e.message || t('providers.testFailed'); snackError(error);
}
finally { testing = false; }
}
async function saveWithoutTest() {
const desc = descriptor;
if (!desc) { error = 'Select a provider type'; return; }
if (!desc) { error = t('providers.selectType'); return; }
const { config, error: buildError } = desc.buildConfig(form, false);
if (buildError) { error = t(buildError); snackError(error); return; }
@@ -58,8 +61,8 @@
body: JSON.stringify({ type: form.type, name: form.name || desc.defaultName, icon: form.icon, config }),
});
snackSuccess(t('snack.providerSaved'));
window.location.href = '/providers';
} catch (e: any) { error = e.message || 'Save failed'; snackError(error); }
goto('/providers');
} catch (e: any) { error = e.message || t('common.saveFailed'); snackError(error); }
finally { saving = false; }
}
</script>
@@ -112,22 +115,18 @@
</div>
{/each}
{#if error}
<p class="text-sm text-[var(--color-error-fg)]">{error}</p>
{/if}
<ErrorBanner message={error} />
<div class="flex gap-3 pt-2">
<button onclick={testAndSave} disabled={testing || saving}
class="px-4 py-2 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90 disabled:opacity-50">
<Button onclick={testAndSave} disabled={testing || saving}>
{testing ? t('providers.connecting') : t('providers.testAndSave')}
</button>
<button onclick={saveWithoutTest} disabled={testing || saving}
class="px-4 py-2 bg-[var(--color-muted)] text-[var(--color-foreground)] rounded-md text-sm font-medium hover:opacity-80 disabled:opacity-50">
</Button>
<Button variant="secondary" onclick={saveWithoutTest} disabled={testing || saving}>
{saving ? t('common.loading') : t('providers.saveWithoutTest')}
</button>
<a href="/providers" class="px-4 py-2 bg-[var(--color-muted)] text-[var(--color-muted-foreground)] rounded-md text-sm font-medium hover:opacity-80">
</Button>
<Button variant="secondary" href="/providers">
{t('common.cancel')}
</a>
</Button>
</div>
</div>
</Card>
+6 -3
View File
@@ -7,10 +7,12 @@
import Loading from '$lib/components/Loading.svelte';
import MdiIcon from '$lib/components/MdiIcon.svelte';
import Hint from '$lib/components/Hint.svelte';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
let loaded = $state(false);
let saving = $state(false);
let error = $state('');
let settings = $state({
external_url: '',
telegram_webhook_secret: '',
@@ -20,16 +22,16 @@
onMount(async () => {
try {
settings = await api('/settings');
} catch (err: any) { snackError(err.message); }
} catch (err: any) { error = err.message; snackError(err.message); }
finally { loaded = true; }
});
async function save() {
saving = true;
saving = true; error = '';
try {
settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) });
snackSuccess(t('settings.saved'));
} catch (err: any) { snackError(err.message); }
} catch (err: any) { error = err.message; snackError(err.message); }
saving = false;
}
</script>
@@ -39,6 +41,7 @@
{#if !loaded}
<Loading />
{:else}
<ErrorBanner message={error} />
<div class="space-y-6">
<!-- General section -->
<Card>
+1 -1
View File
@@ -25,7 +25,7 @@
try {
await setup(username, password);
window.location.href = '/';
} catch (err: any) { error = err.message || 'Setup failed'; }
} catch (err: any) { error = err.message || t('auth.setupFailed'); }
submitting = false;
}
</script>
+2 -1
View File
@@ -15,6 +15,7 @@
import { chatActionItems } from '$lib/grid-items';
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
import { highlightFromUrl } from '$lib/highlight';
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
import type { NotificationTarget, TargetReceiver, TelegramChat } from '$lib/types';
import TargetForm from './TargetForm.svelte';
@@ -419,7 +420,7 @@
{#if !loaded}<Loading />{:else}
{#if loadError}
<div class="mb-4 p-3 rounded-md text-sm bg-[var(--color-error-bg)] text-[var(--color-error-fg)]">{loadError}</div>
<ErrorBanner message={loadError} />
{/if}
{#if showForm}
@@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any
@@ -68,14 +69,17 @@ class NotificationDispatcher:
Returns list of results (one per target).
"""
raw_results = await asyncio.gather(
*[self._send_to_target(event, t) for t in targets],
return_exceptions=True,
)
results = []
for target in targets:
try:
result = await self._send_to_target(event, target)
results.append(result)
except Exception as e:
_LOGGER.error("Failed to dispatch to target: %s", e)
results.append({"success": False, "error": str(e)})
for raw in raw_results:
if isinstance(raw, Exception):
_LOGGER.error("Failed to dispatch to target: %s", raw)
results.append({"success": False, "error": str(raw)})
else:
results.append(raw)
return results
def _resolve_template(
@@ -85,6 +85,20 @@ class GiteaClient:
return repos
async def get_repo(self, owner: str, repo: str) -> dict[str, Any] | None:
"""Fetch a single repository by owner/repo name."""
try:
async with self._session.get(
f"{self._url}/api/v1/repos/{owner}/{repo}",
headers=self._headers,
) as response:
if response.status == 200:
return await response.json()
_LOGGER.warning("Failed to fetch repo %s/%s: HTTP %s", owner, repo, response.status)
except aiohttp.ClientError as err:
_LOGGER.warning("Failed to fetch repo %s/%s: %s", owner, repo, err)
return None
async def get_repo_issues(
self, owner: str, repo: str, state: str = "open", limit: int = 10,
) -> list[dict[str, Any]]:
@@ -14,12 +14,28 @@ _DEFAULT_PORT = 3493
_READ_TIMEOUT = 10.0
_CONNECT_TIMEOUT = 5.0
# Allowed characters for NUT protocol identifiers (UPS names, variable names).
# Prevents command injection via newlines or special characters.
_SAFE_NAME_RE = re.compile(r"^[\w.\-]+$")
# Regex to parse VAR lines: VAR <ups> <name> "<value>"
_VAR_RE = re.compile(r'^VAR\s+(\S+)\s+(\S+)\s+"(.*)"$')
# Regex to parse UPS lines: UPS <name> "<description>"
_UPS_RE = re.compile(r'^UPS\s+(\S+)\s+"(.*)"$')
def _validate_name(value: str, label: str) -> None:
"""Validate that *value* is a safe NUT protocol identifier.
Raises ``NutClientError`` if *value* contains characters outside
``[\\w.\\-]``, which could be used for protocol command injection.
"""
if not _SAFE_NAME_RE.match(value):
raise NutClientError(
f"Invalid {label}: {value!r} contains disallowed characters"
)
class NutClientError(Exception):
"""Error communicating with NUT server."""
@@ -91,6 +107,7 @@ class NutClient:
async def list_var(self, ups_name: str) -> dict[str, str]:
"""Get all variables for a UPS device."""
_validate_name(ups_name, "UPS name")
lines = await self._list_command(f"LIST VAR {ups_name}")
variables: dict[str, str] = {}
for line in lines:
@@ -101,6 +118,8 @@ class NutClient:
async def get_var(self, ups_name: str, var_name: str) -> str:
"""Get a single variable value."""
_validate_name(ups_name, "UPS name")
_validate_name(var_name, "variable name")
response = await self._command(f"GET VAR {ups_name} {var_name}")
m = _VAR_RE.match(response)
if m:
@@ -10,7 +10,7 @@ from jinja2.sandbox import SandboxedEnvironment
_LOGGER = logging.getLogger(__name__)
_env = SandboxedEnvironment(autoescape=False)
_env = SandboxedEnvironment(autoescape=True)
def render_template(template_str: str, context: dict[str, Any]) -> str:
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import Action, ActionRule, User
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__)
@@ -59,10 +60,9 @@ def _rule_response(rule: ActionRule) -> dict:
async def _get_user_action(
session: AsyncSession, action_id: int, user: User
) -> Action:
action = await session.get(Action, action_id)
if not action or action.user_id != user.id:
raise HTTPException(status_code=404, detail="Action not found")
return action
return await get_owned_entity(
session, Action, action_id, user.id, not_found_msg="Action not found",
)
# ---------------------------------------------------------------------------
@@ -12,12 +12,10 @@ from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from jinja2.sandbox import SandboxedEnvironment
from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import CommandTemplateConfig, CommandTemplateSlot, User
from .slot_helpers import load_slots, render_template_preview, save_slots
_LOGGER = logging.getLogger(__name__)
@@ -44,38 +42,11 @@ class CommandTemplateConfigUpdate(BaseModel):
# ---------------------------------------------------------------------------
async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]:
"""Load slots as {slot_name: {locale: template}}."""
result = await session.exec(
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == config_id)
)
nested: dict[str, dict[str, str]] = {}
for s in result.all():
nested.setdefault(s.slot_name, {})[s.locale] = s.template
return nested
return await load_slots(session, CommandTemplateSlot, config_id)
async def _save_slots(session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]]) -> None:
"""Save slots from {slot_name: {locale: template}} format."""
for slot_name, locale_map in slots.items():
for locale, template_text in locale_map.items():
result = await session.exec(
select(CommandTemplateSlot).where(
CommandTemplateSlot.config_id == config_id,
CommandTemplateSlot.slot_name == slot_name,
CommandTemplateSlot.locale == locale,
)
)
existing = result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(CommandTemplateSlot(
config_id=config_id,
slot_name=slot_name,
locale=locale,
template=template_text,
))
await save_slots(session, CommandTemplateSlot, config_id, slots)
async def _response(session: AsyncSession, c: CommandTemplateConfig) -> dict[str, Any]:
@@ -367,18 +338,4 @@ async def preview_raw(
"wait": 15,
}
try:
env = SandboxedEnvironment(autoescape=False)
env.from_string(body.template)
except TemplateSyntaxError as e:
return {"rendered": None, "error": e.message, "error_line": e.lineno}
try:
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
tmpl = strict_env.from_string(body.template)
rendered = tmpl.render(**sample_ctx)
return {"rendered": rendered}
except UndefinedError as e:
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
except Exception as e:
return {"rendered": None, "error": str(e), "error_line": None}
return render_template_preview(body.template, sample_ctx)
@@ -17,6 +17,7 @@ from ..database.models import (
TelegramBot,
User,
)
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__)
@@ -401,7 +402,7 @@ async def _listener_response(session: AsyncSession, l: CommandTrackerListener) -
async def _get_user_tracker(
session: AsyncSession, tracker_id: int, user_id: int
) -> CommandTracker:
tracker = await session.get(CommandTracker, tracker_id)
if not tracker or tracker.user_id != user_id:
raise HTTPException(status_code=404, detail="Command tracker not found")
return tracker
return await get_owned_entity(
session, CommandTracker, tracker_id, user_id,
not_found_msg="Command tracker not found",
)
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import EmailBot, User
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__)
@@ -156,7 +157,6 @@ def _response(bot: EmailBot) -> dict:
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> EmailBot:
bot = await session.get(EmailBot, bot_id)
if not bot or bot.user_id != user_id:
raise HTTPException(status_code=404, detail="Email bot not found")
return bot
return await get_owned_entity(
session, EmailBot, bot_id, user_id, not_found_msg="Email bot not found",
)
@@ -0,0 +1,25 @@
"""Shared helpers for API route modules."""
from typing import TypeVar
from fastapi import HTTPException
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
T = TypeVar("T", bound=SQLModel)
async def get_owned_entity(
session: AsyncSession,
model: type[T],
entity_id: int,
user_id: int,
*,
owner_field: str = "user_id",
not_found_msg: str = "Not found",
) -> T:
"""Fetch an entity by PK and verify ownership, or raise 404."""
entity = await session.get(model, entity_id)
if not entity or getattr(entity, owner_field) != user_id:
raise HTTPException(status_code=404, detail=not_found_msg)
return entity
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import MatrixBot, User
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__)
@@ -108,33 +109,34 @@ async def test_matrix_bot(
bot = await _get_user_bot(session, bot_id, user.id)
import aiohttp
async with aiohttp.ClientSession() as http:
# Verify token with /whoami
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
headers = {"Authorization": f"Bearer {bot.access_token}"}
try:
async with http.get(whoami_url, headers=headers) as resp:
if resp.status != 200:
body = await resp.text()
return {"success": False, "error": f"Auth failed: HTTP {resp.status}{body[:200]}"}
whoami = await resp.json()
except aiohttp.ClientError as e:
return {"success": False, "error": f"Connection failed: {e}"}
from ..services.http_session import get_http_session
http = await get_http_session()
# Verify token with /whoami
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
headers = {"Authorization": f"Bearer {bot.access_token}"}
try:
async with http.get(whoami_url, headers=headers) as resp:
if resp.status != 200:
body = await resp.text()
return {"success": False, "error": f"Auth failed: HTTP {resp.status}{body[:200]}"}
whoami = await resp.json()
except aiohttp.ClientError as e:
return {"success": False, "error": f"Connection failed: {e}"}
result = {"success": True, "user_id": whoami.get("user_id", "")}
result = {"success": True, "user_id": whoami.get("user_id", "")}
# Optionally send a test message
if room_id:
from notify_bridge_core.notifications.matrix.client import MatrixClient
client = MatrixClient(http, bot.homeserver_url, bot.access_token)
send_result = await client.send_message(
room_id,
"Test message from Notify Bridge",
html_message="<b>Test message</b> from Notify Bridge",
)
result["send_result"] = send_result
# Optionally send a test message
if room_id:
from notify_bridge_core.notifications.matrix.client import MatrixClient
client = MatrixClient(http, bot.homeserver_url, bot.access_token)
send_result = await client.send_message(
room_id,
"Test message from Notify Bridge",
html_message="<b>Test message</b> from Notify Bridge",
)
result["send_result"] = send_result
return result
return result
def _response(bot: MatrixBot) -> dict:
@@ -150,7 +152,6 @@ def _response(bot: MatrixBot) -> dict:
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> MatrixBot:
bot = await session.get(MatrixBot, bot_id)
if not bot or bot.user_id != user_id:
raise HTTPException(status_code=404, detail="Matrix bot not found")
return bot
return await get_owned_entity(
session, MatrixBot, bot_id, user_id, not_found_msg="Matrix bot not found",
)
@@ -23,6 +23,7 @@ from ..database.models import (
)
from ..services.notifier import send_test_notification
from ..services.test_dispatch import dispatch_test_notification
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__)
@@ -277,7 +278,7 @@ async def _tt_response(session: AsyncSession, tt: NotificationTrackerTarget) ->
async def _get_user_tracker(
session: AsyncSession, tracker_id: int, user_id: int
) -> NotificationTracker:
tracker = await session.get(NotificationTracker, tracker_id)
if not tracker or tracker.user_id != user_id:
raise HTTPException(status_code=404, detail="Tracker not found")
return tracker
return await get_owned_entity(
session, NotificationTracker, tracker_id, user_id,
not_found_msg="Tracker not found",
)
@@ -18,6 +18,7 @@ from ..database.models import (
User,
)
from ..services.scheduler import schedule_tracker, unschedule_tracker
from .helpers import get_owned_entity
from .notification_tracker_targets import _tt_response
_LOGGER = logging.getLogger(__name__)
@@ -205,7 +206,7 @@ async def _tracker_response(session: AsyncSession, t: NotificationTracker) -> di
async def _get_user_tracker(
session: AsyncSession, tracker_id: int, user_id: int
) -> NotificationTracker:
tracker = await session.get(NotificationTracker, tracker_id)
if not tracker or tracker.user_id != user_id:
raise HTTPException(status_code=404, detail="Tracker not found")
return tracker
return await get_owned_entity(
session, NotificationTracker, tracker_id, user_id,
not_found_msg="Tracker not found",
)
@@ -13,7 +13,12 @@ import aiohttp
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import ServiceProvider, User
from ..services import make_immich_provider, make_gitea_provider, make_planka_provider, make_nut_provider, make_google_photos_provider
from ..services import (
make_immich_provider, make_gitea_provider, make_planka_provider,
make_nut_provider, make_google_photos_provider, list_provider_collections,
)
from ..services.http_session import get_http_session
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__)
@@ -82,6 +87,20 @@ class GooglePhotosProviderConfig(BaseModel):
refresh_token: str
class PayloadMapping(BaseModel):
variable: str
jsonpath: str
default: str | None = None
class WebhookProviderConfig(BaseModel):
auth_mode: str = "none"
webhook_secret: str | None = None
payload_mappings: list[PayloadMapping] = []
event_type_path: str | None = None
collection_path: str | None = None
_PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
"immich": ImmichProviderConfig,
"gitea": GiteaProviderConfig,
@@ -89,6 +108,7 @@ _PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
"scheduler": SchedulerProviderConfig,
"nut": NutProviderConfig,
"google_photos": GooglePhotosProviderConfig,
"webhook": WebhookProviderConfig,
}
@@ -106,6 +126,70 @@ def _validate_provider_config(provider_type: str, config: dict[str, Any]) -> Non
)
async def _test_provider_connection(provider: ServiceProvider) -> dict[str, Any]:
"""Test provider connection and return the result dict.
For providers that lack optional credentials (gitea without api_token,
planka without api_key), returns a success stub.
"""
http_session = await get_http_session()
if provider.type == "immich":
immich = make_immich_provider(http_session, provider)
return await immich.test_connection()
if provider.type == "gitea":
if not provider.config.get("api_token"):
return {"ok": True, "message": "Gitea webhook-only mode (no API token for testing)"}
gitea = make_gitea_provider(http_session, provider)
return await gitea.test_connection()
if provider.type == "planka":
if not provider.config.get("api_key"):
return {"ok": True, "message": "Planka webhook-only mode (no API key for testing)"}
planka = make_planka_provider(http_session, provider)
return await planka.test_connection()
if provider.type == "nut":
nut = make_nut_provider(provider)
return await nut.test_connection()
if provider.type == "google_photos":
gp = make_google_photos_provider(http_session, provider)
return await gp.test_connection()
if provider.type in ("scheduler", "webhook"):
return {"ok": True, "message": "Virtual provider — always available"}
return {"ok": False, "message": f"Unknown provider type: {provider.type}"}
async def _validate_provider_connection(provider: ServiceProvider) -> dict[str, Any]:
"""Test provider connection. Raise HTTPException on failure.
Returns the test_result dict on success (caller may inspect extra fields
like ``external_domain``).
"""
try:
test_result = await _test_provider_connection(provider)
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
except OSError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", f"Cannot connect to {provider.type} provider"),
)
return test_result
@router.get("")
async def list_providers(
user: User = Depends(get_current_user),
@@ -128,96 +212,15 @@ async def create_provider(
"""Add a new service provider (validates connection for known types)."""
_validate_provider_config(body.type, body.config)
# Validate connection for known provider types
try:
if body.type == "immich":
from notify_bridge_core.providers.immich import ImmichServiceProvider
config = body.config
async with aiohttp.ClientSession() as http_session:
immich = ImmichServiceProvider(
http_session, config.get("url", ""), config.get("api_key", ""),
config.get("external_domain"), body.name,
)
test_result = await immich.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", f"Cannot connect to {body.type} provider"),
)
# Store external_domain from server config if available
if test_result.get("external_domain"):
config["external_domain"] = test_result["external_domain"]
# Build a temporary ServiceProvider for connection testing
temp_provider = ServiceProvider(
id=0, user_id=0, type=body.type, name=body.name, config=body.config,
)
test_result = await _validate_provider_connection(temp_provider)
elif body.type == "gitea":
config = body.config
# api_token is optional (webhook_secret is required, but token only for repo listing)
if config.get("api_token"):
async with aiohttp.ClientSession() as http_session:
from notify_bridge_core.providers.gitea import GiteaServiceProvider
gitea = GiteaServiceProvider(
http_session, config.get("url", ""), config.get("api_token", ""), body.name,
)
test_result = await gitea.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Gitea"),
)
elif body.type == "planka":
config = body.config
if config.get("api_key"):
async with aiohttp.ClientSession() as http_session:
from notify_bridge_core.providers.planka import PlankaServiceProvider
planka = PlankaServiceProvider(
http_session, config.get("url", ""), config.get("api_key", ""), body.name,
)
test_result = await planka.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Planka"),
)
elif body.type == "nut":
nut = make_nut_provider(ServiceProvider(
id=0, user_id=0, type="nut", name=body.name, config=body.config,
))
test_result = await nut.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to NUT server"),
)
elif body.type == "google_photos":
config = body.config
async with aiohttp.ClientSession() as http_session:
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
gp = GooglePhotosServiceProvider(
http_session, config.get("client_id", ""), config.get("client_secret", ""),
config.get("refresh_token", ""), body.name,
)
test_result = await gp.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Google Photos"),
)
except HTTPException:
raise
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
except OSError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
# Scheduler: no validation needed (virtual provider)
# Store external_domain from Immich server config if available
if test_result.get("external_domain"):
body.config["external_domain"] = test_result["external_domain"]
provider = ServiceProvider(
user_id=user.id,
@@ -307,78 +310,10 @@ async def update_provider(
provider.config = body.config
# Re-validate connection when config changes for known provider types
if config_changed and provider.type == "immich":
try:
async with aiohttp.ClientSession() as http_session:
immich = make_immich_provider(http_session, provider)
test_result = await immich.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", f"Cannot connect to {provider.type} provider"),
)
if test_result.get("external_domain"):
provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
elif config_changed and provider.type == "gitea":
if provider.config.get("api_token"):
try:
async with aiohttp.ClientSession() as http_session:
gitea = make_gitea_provider(http_session, provider)
test_result = await gitea.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Gitea"),
)
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
elif config_changed and provider.type == "planka":
if provider.config.get("api_key"):
try:
async with aiohttp.ClientSession() as http_session:
planka = make_planka_provider(http_session, provider)
test_result = await planka.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Planka"),
)
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
elif config_changed and provider.type == "nut":
nut = make_nut_provider(provider)
test_result = await nut.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to NUT server"),
)
elif config_changed and provider.type == "google_photos":
try:
async with aiohttp.ClientSession() as http_session:
gp = make_google_photos_provider(http_session, provider)
test_result = await gp.test_connection()
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", "Cannot connect to Google Photos"),
)
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
if config_changed:
test_result = await _validate_provider_connection(provider)
if test_result.get("external_domain"):
provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
session.add(provider)
await session.commit()
@@ -408,39 +343,7 @@ async def test_provider(
):
"""Check if a service provider is reachable."""
provider = await _get_user_provider(session, provider_id, user.id)
if provider.type == "immich":
async with aiohttp.ClientSession() as http_session:
immich = make_immich_provider(http_session, provider)
return await immich.test_connection()
if provider.type == "gitea":
if not provider.config.get("api_token"):
return {"ok": True, "message": "Gitea webhook-only mode (no API token for testing)"}
async with aiohttp.ClientSession() as http_session:
gitea = make_gitea_provider(http_session, provider)
return await gitea.test_connection()
if provider.type == "planka":
if not provider.config.get("api_key"):
return {"ok": True, "message": "Planka webhook-only mode (no API key for testing)"}
async with aiohttp.ClientSession() as http_session:
planka = make_planka_provider(http_session, provider)
return await planka.test_connection()
if provider.type == "scheduler":
return {"ok": True, "message": "Virtual provider — always available"}
if provider.type == "nut":
nut = make_nut_provider(provider)
return await nut.test_connection()
if provider.type == "google_photos":
async with aiohttp.ClientSession() as http_session:
gp = make_google_photos_provider(http_session, provider)
return await gp.test_connection()
return {"ok": False, "message": f"Unknown provider type: {provider.type}"}
return await _test_provider_connection(provider)
@router.get("/{provider_id}/people")
@@ -454,14 +357,14 @@ async def list_people(
if provider.type == "immich":
from notify_bridge_core.providers.immich.client import ImmichClient
async with aiohttp.ClientSession() as http_session:
client = ImmichClient(
http_session,
provider.config.get("url", ""),
provider.config.get("api_key", ""),
)
people = await client.get_people()
return [{"id": pid, "name": name} for pid, name in people.items()]
http_session = await get_http_session()
client = ImmichClient(
http_session,
provider.config.get("url", ""),
provider.config.get("api_key", ""),
)
people = await client.get_people()
return [{"id": pid, "name": name} for pid, name in people.items()]
return []
@@ -475,35 +378,7 @@ async def list_collections(
"""Fetch collections from a service provider."""
provider = await _get_user_provider(session, provider_id, user.id)
if provider.type == "immich":
async with aiohttp.ClientSession() as http_session:
immich = make_immich_provider(http_session, provider)
return await immich.list_collections()
if provider.type == "gitea":
if not provider.config.get("api_token"):
return []
async with aiohttp.ClientSession() as http_session:
gitea = make_gitea_provider(http_session, provider)
return await gitea.list_collections()
if provider.type == "planka":
if not provider.config.get("api_key"):
return []
async with aiohttp.ClientSession() as http_session:
planka = make_planka_provider(http_session, provider)
return await planka.list_collections()
if provider.type == "nut":
nut = make_nut_provider(provider)
return await nut.list_collections()
if provider.type == "google_photos":
async with aiohttp.ClientSession() as http_session:
gp = make_google_photos_provider(http_session, provider)
return await gp.list_collections()
return []
return await list_provider_collections(provider)
@router.get("/{provider_id}/albums/{album_id}/shared-links")
@@ -517,19 +392,19 @@ async def get_album_shared_links(
provider = await _get_user_provider(session, provider_id, user.id)
if provider.type == "immich":
async with aiohttp.ClientSession() as http_session:
immich = make_immich_provider(http_session, provider)
links = await immich.client.get_shared_links(album_id)
return [
{
"id": link.id,
"key": link.key,
"has_password": link.has_password,
"is_expired": link.is_expired,
"is_accessible": link.is_accessible,
}
for link in links
]
http_session = await get_http_session()
immich = make_immich_provider(http_session, provider)
links = await immich.client.get_shared_links(album_id)
return [
{
"id": link.id,
"key": link.key,
"has_password": link.has_password,
"is_expired": link.is_expired,
"is_accessible": link.is_accessible,
}
for link in links
]
return []
@@ -545,15 +420,13 @@ async def create_album_shared_link(
provider = await _get_user_provider(session, provider_id, user.id)
if provider.type == "immich":
async with aiohttp.ClientSession() as http_session:
immich = make_immich_provider(http_session, provider)
success = await immich.client.create_shared_link(album_id)
if success:
return {"success": True}
from fastapi import HTTPException
raise HTTPException(status_code=400, detail="Failed to create shared link")
http_session = await get_http_session()
immich = make_immich_provider(http_session, provider)
success = await immich.client.create_shared_link(album_id)
if success:
return {"success": True}
raise HTTPException(status_code=400, detail="Failed to create shared link")
from fastapi import HTTPException
raise HTTPException(status_code=400, detail="Provider type does not support shared links")
@@ -580,7 +453,7 @@ async def _get_user_provider(
session: AsyncSession, provider_id: int, user_id: int
) -> ServiceProvider:
"""Get a provider owned by the user, or raise 404."""
provider = await session.get(ServiceProvider, provider_id)
if not provider or provider.user_id != user_id:
raise HTTPException(status_code=404, detail="Provider not found")
return provider
return await get_owned_entity(
session, ServiceProvider, provider_id, user_id,
not_found_msg="Provider not found",
)
@@ -0,0 +1,90 @@
"""Shared slot load/save and Jinja2 preview helpers for template config APIs."""
from typing import TypeVar
from jinja2 import StrictUndefined, TemplateSyntaxError, UndefinedError
from jinja2.sandbox import SandboxedEnvironment
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
S = TypeVar("S", bound=SQLModel)
async def load_slots(
session: AsyncSession,
slot_model: type[S],
config_id: int,
) -> dict[str, dict[str, str]]:
"""Load all template slots for a config as {slot_name: {locale: template}}.
Works for both TemplateSlot and CommandTemplateSlot they share the same
column names (config_id, slot_name, locale, template).
"""
result = await session.exec(
select(slot_model).where(slot_model.config_id == config_id) # type: ignore[attr-defined]
)
slots: dict[str, dict[str, str]] = {}
for s in result.all():
slots.setdefault(s.slot_name, {})[s.locale] = s.template # type: ignore[attr-defined]
return slots
async def save_slots(
session: AsyncSession,
slot_model: type[S],
config_id: int,
slots: dict[str, dict[str, str]],
) -> None:
"""Create or update template slots for a config (locale-aware).
Works for both TemplateSlot and CommandTemplateSlot.
"""
for slot_name, locale_map in slots.items():
for locale, template_text in locale_map.items():
result = await session.exec(
select(slot_model).where(
slot_model.config_id == config_id, # type: ignore[attr-defined]
slot_model.slot_name == slot_name, # type: ignore[attr-defined]
slot_model.locale == locale, # type: ignore[attr-defined]
)
)
existing = result.first()
if existing:
existing.template = template_text # type: ignore[attr-defined]
session.add(existing)
else:
session.add(slot_model(
config_id=config_id,
slot_name=slot_name,
locale=locale,
template=template_text,
))
def render_template_preview(template: str, context: dict) -> dict:
"""Two-pass Jinja2 render: syntax check, then strict render.
Returns a dict with either ``{"rendered": str}`` on success, or
``{"rendered": None, "error": str, ...}`` on failure.
"""
# Pass 1: syntax check (default Undefined — catches parse errors only)
try:
env = SandboxedEnvironment(autoescape=False)
env.from_string(template)
except TemplateSyntaxError as e:
return {
"rendered": None,
"error": e.message,
"error_line": e.lineno,
}
# Pass 2: render with StrictUndefined to catch unknown variables
try:
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
tmpl = strict_env.from_string(template)
rendered = tmpl.render(**context)
return {"rendered": rendered}
except UndefinedError as e:
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
except Exception as e:
return {"rendered": None, "error": str(e), "error_line": None}
@@ -112,8 +112,16 @@ async def get_nav_counts(
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Return entity counts for sidebar navigation badges."""
counts = {}
"""Return entity counts for sidebar navigation badges.
Note: queries run sequentially because SQLAlchemy AsyncSession is NOT safe
for concurrent use within a single session (no asyncio.gather). We
minimise round-trips by combining user + system counts and per-type
target counts into single aggregate queries where possible.
"""
counts: dict[str, int] = {}
# --- 1) User-owned entity counts (one query per model) ---
for model, key in [
(ServiceProvider, "providers"),
(NotificationTracker, "notification_trackers"),
@@ -132,7 +140,7 @@ async def get_nav_counts(
)).one()
counts[key] = count
# System-owned entities (user_id=0) count as well
# --- 2) Add system-owned counts (user_id=0) for shared entities ---
for model, key in [
(TemplateConfig, "template_configs"),
(CommandTemplateConfig, "command_template_configs"),
@@ -144,15 +152,22 @@ async def get_nav_counts(
)).one()
counts[key] += system_count
# Per-type target counts for nav badges
for target_type in ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"):
type_count = (await session.exec(
select(func.count()).select_from(NotificationTarget).where(
NotificationTarget.user_id == user.id,
NotificationTarget.type == target_type,
)
)).one()
counts[f"targets_{target_type}"] = type_count
# --- 3) Per-type target counts in a single query using conditional aggregation ---
target_types = ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix")
type_counts_result = (await session.exec(
select(
NotificationTarget.type,
func.count(),
)
.where(
NotificationTarget.user_id == user.id,
NotificationTarget.type.in_(target_types),
)
.group_by(NotificationTarget.type)
)).all()
type_counts_map = dict(type_counts_result)
for target_type in target_types:
counts[f"targets_{target_type}"] = type_counts_map.get(target_type, 0)
return counts
@@ -12,6 +12,7 @@ from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import NotificationTarget, TargetReceiver, User
from ..services.notifier import send_to_receiver
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__)
@@ -170,7 +171,7 @@ def _response(r: TargetReceiver) -> dict:
async def _get_user_target(session: AsyncSession, target_id: int, user_id: int) -> NotificationTarget:
target = await session.get(NotificationTarget, target_id)
if not target or target.user_id != user_id:
raise HTTPException(status_code=404, detail="Target not found")
return target
return await get_owned_entity(
session, NotificationTarget, target_id, user_id,
not_found_msg="Target not found",
)
@@ -12,6 +12,7 @@ from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import NotificationTarget, NotificationTrackerTarget, TargetReceiver, TelegramBot, TelegramChat, User
from ..services.notifier import send_test_notification
from .helpers import get_owned_entity
from .target_receivers import _receiver_key
_LOGGER = logging.getLogger(__name__)
@@ -306,8 +307,15 @@ async def _validate_broadcast_children(
return
if exclude_target_id and exclude_target_id in child_ids:
raise HTTPException(status_code=400, detail="A broadcast target cannot include itself")
# Batch-load all children in a single IN query instead of N+1 individual fetches
children = (await session.exec(
select(NotificationTarget).where(NotificationTarget.id.in_(child_ids))
)).all()
children_by_id = {c.id: c for c in children}
for child_id in child_ids:
child = await session.get(NotificationTarget, child_id)
child = children_by_id.get(child_id)
if not child or child.user_id != user_id:
raise HTTPException(status_code=400, detail=f"Child target {child_id} not found")
if child.type == "broadcast":
@@ -378,7 +386,7 @@ def _safe_config(target: NotificationTarget) -> dict:
async def _get_user_target(
session: AsyncSession, target_id: int, user_id: int
) -> NotificationTarget:
target = await session.get(NotificationTarget, target_id)
if not target or target.user_id != user_id:
raise HTTPException(status_code=404, detail="Target not found")
return target
return await get_owned_entity(
session, NotificationTarget, target_id, user_id,
not_found_msg="Target not found",
)
@@ -7,8 +7,6 @@ from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
import aiohttp
from notify_bridge_core.notifications.telegram.client import TelegramClient
from ..auth.dependencies import get_current_user
@@ -19,6 +17,7 @@ from ..database.models import AppSetting, NotificationTarget, TargetReceiver, Te
from ..services.notifier import _get_test_message
from ..services.telegram_poller import schedule_bot_polling, unschedule_bot_polling
from .app_settings import get_setting
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__)
@@ -290,10 +289,11 @@ async def test_chat(
):
"""Send a test message to a chat via the bot."""
bot = await _get_user_bot(session, bot_id, user.id)
from ..services.http_session import get_http_session
message = _get_test_message(locale, "telegram")
async with aiohttp.ClientSession() as http:
client = TelegramClient(http, bot.token)
return await client.send_message(chat_id, message)
http = await get_http_session()
client = TelegramClient(http, bot.token)
return await client.send_message(chat_id, message)
class ChatUpdate(BaseModel):
@@ -344,41 +344,44 @@ async def delete_chat(
async def _get_webhook_info(token: str) -> dict | None:
"""Call Telegram getWebhookInfo via TelegramClient."""
async with aiohttp.ClientSession() as http:
client = TelegramClient(http, token)
result = await client.get_webhook_info()
return result.get("result") if result.get("success") else None
from ..services.http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, token)
result = await client.get_webhook_info()
return result.get("result") if result.get("success") else None
async def _get_me(token: str) -> dict | None:
"""Call Telegram getMe via TelegramClient."""
async with aiohttp.ClientSession() as http:
client = TelegramClient(http, token)
result = await client.get_me()
return result.get("result") if result.get("success") else None
from ..services.http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, token)
result = await client.get_me()
return result.get("result") if result.get("success") else None
async def _fetch_chats_from_telegram(token: str) -> list[dict]:
"""Fetch chats from Telegram getUpdates via TelegramClient."""
async with aiohttp.ClientSession() as http:
client = TelegramClient(http, token)
result = await client.get_updates(limit=100)
if not result.get("success"):
return []
from ..services.http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, token)
result = await client.get_updates(limit=100)
if not result.get("success"):
return []
seen: dict[int, dict] = {}
for update in result.get("result", []):
msg = update.get("message", {})
chat = msg.get("chat", {})
chat_id = chat.get("id")
if chat_id and chat_id not in seen:
seen[chat_id] = {
"id": chat_id,
"title": chat.get("title") or (chat.get("first_name", "") + (" " + chat.get("last_name", "")).strip()),
"type": chat.get("type", "private"),
"username": chat.get("username", ""),
}
return list(seen.values())
seen: dict[int, dict] = {}
for update in result.get("result", []):
msg = update.get("message", {})
chat = msg.get("chat", {})
chat_id = chat.get("id")
if chat_id and chat_id not in seen:
seen[chat_id] = {
"id": chat_id,
"title": chat.get("title") or (chat.get("first_name", "") + (" " + chat.get("last_name", "")).strip()),
"type": chat.get("type", "private"),
"username": chat.get("username", ""),
}
return list(seen.values())
def _chat_response(c: TelegramChat) -> dict:
@@ -410,10 +413,9 @@ def _bot_response(b: TelegramBot) -> dict:
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> TelegramBot:
bot = await session.get(TelegramBot, bot_id)
if not bot or bot.user_id != user_id:
raise HTTPException(status_code=404, detail="Bot not found")
return bot
return await get_owned_entity(
session, TelegramBot, bot_id, user_id, not_found_msg="Bot not found",
)
@@ -13,12 +13,12 @@ from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from jinja2.sandbox import SandboxedEnvironment
from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import TemplateConfig, TemplateSlot, User
from ..services.sample_context import _SAMPLE_CONTEXT
from .slot_helpers import load_slots, render_template_preview, save_slots
_LOGGER = logging.getLogger(__name__)
@@ -49,40 +49,13 @@ class TemplateConfigUpdate(BaseModel):
# ---------------------------------------------------------------------------
async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]:
"""Load all template slots for a config as {slot_name: {locale: template}}."""
result = await session.exec(
select(TemplateSlot).where(TemplateSlot.config_id == config_id)
)
slots: dict[str, dict[str, str]] = {}
for s in result.all():
slots.setdefault(s.slot_name, {})[s.locale] = s.template
return slots
return await load_slots(session, TemplateSlot, config_id)
async def _save_slots(
session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]]
) -> None:
"""Create or update template slots for a config (locale-aware)."""
for slot_name, locale_map in slots.items():
for locale, template_text in locale_map.items():
result = await session.exec(
select(TemplateSlot).where(
TemplateSlot.config_id == config_id,
TemplateSlot.slot_name == slot_name,
TemplateSlot.locale == locale,
)
)
existing = result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(TemplateSlot(
config_id=config_id,
slot_name=slot_name,
locale=locale,
template=template_text,
))
await save_slots(session, TemplateSlot, config_id, slots)
async def _response(session: AsyncSession, c: TemplateConfig) -> dict[str, Any]:
@@ -155,7 +128,7 @@ async def get_template_variables(
"photo_count": "Total photo count in album",
"video_count": "Total video count in album",
"owner": "Album owner name",
"target_type": "Target type: 'telegram' or 'webhook'",
"target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
"has_videos": "Whether added assets contain videos (boolean)",
"has_photos": "Whether added assets contain photos (boolean)",
"has_oversized_videos": "Whether any video exceeds the target's size limit (boolean)",
@@ -206,7 +179,7 @@ async def get_template_variables(
}
scheduled_vars = {
"date": "Current date string",
"target_type": "Target type: 'telegram' or 'webhook'",
"target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
}
return {
@@ -284,7 +257,7 @@ def _webhook_variables() -> dict:
"source_ip": "IP address of the webhook sender",
"raw_payload": "Full JSON payload as dict (use raw_payload.field or raw_payload | tojson)",
"timestamp": "When the webhook was received",
"target_type": "Target type: 'telegram' or 'webhook'",
"target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
},
},
}
@@ -529,7 +502,7 @@ async def preview_date_format(
class PreviewRequest(BaseModel):
template: str
target_type: str = "telegram" # "telegram" or "webhook"
target_type: str = "telegram" # telegram, webhook, email, discord, slack, ntfy, matrix
date_format: str = "%d.%m.%Y, %H:%M UTC"
date_only_format: str = "%d.%m.%Y"
@@ -545,33 +518,12 @@ async def preview_raw(
1. Parse with default Undefined (catches syntax errors)
2. Render with StrictUndefined (catches unknown variables like {{ asset.a }})
"""
# Pass 1: syntax check
from datetime import datetime
ctx = {**_SAMPLE_CONTEXT, "target_type": body.target_type,
"date_format": body.date_format, "date_only_format": body.date_only_format}
# Format common_date using the provided date_only_format
try:
env = SandboxedEnvironment(autoescape=False)
env.from_string(body.template)
except TemplateSyntaxError as e:
return {
"rendered": None,
"error": e.message,
"error_line": e.lineno,
}
# Pass 2: render with strict undefined to catch unknown variables
try:
from datetime import datetime
ctx = {**_SAMPLE_CONTEXT, "target_type": body.target_type,
"date_format": body.date_format, "date_only_format": body.date_only_format}
# Format common_date using the provided date_only_format
try:
ctx["common_date"] = datetime(2026, 3, 19).strftime(body.date_only_format)
except (ValueError, TypeError):
ctx["common_date"] = "19.03.2026"
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
tmpl = strict_env.from_string(body.template)
rendered = tmpl.render(**ctx)
return {"rendered": rendered}
except UndefinedError as e:
# Still a valid template syntactically, but references unknown variable
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
except Exception as e:
return {"rendered": None, "error": str(e), "error_line": None}
ctx["common_date"] = datetime(2026, 3, 19).strftime(body.date_only_format)
except (ValueError, TypeError):
ctx["common_date"] = "19.03.2026"
return render_template_preview(body.template, ctx)
@@ -3,9 +3,18 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from ..database.models import CommandTracker, CommandConfig, ServiceProvider, TelegramBot
from ..database.models import CommandConfig, CommandTracker, ServiceProvider, TelegramBot
@dataclass(frozen=True)
class CommandResponse:
"""A single response from one tracker's command execution."""
text: str | None = None
media: list[dict[str, Any]] = field(default_factory=list)
class ProviderCommandHandler(ABC):
@@ -14,6 +23,8 @@ class ProviderCommandHandler(ABC):
Each provider (Immich, Gitea, etc.) implements this interface to handle
its own set of commands. The dispatch layer routes commands to the
correct handler based on the provider type.
Each handler call receives a single (tracker, config, provider) context.
"""
provider_type: str
@@ -35,26 +46,28 @@ class ProviderCommandHandler(ABC):
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
"""Handle a provider-specific command.
tracker: CommandTracker,
config: CommandConfig,
) -> CommandResponse | None:
"""Handle a provider-specific command for a single tracker.
Args:
cmd: The command name (without '/').
args: Arguments after the command.
count: Number of results to return.
locale: User's locale ('en', 'ru').
response_mode: 'media' or 'text'.
providers_map: Provider instances keyed by ID.
cmd_templates: Template slots {slot_name: {locale: template}}.
response_mode: 'media' or 'text' (from this tracker's config).
provider: The service provider instance for this tracker.
cmd_templates: Template slots for this tracker's command template config.
bot: The Telegram bot instance.
ctx_tuples: Command context tuples for this provider type.
tracker: The command tracker being dispatched.
config: The command config for this tracker.
Returns:
Text response, media list, or None if unhandled.
A CommandResponse, or None if unhandled.
"""
def get_rate_categories(self) -> dict[str, str]:
@@ -0,0 +1,67 @@
"""Shared command handler utilities to reduce boilerplate across providers."""
from __future__ import annotations
import logging
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import EventLog, NotificationTracker, ServiceProvider
_LOGGER = logging.getLogger(__name__)
async def get_trackers_for_provider(provider_id: int) -> list[NotificationTracker]:
"""Get notification trackers for a single provider."""
from .handler import _get_notification_trackers_for_providers
return await _get_notification_trackers_for_providers({provider_id})
async def get_last_event_str(tracker_ids: list[int]) -> str:
"""Get formatted timestamp of most recent event for given trackers.
Returns a 'YYYY-MM-DD HH:MM' string, or '-' if no events exist.
"""
if not tracker_ids:
return "-"
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(EventLog)
.where(EventLog.tracker_id.in_(tracker_ids))
.order_by(EventLog.created_at.desc())
.limit(1)
)
last_event = result.first()
return last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
def get_tracked_collection_ids(
provider: ServiceProvider,
trackers: list[NotificationTracker],
*,
max_items: int = 20,
) -> list[str]:
"""Get deduplicated collection IDs from trackers for a provider.
Iterates all trackers belonging to *provider*, collects IDs from both
``collection_ids`` and ``filters.collections``, deduplicates while
preserving order, and caps at *max_items*.
"""
seen: set[str] = set()
result: list[str] = []
for tracker in trackers:
if tracker.provider_id != provider.id:
continue
for cid in tracker.collection_ids or []:
if cid not in seen:
seen.add(cid)
result.append(cid)
for cid in (tracker.filters or {}).get("collections", []):
if cid not in seen:
seen.add(cid)
result.append(cid)
return result[:max_items]
@@ -2,27 +2,55 @@
from __future__ import annotations
import asyncio
import logging
from collections.abc import Callable, Coroutine
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import (
CommandConfig, CommandTracker, EventLog,
NotificationTracker, ServiceProvider, TelegramBot,
CommandConfig, CommandTracker, ServiceProvider, TelegramBot,
)
from ..services import make_gitea_provider
from .base import ProviderCommandHandler
from .handler import _render_cmd_template, _get_notification_trackers_for_providers
from ..services.http_session import get_http_session
from .base import CommandResponse, ProviderCommandHandler
from .command_utils import get_last_event_str, get_tracked_collection_ids, get_trackers_for_provider
from .handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__)
_GITEA_COMMANDS = {"status", "repos", "issues", "prs", "commits"}
def _get_tracked_repos(
provider: ServiceProvider,
trackers: list,
) -> list[tuple[ServiceProvider, str, str]]:
"""Get (provider, owner, repo) tuples from tracked collection_ids."""
if not provider.config.get("api_token"):
return []
collection_ids = get_tracked_collection_ids(provider, trackers)
repos: list[tuple[ServiceProvider, str, str]] = []
for full_name in collection_ids:
parts = full_name.split("/", 1)
if len(parts) == 2:
repos.append((provider, parts[0], parts[1]))
return repos
# ---------------------------------------------------------------------------
# Command dispatch table
# ---------------------------------------------------------------------------
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
"""Register a function in the text command dispatch table."""
name = fn.__name__.removeprefix("_cmd_")
_TEXT_COMMANDS[name] = fn
return fn
class GiteaCommandHandler(ProviderCommandHandler):
"""Handles Gitea-specific bot commands."""
@@ -44,91 +72,35 @@ class GiteaCommandHandler(ProviderCommandHandler):
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
if cmd == "status":
ctx = await _cmd_status(providers_map)
return _render_cmd_template(cmd_templates, "status", locale, ctx)
if cmd == "repos":
ctx = await _cmd_repos(providers_map)
return _render_cmd_template(cmd_templates, "repos", locale, ctx)
if cmd == "issues":
ctx = await _cmd_issues(providers_map, count)
return _render_cmd_template(cmd_templates, "issues", locale, ctx)
if cmd == "prs":
ctx = await _cmd_prs(providers_map, count)
return _render_cmd_template(cmd_templates, "prs", locale, ctx)
if cmd == "commits":
ctx = await _cmd_commits(providers_map, count)
return _render_cmd_template(cmd_templates, "commits", locale, ctx)
return None
tracker: CommandTracker,
config: CommandConfig,
) -> CommandResponse | None:
fn = _TEXT_COMMANDS.get(cmd)
if fn is None:
return None
ctx = await fn(provider, count)
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
def _get_tracked_repos(
providers_map: dict[int, ServiceProvider],
trackers: list[NotificationTracker],
) -> list[tuple[ServiceProvider, str, str]]:
"""Get (provider, owner, repo) tuples from tracked collection_ids."""
repos: list[tuple[ServiceProvider, str, str]] = []
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "gitea":
continue
if not provider.config.get("api_token"):
continue
for full_name in (tracker.collection_ids or []):
parts = full_name.split("/", 1)
if len(parts) == 2:
repos.append((provider, parts[0], parts[1]))
# Also check filters.collections
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "gitea":
continue
if not provider.config.get("api_token"):
continue
for full_name in (tracker.filters or {}).get("collections", []):
parts = full_name.split("/", 1)
if len(parts) == 2:
entry = (provider, parts[0], parts[1])
if entry not in repos:
repos.append(entry)
return repos[:20] # Cap to prevent API hammering
@_text_cmd
async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_repos = _get_tracked_repos(provider, trackers)
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
# Get server version from first Gitea provider with token
# Get server version
server_version = "unknown"
async with aiohttp.ClientSession() as http:
for provider in providers_map.values():
if provider.type == "gitea" and provider.config.get("api_token"):
gitea = make_gitea_provider(http, provider)
version = await gitea.client.get_server_version()
if version:
server_version = version
break
if provider.config.get("api_token"):
http = await get_http_session()
gitea = make_gitea_provider(http, provider)
version = await gitea.client.get_server_version()
if version:
server_version = version
# Last event
engine = get_engine()
async with AsyncSession(engine) as session:
tracker_ids = [t.id for t in trackers]
if tracker_ids:
result = await session.exec(
select(EventLog)
.where(EventLog.tracker_id.in_(tracker_ids))
.order_by(EventLog.created_at.desc()).limit(1)
)
last_event = result.first()
else:
last_event = None
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
tracker_ids = [t.id for t in trackers]
last_str = await get_last_event_str(tracker_ids)
return {
"repos_count": len(tracked_repos),
@@ -137,116 +109,139 @@ async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, An
}
async def _cmd_repos(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
@_text_cmd
async def _cmd_repos(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_repos = _get_tracked_repos(provider, trackers)
repos_data: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider)
try:
all_repos = await gitea.client.get_repos(limit=50)
for r in all_repos:
if r.get("full_name") == f"{owner}/{repo}":
repos_data.append({
"full_name": r.get("full_name", ""),
"description": r.get("description", ""),
"stars": r.get("stars_count", 0),
"url": r.get("html_url", ""),
})
break
else:
repos_data.append({
"full_name": f"{owner}/{repo}",
"description": "",
"stars": 0,
"url": "",
})
except Exception:
repos_data.append({
"full_name": f"{owner}/{repo}",
"description": "?",
"stars": 0,
"url": "",
})
http = await get_http_session()
async def _fetch_repo(prov: ServiceProvider, owner: str, repo: str) -> dict[str, Any]:
gitea = make_gitea_provider(http, prov)
# Use direct get_repo endpoint instead of listing all repos
r = await gitea.client.get_repo(owner, repo)
if r:
return {
"full_name": r.get("full_name", ""),
"description": r.get("description", ""),
"stars": r.get("stars_count", 0),
"url": r.get("html_url", ""),
}
return {
"full_name": f"{owner}/{repo}",
"description": "",
"stars": 0,
"url": "",
}
tasks = [_fetch_repo(prov, owner, repo) for prov, owner, repo in tracked_repos]
results = await asyncio.gather(*tasks, return_exceptions=True)
for (prov, owner, repo), result in zip(tracked_repos, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch repo %s/%s: %s", owner, repo, result)
repos_data.append({
"full_name": f"{owner}/{repo}",
"description": "?",
"stars": 0,
"url": "",
})
else:
repos_data.append(result)
return {"repos": repos_data}
async def _cmd_issues(
providers_map: dict[int, ServiceProvider], count: int,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
@_text_cmd
async def _cmd_issues(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_repos = _get_tracked_repos(provider, trackers)
all_issues: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider)
issues = await gitea.client.get_repo_issues(owner, repo, limit=count)
for issue in issues:
all_issues.append({
"repo": f"{owner}/{repo}",
"number": issue.get("number", 0),
"title": issue.get("title", ""),
"url": issue.get("html_url", ""),
"user": issue.get("user", {}).get("login", ""),
"state": issue.get("state", ""),
})
http = await get_http_session()
async def _fetch_issues(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
gitea = make_gitea_provider(http, prov)
return await gitea.client.get_repo_issues(owner, repo, limit=count)
tasks = [_fetch_issues(prov, owner, repo) for prov, owner, repo in tracked_repos]
results = await asyncio.gather(*tasks, return_exceptions=True)
for (prov, owner, repo), result in zip(tracked_repos, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch issues for %s/%s: %s", owner, repo, result)
continue
for issue in result:
all_issues.append({
"repo": f"{owner}/{repo}",
"number": issue.get("number", 0),
"title": issue.get("title", ""),
"url": issue.get("html_url", ""),
"user": issue.get("user", {}).get("login", ""),
"state": issue.get("state", ""),
})
all_issues.sort(key=lambda i: i.get("number", 0), reverse=True)
return {"issues": all_issues[:count]}
async def _cmd_prs(
providers_map: dict[int, ServiceProvider], count: int,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
@_text_cmd
async def _cmd_prs(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_repos = _get_tracked_repos(provider, trackers)
all_prs: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider)
prs = await gitea.client.get_repo_pulls(owner, repo, limit=count)
for pr in prs:
all_prs.append({
"repo": f"{owner}/{repo}",
"number": pr.get("number", 0),
"title": pr.get("title", ""),
"url": pr.get("html_url", ""),
"user": pr.get("user", {}).get("login", ""),
"state": pr.get("state", ""),
})
http = await get_http_session()
async def _fetch_prs(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
gitea = make_gitea_provider(http, prov)
return await gitea.client.get_repo_pulls(owner, repo, limit=count)
tasks = [_fetch_prs(prov, owner, repo) for prov, owner, repo in tracked_repos]
results = await asyncio.gather(*tasks, return_exceptions=True)
for (prov, owner, repo), result in zip(tracked_repos, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch PRs for %s/%s: %s", owner, repo, result)
continue
for pr in result:
all_prs.append({
"repo": f"{owner}/{repo}",
"number": pr.get("number", 0),
"title": pr.get("title", ""),
"url": pr.get("html_url", ""),
"user": pr.get("user", {}).get("login", ""),
"state": pr.get("state", ""),
})
all_prs.sort(key=lambda p: p.get("number", 0), reverse=True)
return {"prs": all_prs[:count]}
async def _cmd_commits(
providers_map: dict[int, ServiceProvider], count: int,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_repos = _get_tracked_repos(providers_map, trackers)
@_text_cmd
async def _cmd_commits(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_repos = _get_tracked_repos(provider, trackers)
all_commits: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, owner, repo in tracked_repos:
gitea = make_gitea_provider(http, provider)
commits = await gitea.client.get_repo_commits(owner, repo, limit=count)
for c in commits:
commit_data = c.get("commit", {})
all_commits.append({
"repo": f"{owner}/{repo}",
"short_id": c.get("sha", "")[:7],
"message": commit_data.get("message", "").split("\n")[0][:80],
"author": commit_data.get("author", {}).get("name", ""),
"url": c.get("html_url", ""),
})
http = await get_http_session()
async def _fetch_commits(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
gitea = make_gitea_provider(http, prov)
return await gitea.client.get_repo_commits(owner, repo, limit=count)
tasks = [_fetch_commits(prov, owner, repo) for prov, owner, repo in tracked_repos]
results = await asyncio.gather(*tasks, return_exceptions=True)
for (prov, owner, repo), result in zip(tracked_repos, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch commits for %s/%s: %s", owner, repo, result)
continue
for c in result:
commit_data = c.get("commit", {})
all_commits.append({
"repo": f"{owner}/{repo}",
"short_id": c.get("sha", "")[:7],
"message": commit_data.get("message", "").split("\n")[0][:80],
"author": commit_data.get("author", {}).get("name", ""),
"url": c.get("html_url", ""),
})
return {"commits": all_commits[:count]}
@@ -4,6 +4,7 @@ from __future__ import annotations
import logging
import time
from functools import lru_cache
from typing import Any
import aiohttp
@@ -25,17 +26,21 @@ from ..database.models import (
ServiceProvider,
TelegramBot,
)
from .base import CommandResponse
from .parser import parse_command
from .registry import get_rate_category
_LOGGER = logging.getLogger(__name__)
# Singleton Jinja2 environment for template rendering (Phase 4d)
_JINJA_ENV = SandboxedEnvironment(autoescape=False)
_JINJA_ENV = SandboxedEnvironment(autoescape=True)
# Rate limit state with automatic TTL expiry (Phase 4e)
_rate_limits: TTLCache = TTLCache(maxsize=10000, ttl=3600)
# Maximum responses per command to avoid Telegram rate limits
_MAX_RESPONSES_PER_COMMAND = 5
def _check_rate_limit(bot_id: int, chat_id: str, cmd: str, limits: dict[str, int]) -> int | None:
"""Check rate limit. Returns seconds to wait, or None if OK."""
@@ -60,6 +65,12 @@ def _resolve_template(
return locale_map.get(locale) or locale_map.get("en")
@lru_cache(maxsize=256)
def _compile_template(template_str: str):
"""Cache compiled Jinja2 templates to avoid re-parsing identical strings."""
return _JINJA_ENV.from_string(template_str)
def _render_cmd_template(
templates: dict[str, dict[str, str]], slot_name: str, locale: str,
context: dict[str, Any],
@@ -70,20 +81,28 @@ def _render_cmd_template(
_LOGGER.warning("No command template found for slot '%s' locale '%s'", slot_name, locale)
return f"[No template: {slot_name}]"
try:
tmpl = _JINJA_ENV.from_string(template_str)
tmpl = _compile_template(template_str)
return tmpl.render(**context)
except Exception as e:
_LOGGER.warning("Failed to render command template '%s': %s", slot_name, e)
return f"[Template error: {slot_name}]"
# ---------------------------------------------------------------------------
# Context resolution
# ---------------------------------------------------------------------------
async def _resolve_command_context(
bot: TelegramBot,
) -> tuple[list[tuple[CommandTracker, CommandConfig, ServiceProvider]], dict[str, dict[str, str]]]:
) -> tuple[
list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
dict[int, dict[str, dict[str, str]]],
]:
"""Resolve all enabled command trackers, configs, and providers for a bot.
Returns (context_tuples, cmd_template_slots).
cmd_template_slots is {slot_name: {locale: template}}.
Returns:
(context_tuples, templates_by_config_id)
templates_by_config_id is {command_template_config_id: {slot_name: {locale: template}}}.
"""
engine = get_engine()
async with AsyncSession(engine) as session:
@@ -142,8 +161,8 @@ async def _resolve_command_context(
continue
tuples.append((tracker, config, provider))
# Load command template slots — merge from all configs
cmd_template_slots: dict[str, dict[str, str]] = {}
# Load command template slots per config (not merged)
templates_by_config_id: dict[int, dict[str, dict[str, str]]] = {}
seen_config_ids: set[int] = set()
for _, config, _ in tuples:
cfg_id = config.command_template_config_id
@@ -154,98 +173,136 @@ async def _resolve_command_context(
CommandTemplateSlot.config_id == cfg_id
)
)
slots: dict[str, dict[str, str]] = {}
for s in slot_result.all():
cmd_template_slots.setdefault(s.slot_name, {})[s.locale] = s.template
slots.setdefault(s.slot_name, {})[s.locale] = s.template
templates_by_config_id[cfg_id] = slots
return tuples, cmd_template_slots
return tuples, templates_by_config_id
def _merge_command_context(
def _templates_for_config(
templates_by_config_id: dict[int, dict[str, dict[str, str]]],
config: CommandConfig,
) -> dict[str, dict[str, str]]:
"""Get template slots for a specific command config."""
cfg_id = config.command_template_config_id
if cfg_id and cfg_id in templates_by_config_id:
return templates_by_config_id[cfg_id]
return {}
def _merge_all_templates(
templates_by_config_id: dict[int, dict[str, dict[str, str]]],
) -> dict[str, dict[str, str]]:
"""Merge all template config slots into one dict (for universal commands)."""
merged: dict[str, dict[str, str]] = {}
for slots in templates_by_config_id.values():
for slot_name, locale_map in slots.items():
merged.setdefault(slot_name, {}).update(locale_map)
return merged
def _merge_enabled_commands(
ctx: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> tuple[list[str], str, int, dict[str, Any]]:
"""Merge enabled_commands from all configs and pick defaults from first config."""
) -> tuple[list[str], dict[str, Any]]:
"""Merge enabled_commands (union) and rate_limits from all configs.
Rate limits use the most restrictive (minimum) cooldown per category.
"""
if not ctx:
return [], "media", 5, {}
return [], {}
enabled: set[str] = set()
merged_limits: dict[str, int] = {}
for _, config, _ in ctx:
enabled.update(config.enabled_commands or [])
for category, cooldown in (config.rate_limits or {}).items():
if category not in merged_limits:
merged_limits[category] = cooldown
else:
merged_limits[category] = min(merged_limits[category], cooldown)
first_config = ctx[0][1]
response_mode = first_config.response_mode or "media"
default_count = first_config.default_count or 5
rate_limits = first_config.rate_limits or {}
return sorted(enabled), merged_limits
return sorted(enabled), response_mode, default_count, rate_limits
# ---------------------------------------------------------------------------
# Main dispatcher
# ---------------------------------------------------------------------------
async def handle_command(
bot: TelegramBot,
chat_id: str,
text: str,
language_code: str = "",
) -> str | list[dict[str, Any]] | None:
) -> list[CommandResponse] | None:
"""Handle a bot command. Routes to provider-specific handlers.
Returns text response, media list, or None.
Returns a list of CommandResponse objects (one per tracker), or None.
Universal commands (/start, /help) return a single-element list.
Provider-specific commands dispatch per-tracker with per-tracker config.
"""
cmd, args, count_override = parse_command(text)
if not cmd:
return None
ctx_tuples, cmd_templates = await _resolve_command_context(bot)
enabled, response_mode, default_count, rate_limits = _merge_command_context(ctx_tuples)
ctx_tuples, templates_by_config_id = await _resolve_command_context(bot)
enabled, rate_limits = _merge_enabled_commands(ctx_tuples)
locale = language_code[:2].lower() if language_code else "en"
if locale not in ("en", "ru"):
locale = "en"
# Merged templates for universal commands
merged_templates = _merge_all_templates(templates_by_config_id)
if cmd == "start":
return _render_cmd_template(cmd_templates, "start", locale, {"bot_name": bot.name})
text_resp = _render_cmd_template(merged_templates, "start", locale, {"bot_name": bot.name})
return [CommandResponse(text=text_resp)]
if cmd not in enabled and cmd != "start":
return None
# Rate limit check
# Rate limit check (once per command, shared across all trackers)
wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits)
if wait is not None:
return _render_cmd_template(cmd_templates, "rate_limited", locale, {"wait": wait})
text_resp = _render_cmd_template(merged_templates, "rate_limited", locale, {"wait": wait})
return [CommandResponse(text=text_resp)]
count = min(count_override or default_count, 20)
# Build providers map from command context
providers_map: dict[int, ServiceProvider] = {}
for _, _, provider in ctx_tuples:
providers_map[provider.id] = provider
# Universal commands
# Universal commands — single merged response
if cmd == "help":
ctx = _cmd_help(enabled, locale, cmd_templates)
return _render_cmd_template(cmd_templates, "help", locale, ctx)
ctx = _cmd_help(enabled, locale, merged_templates)
text_resp = _render_cmd_template(merged_templates, "help", locale, ctx)
return [CommandResponse(text=text_resp)]
# Provider-specific dispatch
# Provider-specific dispatch — per-tracker
from .dispatch import get_handler
# Group ctx_tuples by provider type
by_type: dict[str, list[tuple[CommandTracker, CommandConfig, ServiceProvider]]] = {}
for t in ctx_tuples:
ptype = t[2].type
by_type.setdefault(ptype, []).append(t)
# Find which handler claims this command
for ptype, ptuples in by_type.items():
handler = get_handler(ptype)
if handler and cmd in handler.get_provider_commands():
# Build provider map filtered to this provider type
pmap = {p.id: p for _, _, p in ptuples}
result = await handler.handle(
cmd, args, count, locale, response_mode,
pmap, cmd_templates, bot, ptuples,
responses: list[CommandResponse] = []
for tracker, config, provider in ctx_tuples:
if len(responses) >= _MAX_RESPONSES_PER_COMMAND:
_LOGGER.warning(
"Truncated command responses at %d for bot %d cmd /%s",
_MAX_RESPONSES_PER_COMMAND, bot.id, cmd,
)
if result is not None:
return result
break
return None
handler = get_handler(provider.type)
if not handler or cmd not in handler.get_provider_commands():
continue
tracker_templates = _templates_for_config(templates_by_config_id, config)
count = min(count_override or config.default_count or 5, 20)
response_mode = config.response_mode or "media"
result = await handler.handle(
cmd, args, count, locale, response_mode,
provider, tracker_templates, bot, tracker, config,
)
if result is not None:
responses.append(result)
return responses if responses else None
def _cmd_help(
@@ -283,17 +340,13 @@ async def send_reply(
session: aiohttp.ClientSession | None = None,
) -> None:
"""Send a text reply via TelegramClient."""
async def _send(http: aiohttp.ClientSession) -> None:
client = TelegramClient(http, bot_token)
result = await client.send_message(chat_id, text, reply_to_message_id=reply_to_message_id)
if not result.get("success"):
_LOGGER.warning("Telegram reply failed: %s", result.get("error"))
if session is not None:
await _send(session)
else:
async with aiohttp.ClientSession() as http:
await _send(http)
if session is None:
from ..services.http_session import get_http_session
session = await get_http_session()
client = TelegramClient(session, bot_token)
result = await client.send_message(chat_id, text, reply_to_message_id=reply_to_message_id)
if not result.get("success"):
_LOGGER.warning("Telegram reply failed: %s", result.get("error"))
async def send_media_group(
@@ -319,52 +372,50 @@ async def send_media_group(
captions = [item.get("caption", "") for item in media_items if item.get("caption")]
caption = "\n".join(captions) if captions else None
async def _send(http: aiohttp.ClientSession) -> None:
client = TelegramClient(http, bot_token)
result = await client.send_notification(
chat_id, assets=assets, caption=caption,
reply_to_message_id=reply_to_message_id,
chat_action=None,
)
if not result.get("success"):
_LOGGER.warning("Telegram media group failed: %s", result.get("error"))
if session is not None:
await _send(session)
else:
async with aiohttp.ClientSession() as http:
await _send(http)
if session is None:
from ..services.http_session import get_http_session
session = await get_http_session()
client = TelegramClient(session, bot_token)
result = await client.send_notification(
chat_id, assets=assets, caption=caption,
reply_to_message_id=reply_to_message_id,
chat_action=None,
)
if not result.get("success"):
_LOGGER.warning("Telegram media group failed: %s", result.get("error"))
async def register_commands_with_telegram(bot: TelegramBot) -> bool:
"""Register enabled commands with Telegram BotFather API via TelegramClient."""
ctx_tuples, templates = await _resolve_command_context(bot)
enabled, _, _, _ = _merge_command_context(ctx_tuples)
ctx_tuples, templates_by_config_id = await _resolve_command_context(bot)
enabled, _ = _merge_enabled_commands(ctx_tuples)
templates = _merge_all_templates(templates_by_config_id)
async with aiohttp.ClientSession() as http:
client = TelegramClient(http, bot.token)
success = False
from ..services.http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, bot.token)
success = False
# Register per-locale commands
for locale in ("en", "ru"):
commands = []
for cmd in enabled:
desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd
commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(commands, language_code=locale)
if result.get("success"):
success = True
else:
_LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error"))
# Register default (no language_code) with EN descriptions
en_commands = []
# Register per-locale commands
for locale in ("en", "ru"):
commands = []
for cmd in enabled:
desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd
en_commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(en_commands)
desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd
commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(commands, language_code=locale)
if result.get("success"):
_LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username)
success = True
else:
_LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error"))
return success
# Register default (no language_code) with EN descriptions
en_commands = []
for cmd in enabled:
desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd
en_commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(en_commands)
if result.get("success"):
_LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username)
success = True
return success
@@ -6,70 +6,48 @@ import asyncio
import logging
from typing import Any
import aiohttp
from notify_bridge_core.providers.immich.asset_utils import get_public_url
from ...database.models import ServiceProvider, TelegramBot
from ...database.models import ServiceProvider
from ...services import make_immich_provider
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
from .common import _format_assets, build_asset_dict
from ...services.http_session import get_http_session
from ..command_utils import get_trackers_for_provider
from ..handler import _render_cmd_template
from .common import _format_assets, build_asset_dict, fetch_albums_with_links
_LOGGER = logging.getLogger(__name__)
async def _cmd_albums(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
provider: ServiceProvider, locale: str,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
trackers = await get_trackers_for_provider(provider.id)
if not trackers:
return {"albums": []}
albums_data: list[dict] = []
async with aiohttp.ClientSession() as http:
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "immich":
continue
immich = make_immich_provider(http, provider)
album_ids = tracker.collection_ids or []
if not album_ids:
continue
# Deduplicate album IDs while preserving order
seen: set[str] = set()
album_ids: list[str] = []
for tracker in trackers:
for aid in tracker.collection_ids or []:
if aid not in seen:
seen.add(aid)
album_ids.append(aid)
if not album_ids:
return {"albums": []}
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
album_results = await asyncio.gather(
*[immich.client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[immich.client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
for album_id, result, links in zip(album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
albums_data.append({
"name": f"{album_id[:8]}...", "asset_count": "?", "id": album_id,
})
elif result:
pub_url = ""
if not isinstance(links, Exception) and ext_domain:
pub_url = get_public_url(ext_domain, links) or ""
albums_data.append({
"name": result.name, "asset_count": result.asset_count,
"id": album_id, "public_url": pub_url,
})
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
http = await get_http_session()
immich = make_immich_provider(http, provider)
albums_data = await fetch_albums_with_links(immich.client, album_ids, ext_domain)
return {"albums": albums_data}
async def cmd_favorites(
bot: TelegramBot, providers_map: dict[int, ServiceProvider],
providers_map: dict[int, ServiceProvider],
all_album_ids: list[str], count: int, locale: str,
response_mode: str, client: Any,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /favorites command with concurrent album fetching."""
album_ids = all_album_ids[:10]
if not album_ids:
@@ -104,28 +82,6 @@ async def cmd_summary(
if not all_album_ids:
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": []})
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in all_album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in all_album_ids],
return_exceptions=True,
)
ext = external_domain.rstrip("/")
albums_data: list[dict] = []
for album_id, result, links in zip(all_album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
pub_url = ""
if not isinstance(links, Exception) and ext:
pub_url = get_public_url(ext, links) or ""
albums_data.append({
"name": result.name, "asset_count": result.asset_count,
"id": album_id, "public_url": pub_url,
})
albums_data = await fetch_albums_with_links(client, all_album_ids, ext, include_failed=False)
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": albums_data})
@@ -2,10 +2,12 @@
from __future__ import annotations
import asyncio
import logging
from typing import Any
from ...services import make_immich_provider
from notify_bridge_core.providers.immich.asset_utils import get_public_url
from ..handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__)
@@ -17,6 +19,53 @@ _IMMICH_COMMANDS = {
}
async def fetch_albums_with_links(
client: Any,
album_ids: list[str],
ext_domain: str,
*,
include_failed: bool = True,
) -> list[dict[str, Any]]:
"""Fetch albums and their shared links concurrently.
Returns a list of album data dicts with keys: name, asset_count, id,
public_url, and ``_album`` (the raw album object for callers that need
asset-level access).
When *include_failed* is True, albums that fail to fetch are included
with placeholder data (``"?"`` for counts). When False, they are
silently skipped.
"""
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
albums_data: list[dict[str, Any]] = []
for album_id, result, links in zip(album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
if include_failed:
albums_data.append({
"name": f"{album_id[:8]}...", "asset_count": "?",
"id": album_id, "public_url": "", "_album": None,
})
continue
if result:
pub_url = ""
if not isinstance(links, Exception) and ext_domain:
pub_url = get_public_url(ext_domain, links) or ""
albums_data.append({
"name": result.name, "asset_count": result.asset_count,
"id": album_id, "public_url": pub_url, "_album": result,
})
return albums_data
def build_asset_dict(
asset: Any,
*,
@@ -56,8 +105,14 @@ def _format_assets(
assets: list[dict[str, Any]], cmd: str, query: str,
locale: str, response_mode: str, client: Any,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Format asset results as text or media payload."""
) -> str | dict[str, Any]:
"""Format asset results as text or a text-plus-media payload.
Returns:
str: rendered text when *response_mode* is ``"text"`` (or no assets).
dict: ``{"text": ..., "media": [...]}`` when *response_mode* is
``"media"`` and assets are present.
"""
if not assets:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": query})
@@ -68,7 +123,7 @@ def _format_assets(
})
if response_mode == "media":
media_items = []
media_items: list[dict[str, Any]] = []
for asset in assets:
asset_id = asset.get("id", "")
media_items.append({
@@ -13,23 +13,22 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from ...database.engine import get_engine
from ...database.models import (
EventLog, NotificationTarget, NotificationTrackerTarget,
ServiceProvider, TelegramBot, TrackingConfig,
EventLog, NotificationTracker, NotificationTrackerTarget,
ServiceProvider, TrackingConfig,
)
from notify_bridge_core.providers.immich.asset_utils import get_public_url
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
from .common import _format_assets, build_asset_dict
from ..command_utils import get_trackers_for_provider
from ..handler import _render_cmd_template
from .common import _format_assets, build_asset_dict, fetch_albums_with_links
_LOGGER = logging.getLogger(__name__)
async def _cmd_events(
bot: TelegramBot, providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
count: int, locale: str,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
trackers = await get_trackers_for_provider(provider.id)
tracker_ids = [t.id for t in trackers]
if not tracker_ids:
return {"events": []}
@@ -57,32 +56,21 @@ async def cmd_latest(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
external_domain: str = "",
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /latest command with concurrent album fetching."""
album_ids = all_album_ids[:10]
if not album_ids:
return _format_assets([], "latest", "", locale, response_mode, client, cmd_templates)
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
ext = external_domain.rstrip("/")
fetched = await fetch_albums_with_links(client, album_ids, ext, include_failed=False)
latest_assets: list[dict[str, Any]] = []
for album_id, result, links in zip(album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
pub_url = ""
if not isinstance(links, Exception) and ext:
pub_url = get_public_url(ext, links) or ""
for aid, asset in list(result.assets.items())[:count]:
for album_data in fetched:
pub_url = album_data.get("public_url", "")
album_obj = album_data.get("_album")
if album_obj:
for aid, asset in list(album_obj.assets.items())[:count]:
asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else ""
latest_assets.append(build_asset_dict(asset, public_url=asset_pub))
@@ -95,32 +83,21 @@ async def cmd_random(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
external_domain: str = "",
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /random command with concurrent album fetching."""
album_ids = all_album_ids[:10]
if not album_ids:
return _format_assets([], "random", "", locale, response_mode, client, cmd_templates)
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in album_ids],
return_exceptions=True,
)
ext = external_domain.rstrip("/")
fetched = await fetch_albums_with_links(client, album_ids, ext, include_failed=False)
random_assets: list[dict[str, Any]] = []
for album_id, result, links in zip(album_ids, album_results, link_results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
pub_url = ""
if not isinstance(links, Exception) and ext:
pub_url = get_public_url(ext, links) or ""
asset_list = list(result.assets.values())
for album_data in fetched:
pub_url = album_data.get("public_url", "")
album_obj = album_data.get("_album")
if album_obj:
asset_list = list(album_obj.assets.values())
sampled = rng.sample(asset_list, min(count, len(asset_list)))
for asset in sampled:
asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else ""
@@ -130,40 +107,40 @@ async def cmd_random(
return _format_assets(random_assets[:count], "random", "", locale, response_mode, client, cmd_templates)
async def _check_native_memory(bot: TelegramBot) -> bool:
"""Check if any tracker-target linked to this bot uses native memory source."""
async def _check_native_memory(provider_id: int) -> bool:
"""Check if any notification tracker for this provider uses native memory source."""
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(NotificationTarget).where(
NotificationTarget.type == "telegram",
NotificationTarget.user_id == bot.user_id,
tracker_result = await session.exec(
select(NotificationTracker).where(
NotificationTracker.provider_id == provider_id,
)
)
targets = result.all()
bot_target_ids = {t.id for t in targets if t.config.get("bot_token") == bot.token}
if not bot_target_ids:
trackers = tracker_result.all()
tracker_ids = [t.id for t in trackers]
if not tracker_ids:
return False
tt_result = await session.exec(
select(NotificationTrackerTarget).where(
NotificationTrackerTarget.target_id.in_(bot_target_ids)
NotificationTrackerTarget.tracker_id.in_(tracker_ids)
)
)
for tt in tt_result.all():
if tt.tracking_config_id:
tc = await session.get(TrackingConfig, tt.tracking_config_id)
if tc and tc.memory_source == "native":
return True
return False
tc_ids = list({tt.tracking_config_id for tt in tt_result.all() if tt.tracking_config_id})
if not tc_ids:
return False
tc_result = await session.exec(
select(TrackingConfig).where(TrackingConfig.id.in_(tc_ids))
)
return any(tc.memory_source == "native" for tc in tc_result.all())
async def cmd_memory(
bot: TelegramBot, client: Any, all_album_ids: list[str], count: int,
provider_id: int, client: Any, all_album_ids: list[str], count: int,
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /memory command with concurrent album fetching."""
use_native = await _check_native_memory(bot)
use_native = await _check_native_memory(provider_id)
today = datetime.now(timezone.utc)
memory_assets: list[dict[str, Any]] = []
@@ -2,26 +2,21 @@
from __future__ import annotations
import asyncio
import logging
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ...database.engine import get_engine
from ...database.models import (
CommandConfig, CommandTracker, EventLog,
CommandConfig, CommandTracker,
ServiceProvider, TelegramBot,
)
from ...services import make_immich_provider
from ..base import ProviderCommandHandler
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
from notify_bridge_core.providers.immich.asset_utils import get_public_url
from ...services.http_session import get_http_session
from ..base import CommandResponse, ProviderCommandHandler
from ..command_utils import get_last_event_str, get_trackers_for_provider
from ..handler import _render_cmd_template
from .albums import _cmd_albums, cmd_favorites, cmd_summary
from .common import _IMMICH_COMMANDS
from .common import _IMMICH_COMMANDS, fetch_albums_with_links
from .events import _cmd_events, cmd_latest, cmd_memory, cmd_random
from .search import cmd_find, cmd_person, cmd_place, cmd_search
@@ -29,21 +24,15 @@ _LOGGER = logging.getLogger(__name__)
async def _cmd_status(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
provider: ServiceProvider, locale: str,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
trackers = await get_trackers_for_provider(provider.id)
active = sum(1 for t in trackers if t.enabled)
total = len(trackers)
total_albums = sum(len(t.collection_ids or []) for t in trackers)
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(EventLog).order_by(EventLog.created_at.desc()).limit(1)
)
last_event = result.first()
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
tracker_ids = [t.id for t in trackers]
last_str = await get_last_event_str(tracker_ids)
return {
"trackers_active": active, "trackers_total": total,
@@ -52,16 +41,13 @@ async def _cmd_status(
async def _cmd_people(
providers_map: dict[int, ServiceProvider], locale: str,
provider: ServiceProvider, locale: str,
) -> dict[str, Any]:
all_people: dict[str, str] = {}
async with aiohttp.ClientSession() as http:
for provider in providers_map.values():
if provider.type != "immich":
continue
immich = make_immich_provider(http, provider)
people = await immich.client.get_people()
all_people.update(people)
http = await get_http_session()
immich = make_immich_provider(http, provider)
people = await immich.client.get_people()
all_people.update(people)
names = sorted(all_people.values())
return {"people": names}
@@ -87,106 +73,92 @@ class ImmichCommandHandler(ProviderCommandHandler):
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
tracker: CommandTracker,
config: CommandConfig,
) -> CommandResponse | None:
if cmd == "status":
ctx = await _cmd_status(bot, providers_map, locale)
return _render_cmd_template(cmd_templates, "status", locale, ctx)
ctx = await _cmd_status(provider, locale)
return CommandResponse(text=_render_cmd_template(cmd_templates, "status", locale, ctx))
if cmd == "albums":
ctx = await _cmd_albums(bot, providers_map, locale)
return _render_cmd_template(cmd_templates, "albums", locale, ctx)
ctx = await _cmd_albums(provider, locale)
return CommandResponse(text=_render_cmd_template(cmd_templates, "albums", locale, ctx))
if cmd == "events":
ctx = await _cmd_events(bot, providers_map, count, locale)
return _render_cmd_template(cmd_templates, "events", locale, ctx)
ctx = await _cmd_events(provider, count, locale)
return CommandResponse(text=_render_cmd_template(cmd_templates, "events", locale, ctx))
if cmd == "people":
ctx = await _cmd_people(providers_map, locale)
return _render_cmd_template(cmd_templates, "people", locale, ctx)
ctx = await _cmd_people(provider, locale)
return CommandResponse(text=_render_cmd_template(cmd_templates, "people", locale, ctx))
if cmd in ("search", "find", "person", "place", "latest",
"random", "favorites", "summary", "memory"):
return await _cmd_immich(
bot, cmd, args, count, locale, response_mode,
providers_map, cmd_templates,
cmd, args, count, locale, response_mode,
provider, cmd_templates,
)
return None
async def _cmd_immich(
bot: TelegramBot, cmd: str, args: str, count: int, locale: str,
response_mode: str, providers_map: dict[int, ServiceProvider],
cmd: str, args: str, count: int, locale: str,
response_mode: str, provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
) -> CommandResponse | None:
"""Handle commands that need Immich API access and may return media."""
if not providers_map:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
provider_ids = set(providers_map.keys())
notification_trackers = await _get_notification_trackers_for_providers(provider_ids)
notification_trackers = await get_trackers_for_provider(provider.id)
all_album_ids: list[str] = []
for t in notification_trackers:
all_album_ids.extend(t.collection_ids or [])
provider: ServiceProvider | None = None
for p in providers_map.values():
if p.type == "immich":
provider = p
break
if not provider:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
async with aiohttp.ClientSession() as http:
immich = make_immich_provider(http, provider)
client = immich.client
http = await get_http_session()
immich = make_immich_provider(http, provider)
client = immich.client
# Build asset_id → public_url map from tracked albums' shared links
asset_public_urls: dict[str, str] = {}
if ext_domain and all_album_ids and cmd in ("search", "find", "person", "place", "favorites"):
link_results = await asyncio.gather(
*[client.get_shared_links(aid) for aid in all_album_ids],
return_exceptions=True,
)
album_results = await asyncio.gather(
*[client.get_album(aid) for aid in all_album_ids],
return_exceptions=True,
)
for album_id, links, album in zip(all_album_ids, link_results, album_results):
if isinstance(links, Exception) or isinstance(album, Exception):
continue
pub_url = get_public_url(ext_domain, links)
if pub_url and album:
for asset_id in album.assets:
asset_public_urls[asset_id] = f"{pub_url}/photos/{asset_id}"
# Build asset_id → public_url map from tracked albums' shared links
asset_public_urls: dict[str, str] = {}
if ext_domain and all_album_ids and cmd in ("search", "find", "person", "place", "favorites"):
fetched = await fetch_albums_with_links(client, all_album_ids, ext_domain, include_failed=False)
for album_data in fetched:
pub_url = album_data.get("public_url", "")
album_obj = album_data.get("_album")
if pub_url and album_obj:
for asset_id in album_obj.assets:
asset_public_urls[asset_id] = f"{pub_url}/photos/{asset_id}"
if cmd == "search":
return await cmd_search(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
# Wrap single-provider in a map for functions that still expect it
providers_map = {provider.id: provider}
if cmd == "find":
return await cmd_find(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
result: str | dict[str, Any] | None = None
if cmd == "person":
return await cmd_person(client, args, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
if cmd == "search":
result = await cmd_search(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "find":
result = await cmd_find(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "person":
result = await cmd_person(client, args, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "place":
result = await cmd_place(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
elif cmd == "favorites":
result = await cmd_favorites(providers_map, all_album_ids, count, locale, response_mode, client, cmd_templates)
elif cmd == "latest":
result = await cmd_latest(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
elif cmd == "random":
result = await cmd_random(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
elif cmd == "summary":
result = await cmd_summary(client, all_album_ids, locale, cmd_templates, external_domain=ext_domain)
elif cmd == "memory":
result = await cmd_memory(provider.id, client, all_album_ids, count, locale, response_mode, cmd_templates)
if cmd == "place":
return await cmd_place(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
if cmd == "favorites":
return await cmd_favorites(bot, providers_map, all_album_ids, count, locale, response_mode, client, cmd_templates)
if cmd == "latest":
return await cmd_latest(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
if cmd == "random":
return await cmd_random(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
if cmd == "summary":
return await cmd_summary(client, all_album_ids, locale, cmd_templates, external_domain=ext_domain)
if cmd == "memory":
return await cmd_memory(bot, client, all_album_ids, count, locale, response_mode, cmd_templates)
return None
if result is None:
return None
# _format_assets returns {"text": ..., "media": [...]} for media mode
if isinstance(result, dict):
return CommandResponse(
text=result.get("text"),
media=result.get("media", []),
)
return CommandResponse(text=result)
@@ -9,14 +9,15 @@ from .common import _format_assets
def _enrich_assets(assets: list[dict[str, Any]], asset_public_urls: dict[str, str]) -> list[dict[str, Any]]:
"""Add public_url to assets from the pre-built map."""
"""Add public_url to assets from the pre-built map. Returns new list without mutating inputs."""
if not asset_public_urls:
return assets
for asset in assets:
aid = asset.get("id", "")
if aid and aid in asset_public_urls and not asset.get("public_url"):
asset["public_url"] = asset_public_urls[aid]
return assets
return [
{**asset, "public_url": asset_public_urls.get(asset.get("id", ""), "")}
if asset.get("id", "") in asset_public_urls and not asset.get("public_url")
else asset
for asset in assets
]
async def cmd_search(
@@ -24,7 +25,7 @@ async def cmd_search(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /search command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "search", "query": ""})
@@ -38,7 +39,7 @@ async def cmd_find(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /find command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "find", "query": ""})
@@ -52,7 +53,7 @@ async def cmd_person(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /person command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": ""})
@@ -74,7 +75,7 @@ async def cmd_place(
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
asset_public_urls: dict[str, str] | None = None,
) -> str | list[dict[str, Any]]:
) -> str | dict[str, Any]:
"""Handle /place command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "place", "query": ""})
@@ -3,17 +3,31 @@
from __future__ import annotations
import logging
from collections.abc import Callable, Coroutine
from typing import Any
from ..database.models import CommandConfig, CommandTracker, ServiceProvider, TelegramBot
from ..services import make_nut_provider
from .base import ProviderCommandHandler
from .base import CommandResponse, ProviderCommandHandler
from .handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__)
_NUT_COMMANDS = {"status", "devices", "battery"}
# ---------------------------------------------------------------------------
# Command dispatch table
# ---------------------------------------------------------------------------
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
"""Register a function in the text command dispatch table."""
name = fn.__name__.removeprefix("_cmd_")
_TEXT_COMMANDS[name] = fn
return fn
class NutCommandHandler(ProviderCommandHandler):
"""Handles NUT-specific bot commands."""
@@ -33,80 +47,73 @@ class NutCommandHandler(ProviderCommandHandler):
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
if cmd == "status":
ctx = await _cmd_status(providers_map)
return _render_cmd_template(cmd_templates, "status", locale, ctx)
if cmd == "devices":
ctx = await _cmd_devices(providers_map)
return _render_cmd_template(cmd_templates, "devices", locale, ctx)
if cmd == "battery":
ctx = await _cmd_battery(providers_map)
return _render_cmd_template(cmd_templates, "battery", locale, ctx)
return None
tracker: CommandTracker,
config: CommandConfig,
) -> CommandResponse | None:
fn = _TEXT_COMMANDS.get(cmd)
if fn is None:
return None
ctx = await fn(provider, count)
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
async def _query_all_ups(
providers_map: dict[int, ServiceProvider],
async def _query_ups(
provider: ServiceProvider,
) -> list[dict[str, Any]]:
"""Connect to all NUT providers and query UPS data."""
"""Connect to a NUT provider and query UPS data."""
from notify_bridge_core.providers.nut.models import NutUpsData
results: list[dict[str, Any]] = []
for provider in providers_map.values():
if provider.type != "nut":
continue
nut = make_nut_provider(provider)
nut = make_nut_provider(provider)
try:
client = nut._make_client()
await client.connect()
try:
client = nut._make_client()
await client.connect()
try:
devices = await client.list_ups()
for dev in devices:
variables = await client.list_var(dev.name)
data = NutUpsData.from_variables(dev.name, variables)
results.append({
"name": data.name,
"description": data.description,
"model": data.model,
"manufacturer": data.manufacturer,
"status": data.status,
"battery_charge": int(data.battery_charge) if data.battery_charge is not None else None,
"battery_runtime": data.battery_runtime_formatted,
"ups_load": int(data.ups_load) if data.ups_load is not None else None,
"input_voltage": str(data.input_voltage) if data.input_voltage is not None else None,
"output_voltage": str(data.output_voltage) if data.output_voltage is not None else None,
})
finally:
await client.disconnect()
except Exception as exc:
_LOGGER.warning("Failed to query NUT provider %s: %s", provider.name, exc)
devices = await client.list_ups()
for dev in devices:
variables = await client.list_var(dev.name)
data = NutUpsData.from_variables(dev.name, variables)
results.append({
"name": data.name,
"description": data.description,
"model": data.model,
"manufacturer": data.manufacturer,
"status": data.status,
"battery_charge": int(data.battery_charge) if data.battery_charge is not None else None,
"battery_runtime": data.battery_runtime_formatted,
"ups_load": int(data.ups_load) if data.ups_load is not None else None,
"input_voltage": str(data.input_voltage) if data.input_voltage is not None else None,
"output_voltage": str(data.output_voltage) if data.output_voltage is not None else None,
})
finally:
await client.disconnect()
except Exception as exc:
_LOGGER.warning("Failed to query NUT provider %s: %s", provider.name, exc)
return results
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
devices = await _query_all_ups(providers_map)
@_text_cmd
async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
devices = await _query_ups(provider)
return {"devices": devices}
async def _cmd_devices(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
@_text_cmd
async def _cmd_devices(provider: ServiceProvider, count: int) -> dict[str, Any]:
devices: list[dict[str, Any]] = []
for provider in providers_map.values():
if provider.type != "nut":
continue
nut = make_nut_provider(provider)
try:
device_list = await nut.list_collections()
devices.extend(device_list)
except Exception as exc:
_LOGGER.warning("Failed to list devices from %s: %s", provider.name, exc)
nut = make_nut_provider(provider)
try:
device_list = await nut.list_collections()
devices.extend(device_list)
except Exception as exc:
_LOGGER.warning("Failed to list devices from %s: %s", provider.name, exc)
return {"devices": devices}
async def _cmd_battery(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
devices = await _query_all_ups(providers_map)
@_text_cmd
async def _cmd_battery(provider: ServiceProvider, count: int) -> dict[str, Any]:
devices = await _query_ups(provider)
return {"devices": devices}
@@ -3,26 +3,47 @@
from __future__ import annotations
import logging
from collections.abc import Callable, Coroutine
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import (
CommandConfig, CommandTracker, EventLog,
NotificationTracker, ServiceProvider, TelegramBot,
CommandConfig, CommandTracker, ServiceProvider, TelegramBot,
)
from ..services import make_planka_provider
from .base import ProviderCommandHandler
from .handler import _render_cmd_template, _get_notification_trackers_for_providers
from ..services.http_session import get_http_session
from .base import CommandResponse, ProviderCommandHandler
from .command_utils import get_last_event_str, get_tracked_collection_ids, get_trackers_for_provider
from .handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__)
_PLANKA_COMMANDS = {"status", "boards", "cards", "lists"}
def _get_tracked_board_ids(
provider: ServiceProvider,
trackers: list,
) -> list[str]:
"""Get board IDs from tracked collection_ids for this provider."""
if not provider.config.get("api_key"):
return []
return get_tracked_collection_ids(provider, trackers)
# ---------------------------------------------------------------------------
# Command dispatch table
# ---------------------------------------------------------------------------
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
"""Register a function in the text command dispatch table."""
name = fn.__name__.removeprefix("_cmd_")
_TEXT_COMMANDS[name] = fn
return fn
class PlankaCommandHandler(ProviderCommandHandler):
"""Handles Planka-specific bot commands."""
@@ -43,69 +64,26 @@ class PlankaCommandHandler(ProviderCommandHandler):
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
if cmd == "status":
ctx = await _cmd_status(providers_map)
return _render_cmd_template(cmd_templates, "status", locale, ctx)
if cmd == "boards":
ctx = await _cmd_boards(providers_map)
return _render_cmd_template(cmd_templates, "boards", locale, ctx)
if cmd == "cards":
ctx = await _cmd_cards(providers_map, count)
return _render_cmd_template(cmd_templates, "cards", locale, ctx)
if cmd == "lists":
ctx = await _cmd_lists(providers_map)
return _render_cmd_template(cmd_templates, "lists", locale, ctx)
return None
tracker: CommandTracker,
config: CommandConfig,
) -> CommandResponse | None:
fn = _TEXT_COMMANDS.get(cmd)
if fn is None:
return None
ctx = await fn(provider, count)
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
def _get_tracked_board_ids(
providers_map: dict[int, ServiceProvider],
trackers: list[NotificationTracker],
) -> list[tuple[ServiceProvider, str]]:
"""Get (provider, board_id) tuples from tracked collection_ids."""
boards: list[tuple[ServiceProvider, str]] = []
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "planka":
continue
if not provider.config.get("api_key"):
continue
for board_id in (tracker.collection_ids or []):
entry = (provider, board_id)
if entry not in boards:
boards.append(entry)
# Also check filters.collections
for board_id in (tracker.filters or {}).get("collections", []):
entry = (provider, board_id)
if entry not in boards:
boards.append(entry)
return boards[:20]
@_text_cmd
async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_boards = _get_tracked_board_ids(provider, trackers)
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
# Last event
engine = get_engine()
async with AsyncSession(engine) as session:
tracker_ids = [t.id for t in trackers]
if tracker_ids:
result = await session.exec(
select(EventLog)
.where(EventLog.tracker_id.in_(tracker_ids))
.order_by(EventLog.created_at.desc()).limit(1)
)
last_event = result.first()
else:
last_event = None
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
tracker_ids = [t.id for t in trackers]
last_str = await get_last_event_str(tracker_ids)
return {
"boards_count": len(tracked_boards),
@@ -113,81 +91,69 @@ async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, An
}
async def _cmd_boards(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
@_text_cmd
async def _cmd_boards(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_boards = _get_tracked_board_ids(provider, trackers)
boards_data: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, board_id in tracked_boards:
planka = make_planka_provider(http, provider)
all_boards = await planka.client.get_boards()
for b in all_boards:
if str(b.get("id", "")) == board_id:
boards_data.append({"name": b.get("name", board_id)})
break
else:
boards_data.append({"name": board_id})
http = await get_http_session()
planka = make_planka_provider(http, provider)
all_boards = await planka.client.get_boards()
board_names = {str(b.get("id", "")): b.get("name", "") for b in all_boards}
for board_id in tracked_boards:
boards_data.append({"name": board_names.get(board_id, board_id)})
return {"boards": boards_data}
async def _cmd_cards(
providers_map: dict[int, ServiceProvider], count: int,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
@_text_cmd
async def _cmd_cards(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_boards = _get_tracked_board_ids(provider, trackers)
all_cards: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, board_id in tracked_boards:
planka = make_planka_provider(http, provider)
cards = await planka.client.get_board_cards(board_id, limit=count)
lists = await planka.client.get_board_lists(board_id)
lists_by_id = {str(lst.get("id", "")): lst.get("name", "") for lst in lists}
http = await get_http_session()
planka = make_planka_provider(http, provider)
boards = await planka.client.get_boards()
board_names = {str(b.get("id", "")): b.get("name", "") for b in boards}
boards = await planka.client.get_boards()
board_name = board_id
for b in boards:
if str(b.get("id", "")) == board_id:
board_name = b.get("name", board_id)
break
for board_id in tracked_boards:
cards = await planka.client.get_board_cards(board_id, limit=count)
lists = await planka.client.get_board_lists(board_id)
lists_by_id = {str(lst.get("id", "")): lst.get("name", "") for lst in lists}
board_name = board_names.get(board_id, board_id)
for card in cards:
list_id = str(card.get("listId", ""))
all_cards.append({
"name": card.get("name", ""),
"list_name": lists_by_id.get(list_id, ""),
"board_name": board_name,
})
for card in cards:
list_id = str(card.get("listId", ""))
all_cards.append({
"name": card.get("name", ""),
"list_name": lists_by_id.get(list_id, ""),
"board_name": board_name,
})
return {"cards": all_cards[:count]}
async def _cmd_lists(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
@_text_cmd
async def _cmd_lists(provider: ServiceProvider, count: int) -> dict[str, Any]:
trackers = await get_trackers_for_provider(provider.id)
tracked_boards = _get_tracked_board_ids(provider, trackers)
all_lists: list[dict[str, Any]] = []
async with aiohttp.ClientSession() as http:
for provider, board_id in tracked_boards:
planka = make_planka_provider(http, provider)
lists = await planka.client.get_board_lists(board_id)
http = await get_http_session()
planka = make_planka_provider(http, provider)
boards = await planka.client.get_boards()
board_names = {str(b.get("id", "")): b.get("name", "") for b in boards}
boards = await planka.client.get_boards()
board_name = board_id
for b in boards:
if str(b.get("id", "")) == board_id:
board_name = b.get("name", board_id)
break
for board_id in tracked_boards:
lists = await planka.client.get_board_lists(board_id)
board_name = board_names.get(board_id, board_id)
for lst in lists:
all_lists.append({
"name": lst.get("name", ""),
"board_name": board_name,
})
for lst in lists:
all_lists.append({
"name": lst.get("name", ""),
"board_name": board_name,
})
return {"lists": all_lists}
@@ -6,7 +6,6 @@ import hmac
import logging
from typing import Any
import aiohttp
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -16,6 +15,7 @@ from notify_bridge_core.notifications.telegram.client import TelegramClient
from ..database.engine import get_session
from ..database.models import TelegramBot, TelegramChat
from ..services.telegram import save_chat_from_webhook
from .base import CommandResponse
from .handler import handle_command, send_media_group, send_reply
_LOGGER = logging.getLogger(__name__)
@@ -89,15 +89,13 @@ async def telegram_webhook(
return {"ok": True, "skipped": "commands_disabled"}
effective_lang = chat_row.language_override or msg_language
message_id = message.get("message_id")
cmd_response = await handle_command(bot, chat_id, text, language_code=effective_lang)
if cmd_response is not None:
if isinstance(cmd_response, dict) and "media" in cmd_response:
await send_reply(bot.token, chat_id, cmd_response["text"], reply_to_message_id=message_id)
await send_media_group(bot.token, chat_id, cmd_response["media"], reply_to_message_id=message_id)
elif isinstance(cmd_response, list):
await send_media_group(bot.token, chat_id, cmd_response, reply_to_message_id=message_id)
else:
await send_reply(bot.token, chat_id, cmd_response, reply_to_message_id=message_id)
responses = await handle_command(bot, chat_id, text, language_code=effective_lang)
if responses:
for resp in responses:
if resp.text:
await send_reply(bot.token, chat_id, resp.text, reply_to_message_id=message_id)
if resp.media:
await send_media_group(bot.token, chat_id, resp.media, reply_to_message_id=message_id)
return {"ok": True}
return {"ok": True, "skipped": "not_a_command"}
@@ -105,13 +103,15 @@ async def telegram_webhook(
async def register_webhook(bot_token: str, webhook_url: str, secret: str | None = None) -> dict:
"""Register webhook URL with Telegram Bot API via TelegramClient."""
async with aiohttp.ClientSession() as http:
client = TelegramClient(http, bot_token)
return await client.set_webhook(webhook_url, secret=secret)
from ..services.http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, bot_token)
return await client.set_webhook(webhook_url, secret=secret)
async def unregister_webhook(bot_token: str) -> dict:
"""Remove webhook from Telegram Bot API via TelegramClient."""
async with aiohttp.ClientSession() as http:
client = TelegramClient(http, bot_token)
return await client.delete_webhook()
from ..services.http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, bot_token)
return await client.delete_webhook()
@@ -359,6 +359,7 @@ class NotificationTrackerState(SQLModel, table=True):
# Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id
tracker_id: int = Field(
foreign_key="notification_tracker.id",
index=True,
sa_column_kwargs={"name": "notification_tracker_id"},
)
collection_id: str
@@ -458,7 +459,7 @@ class CommandTrackerListener(SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
command_tracker_id: int = Field(
foreign_key="command_tracker.id",
index=True,
)
listener_type: str # e.g. "telegram_bot"
@@ -73,6 +73,8 @@ async def lifespan(app: FastAPI):
await start_scheduler()
yield
# Graceful shutdown
from .services.http_session import close_http_session
await close_http_session()
scheduler = get_scheduler()
if scheduler.running:
scheduler.shutdown()
@@ -1,5 +1,11 @@
"""Shared service utilities."""
from __future__ import annotations
from typing import Any, Protocol
import aiohttp
from notify_bridge_core.providers.immich import ImmichServiceProvider
from notify_bridge_core.providers.gitea import GiteaServiceProvider
from notify_bridge_core.providers.planka import PlankaServiceProvider
@@ -8,8 +14,23 @@ from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvid
from ..database.models import ServiceProvider
# Default timeout for all outgoing HTTP requests to external services.
DEFAULT_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
def make_immich_provider(http_session, provider: ServiceProvider) -> ImmichServiceProvider:
class CollectionProvider(Protocol):
"""Protocol for providers that can list collections."""
async def list_collections(self) -> list[dict[str, Any]]: ...
class TestableProvider(Protocol):
"""Protocol for providers that support connection testing."""
async def test_connection(self) -> dict[str, Any]: ...
def make_immich_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> ImmichServiceProvider:
"""Create an ImmichServiceProvider from a DB provider model."""
config = provider.config or {}
return ImmichServiceProvider(
@@ -21,7 +42,7 @@ def make_immich_provider(http_session, provider: ServiceProvider) -> ImmichServi
)
def make_gitea_provider(http_session, provider: ServiceProvider) -> GiteaServiceProvider:
def make_gitea_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> GiteaServiceProvider:
"""Create a GiteaServiceProvider from a DB provider model."""
config = provider.config or {}
return GiteaServiceProvider(
@@ -32,7 +53,7 @@ def make_gitea_provider(http_session, provider: ServiceProvider) -> GiteaService
)
def make_planka_provider(http_session, provider: ServiceProvider) -> PlankaServiceProvider:
def make_planka_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> PlankaServiceProvider:
"""Create a PlankaServiceProvider from a DB provider model."""
config = provider.config or {}
return PlankaServiceProvider(
@@ -55,7 +76,7 @@ def make_nut_provider(provider: ServiceProvider) -> NutServiceProvider:
)
def make_google_photos_provider(http_session, provider: ServiceProvider) -> GooglePhotosServiceProvider:
def make_google_photos_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> GooglePhotosServiceProvider:
"""Create a GooglePhotosServiceProvider from a DB provider model."""
config = provider.config or {}
return GooglePhotosServiceProvider(
@@ -65,3 +86,61 @@ def make_google_photos_provider(http_session, provider: ServiceProvider) -> Goog
config.get("refresh_token", ""),
provider.name,
)
# ---------------------------------------------------------------------------
# Provider factory registry — maps provider type strings to factory callables
# that create a provider with a ``list_collections`` method. Providers that
# require an API credential skip creation when the credential is missing
# (the factory returns None in that case).
# ---------------------------------------------------------------------------
def _make_collection_provider(
http_session: aiohttp.ClientSession,
provider: ServiceProvider,
) -> CollectionProvider | None:
"""Create a CollectionProvider for the given DB provider, or None if unsupported."""
ptype = provider.type
config = provider.config or {}
if ptype == "immich":
return make_immich_provider(http_session, provider)
if ptype == "gitea":
if not config.get("api_token"):
return None
return make_gitea_provider(http_session, provider)
if ptype == "planka":
if not config.get("api_key"):
return None
return make_planka_provider(http_session, provider)
if ptype == "google_photos":
return make_google_photos_provider(http_session, provider)
# NUT provider needs no http_session
if ptype == "nut":
return make_nut_provider(provider) # type: ignore[return-value]
return None
# Set of provider types that need an aiohttp session for collection listing.
_HTTP_COLLECTION_PROVIDERS = {"immich", "gitea", "planka", "google_photos"}
async def list_provider_collections(provider: ServiceProvider) -> list[dict[str, Any]]:
"""List collections for any supported provider type.
Returns an empty list for providers that don't support collections or
are missing required credentials.
"""
if provider.type in _HTTP_COLLECTION_PROVIDERS:
from .http_session import get_http_session
http_session = await get_http_session()
svc = _make_collection_provider(http_session, provider)
if svc is None:
return []
return await svc.list_collections()
# Non-HTTP providers (e.g. NUT)
svc = _make_collection_provider(None, provider) # type: ignore[arg-type]
if svc is None:
return []
return await svc.list_collections()
@@ -6,7 +6,6 @@ import logging
from datetime import datetime, timezone
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -159,27 +158,28 @@ async def _execute_with_provider(
)
from notify_bridge_core.providers.immich.client import ImmichClient
async with aiohttp.ClientSession() as http_session:
client = ImmichClient(
http_session,
provider_config.get("url", ""),
provider_config.get("api_key", ""),
from .http_session import get_http_session
http_session = await get_http_session()
client = ImmichClient(
http_session,
provider_config.get("url", ""),
provider_config.get("api_key", ""),
)
external_domain = provider_config.get("external_domain")
if external_domain:
client.external_domain = external_domain
# Verify connectivity
if not await client.ping():
return ActionResult(
success=False,
error=f"Cannot connect to Immich server ({provider_name})",
)
external_domain = provider_config.get("external_domain")
if external_domain:
client.external_domain = external_domain
# Verify connectivity
if not await client.ping():
return ActionResult(
success=False,
error=f"Cannot connect to Immich server ({provider_name})",
)
executor = ImmichActionExecutor(client)
if dry_run:
return await executor.dry_run(action_type, rule_configs, action_config)
return await executor.execute(action_type, rule_configs, action_config)
executor = ImmichActionExecutor(client)
if dry_run:
return await executor.dry_run(action_type, rule_configs, action_config)
return await executor.execute(action_type, rule_configs, action_config)
return ActionResult(
success=False,
@@ -0,0 +1,33 @@
"""Application-level shared aiohttp.ClientSession.
All outgoing HTTP requests in the server package should use the shared
session returned by ``get_http_session()`` instead of creating
per-request ``aiohttp.ClientSession`` instances. This keeps a single
TCP connection pool alive for the lifetime of the process, avoiding
the overhead of pool creation/teardown on every request.
Call ``close_http_session()`` once during application shutdown.
"""
from __future__ import annotations
import aiohttp
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
_session: aiohttp.ClientSession | None = None
async def get_http_session() -> aiohttp.ClientSession:
"""Get or create the shared HTTP session."""
global _session
if _session is None or _session.closed:
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
return _session
async def close_http_session() -> None:
"""Close the shared HTTP session (call on app shutdown)."""
global _session
if _session is not None and not _session.closed:
await _session.close()
_session = None
@@ -3,8 +3,6 @@
import logging
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -90,19 +88,21 @@ async def _send_telegram_broadcast(target: NotificationTarget, message: str, rec
if not receivers:
return {"success": False, "error": "No receivers configured"}
from .http_session import get_http_session
http = await get_http_session()
results: list[dict] = []
async with aiohttp.ClientSession() as session:
client = TelegramClient(session, bot_token)
for recv in receivers:
chat_id = recv.get("chat_id")
if not chat_id:
continue
result = await client.send_message(
chat_id=str(chat_id),
text=message,
disable_web_page_preview=bool(disable_preview),
)
results.append(result)
client = TelegramClient(http, bot_token)
for recv in receivers:
chat_id = recv.get("chat_id")
if not chat_id:
continue
result = await client.send_message(
chat_id=str(chat_id),
text=message,
disable_web_page_preview=bool(disable_preview),
)
results.append(result)
return _aggregate(results)
@@ -113,15 +113,17 @@ async def _send_webhook_broadcast(target: NotificationTarget, message: str, rece
if not receivers:
return {"success": False, "error": "No receivers configured"}
from .http_session import get_http_session
http = await get_http_session()
results: list[dict] = []
async with aiohttp.ClientSession() as session:
for recv in receivers:
url = recv.get("url")
headers = recv.get("headers", {})
if not url:
continue
client = WebhookClient(session, url, headers)
results.append(await client.send({"message": message, "event_type": "notification"}))
for recv in receivers:
url = recv.get("url")
headers = recv.get("headers", {})
if not url:
continue
client = WebhookClient(http, url, headers)
results.append(await client.send({"message": message, "event_type": "notification"}))
return _aggregate(results)
@@ -178,22 +180,24 @@ async def _send_webhook_like_broadcast(target: NotificationTarget, message: str,
if not receivers:
return {"success": False, "error": "No receivers configured"}
from .http_session import get_http_session
http = await get_http_session()
results: list[dict] = []
async with aiohttp.ClientSession() as session:
if target.type == "discord":
from notify_bridge_core.notifications.discord.client import DiscordClient
client = DiscordClient(session)
for recv in receivers:
url = recv.get("webhook_url")
if url:
results.append(await client.send(url, message, username=target.config.get("username")))
elif target.type == "slack":
from notify_bridge_core.notifications.slack.client import SlackClient
client = SlackClient(session)
for recv in receivers:
url = recv.get("webhook_url")
if url:
results.append(await client.send(url, message, username=target.config.get("username")))
if target.type == "discord":
from notify_bridge_core.notifications.discord.client import DiscordClient
client = DiscordClient(http)
for recv in receivers:
url = recv.get("webhook_url")
if url:
results.append(await client.send(url, message, username=target.config.get("username")))
elif target.type == "slack":
from notify_bridge_core.notifications.slack.client import SlackClient
client = SlackClient(http)
for recv in receivers:
url = recv.get("webhook_url")
if url:
results.append(await client.send(url, message, username=target.config.get("username")))
return _aggregate(results)
@@ -207,18 +211,20 @@ async def _send_ntfy_broadcast(target: NotificationTarget, message: str, receive
return {"success": False, "error": "No receivers configured"}
from notify_bridge_core.notifications.ntfy.client import NtfyClient
from .http_session import get_http_session
http = await get_http_session()
results: list[dict] = []
async with aiohttp.ClientSession() as session:
client = NtfyClient(session)
for recv in receivers:
topic = recv.get("topic")
if topic:
results.append(await client.send(
server_url, topic, message,
title="Notify Bridge",
priority=recv.get("priority", 3),
auth_token=auth_token,
))
client = NtfyClient(http)
for recv in receivers:
topic = recv.get("topic")
if topic:
results.append(await client.send(
server_url, topic, message,
title="Notify Bridge",
priority=recv.get("priority", 3),
auth_token=auth_token,
))
return _aggregate(results)
@@ -243,13 +249,15 @@ async def _send_matrix_broadcast(target: NotificationTarget, message: str, recei
if not receivers:
return {"success": False, "error": "No receivers configured"}
from .http_session import get_http_session
http = await get_http_session()
results: list[dict] = []
async with aiohttp.ClientSession() as http:
client = MatrixClient(http, homeserver, access_token)
for recv in receivers:
room_id = recv.get("room_id")
if room_id:
results.append(await client.send_message(room_id, message, html_message=message))
client = MatrixClient(http, homeserver, access_token)
for recv in receivers:
room_id = recv.get("room_id")
if room_id:
results.append(await client.send_message(room_id, message, html_message=message))
return _aggregate(results)
@@ -31,11 +31,50 @@ async def start_scheduler() -> None:
from .telegram_poller import start_command_listener_polling
await start_command_listener_polling()
# Schedule daily cleanup of old event log entries
_schedule_event_cleanup()
# Start debounced command auto-sync scheduler
from .command_sync import start_sync_scheduler
start_sync_scheduler()
def _schedule_event_cleanup() -> None:
"""Schedule a daily job to delete EventLog entries older than 90 days."""
from apscheduler.triggers.cron import CronTrigger
scheduler = get_scheduler()
job_id = "cleanup_old_events"
if scheduler.get_job(job_id):
return
scheduler.add_job(
_cleanup_old_events,
CronTrigger(hour=3, minute=0),
id=job_id,
replace_existing=True,
max_instances=1,
)
_LOGGER.info("Scheduled daily event log cleanup at 03:00 UTC")
async def _cleanup_old_events() -> None:
"""Delete EventLog entries older than 90 days."""
from datetime import datetime, timedelta, timezone
from sqlmodel import delete
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import EventLog
cutoff = datetime.now(timezone.utc) - timedelta(days=90)
engine = get_engine()
async with AsyncSession(engine) as session:
await session.exec(delete(EventLog).where(EventLog.created_at < cutoff))
await session.commit()
_LOGGER.info("Cleaned up event log entries older than %s", cutoff.date())
async def _load_tracker_jobs() -> None:
"""Load enabled trackers and schedule polling jobs."""
from sqlmodel import select
@@ -50,13 +89,16 @@ async def _load_tracker_jobs() -> None:
result = await session.exec(select(NotificationTracker).where(NotificationTracker.enabled == True))
trackers = result.all()
# Pre-load provider types for scheduler detection
# Batch-load provider types for scheduler detection
unique_provider_ids = list({t.provider_id for t in trackers})
provider_types: dict[int, str] = {}
for tracker in trackers:
if tracker.provider_id not in provider_types:
provider = await session.get(ServiceProviderModel, tracker.provider_id)
if provider:
provider_types[tracker.provider_id] = provider.type
if unique_provider_ids:
provider_result = await session.exec(
select(ServiceProviderModel).where(
ServiceProviderModel.id.in_(unique_provider_ids)
)
)
provider_types = {p.id: p.type for p in provider_result.all()}
for tracker in trackers:
job_id = f"tracker_{tracker.id}"
@@ -86,6 +128,7 @@ async def _load_tracker_jobs() -> None:
id=job_id,
args=[tracker.id],
replace_existing=True,
max_instances=1,
)
_LOGGER.info("Scheduled tracker %d (%s) every %ds", tracker.id, tracker.name, tracker.scan_interval)
@@ -106,6 +149,7 @@ def _add_cron_job(
id=job_id,
args=[tracker_id],
replace_existing=True,
max_instances=1,
)
_LOGGER.info("Scheduled tracker %d (%s) with cron: %s", tracker_id, tracker_name, cron_expression)
@@ -13,7 +13,6 @@ from __future__ import annotations
import logging
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -47,10 +46,18 @@ async def _get_bot_ids_with_active_listeners() -> set[int]:
listeners = result.all()
active_bot_ids: set[int] = set()
for listener in listeners:
tracker = await session.get(CommandTracker, listener.command_tracker_id)
if tracker and tracker.enabled:
active_bot_ids.add(listener.listener_id)
tracker_ids = list({l.command_tracker_id for l in listeners})
if tracker_ids:
tracker_result = await session.exec(
select(CommandTracker).where(
CommandTracker.id.in_(tracker_ids),
CommandTracker.enabled == True, # noqa: E712
)
)
enabled_tracker_ids = {t.id for t in tracker_result.all()}
for listener in listeners:
if listener.command_tracker_id in enabled_tracker_ids:
active_bot_ids.add(listener.listener_id)
return active_bot_ids
@@ -145,21 +152,23 @@ async def _poll_bot(bot_id: int) -> None:
if not bot or bot.update_mode != "polling":
unschedule_bot_polling(bot_id)
return
# Extract what we need before closing session
# Copy attributes before session closes to avoid detached-instance errors
from types import SimpleNamespace
bot_token = bot.token
bot_obj = bot
bot_obj = SimpleNamespace(id=bot.id, name=bot.name, token=bot.token)
offset = _last_update_id.get(bot_id, 0)
try:
async with aiohttp.ClientSession() as http:
client = TelegramClient(http, bot_token)
result = await client.get_updates(
offset=offset + 1 if offset else None, limit=50,
)
if not result.get("success"):
return
updates = result.get("result", [])
from .http_session import get_http_session
http = await get_http_session()
client = TelegramClient(http, bot_token)
result = await client.get_updates(
offset=offset + 1 if offset else None, limit=50,
)
if not result.get("success"):
return
updates = result.get("result", [])
except Exception as e:
_LOGGER.debug("Polling error for bot %d: %s", bot_id, e)
return
@@ -209,17 +218,13 @@ async def _poll_bot(bot_id: int) -> None:
continue
effective_lang = chat_row.language_override or msg_language
message_id = message.get("message_id")
cmd_response = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
if cmd_response is not None:
if isinstance(cmd_response, dict) and "media" in cmd_response:
# Text + media: send text first, media as reply
from ..commands.handler import send_reply as _reply
await _reply(bot_token, chat_id, cmd_response["text"], reply_to_message_id=message_id)
await send_media_group(bot_token, chat_id, cmd_response["media"], reply_to_message_id=message_id)
elif isinstance(cmd_response, list):
await send_media_group(bot_token, chat_id, cmd_response, reply_to_message_id=message_id)
else:
await send_reply(bot_token, chat_id, cmd_response, reply_to_message_id=message_id)
responses = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
if responses:
for resp in responses:
if resp.text:
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
if resp.media:
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
except Exception:
_LOGGER.error("Error handling command from bot %d", bot_id, exc_info=True)
@@ -7,7 +7,6 @@ objects and dispatches through the same path the watcher uses.
import logging
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -183,58 +182,59 @@ async def _build_immich_event(
memory_source = getattr(tracking_config, "memory_source", "albums") if tracking_config else "albums"
is_memory = test_type == "memory"
async with aiohttp.ClientSession() as http_session:
immich = ImmichServiceProvider(
http_session,
provider_config.get("url", ""),
provider_config.get("api_key", ""),
provider_config.get("external_domain"),
provider_name,
)
if not await immich.connect():
return None
from .http_session import get_http_session
http_session = await get_http_session()
immich = ImmichServiceProvider(
http_session,
provider_config.get("url", ""),
provider_config.get("api_key", ""),
provider_config.get("external_domain"),
provider_name,
)
if not await immich.connect():
return None
# Native Immich memories API path
if is_memory and memory_source == "native":
return await _build_native_memory_event(
immich, ext_domain, provider_name, tracker_name,
collection_ids, limit, asset_type, favorite_only, min_rating,
)
# Album-based path: use shared collect_scheduled_assets
albums: dict[str, ImmichAlbumData] = {}
shared_links: dict[str, list[SharedLinkInfo]] = {}
for album_id in collection_ids:
album = await immich.client.get_album(album_id)
if album:
albums[album_id] = album
shared_links[album_id] = await immich.client.get_shared_links(album_id)
assets, collections_extra = collect_scheduled_assets(
albums, shared_links, ext_domain,
limit=limit,
asset_type=asset_type,
favorite_only=favorite_only,
min_rating=min_rating,
is_memory=is_memory,
# Native Immich memories API path
if is_memory and memory_source == "native":
return await _build_native_memory_event(
immich, ext_domain, provider_name, tracker_name,
collection_ids, limit, asset_type, favorite_only, min_rating,
)
first_col = collections_extra[0] if collections_extra else {}
return ServiceEvent(
event_type=EventType.SCHEDULED_MESSAGE,
provider_type=ServiceProviderType.IMMICH,
provider_name=provider_name,
collection_id=collection_ids[0] if collection_ids else "",
collection_name=first_col.get("name", tracker_name),
timestamp=datetime.now(timezone.utc),
added_assets=assets,
added_count=len(assets),
extra={
"collections": collections_extra,
"albums": collections_extra,
**(first_col if first_col else {}),
},
)
# Album-based path: use shared collect_scheduled_assets
albums: dict[str, ImmichAlbumData] = {}
shared_links: dict[str, list[SharedLinkInfo]] = {}
for album_id in collection_ids:
album = await immich.client.get_album(album_id)
if album:
albums[album_id] = album
shared_links[album_id] = await immich.client.get_shared_links(album_id)
assets, collections_extra = collect_scheduled_assets(
albums, shared_links, ext_domain,
limit=limit,
asset_type=asset_type,
favorite_only=favorite_only,
min_rating=min_rating,
is_memory=is_memory,
)
first_col = collections_extra[0] if collections_extra else {}
return ServiceEvent(
event_type=EventType.SCHEDULED_MESSAGE,
provider_type=ServiceProviderType.IMMICH,
provider_name=provider_name,
collection_id=collection_ids[0] if collection_ids else "",
collection_name=first_col.get("name", tracker_name),
timestamp=datetime.now(timezone.utc),
added_assets=assets,
added_count=len(assets),
extra={
"collections": collections_extra,
"albums": collections_extra,
**(first_col if first_col else {}),
},
)
async def _build_native_memory_event(
@@ -6,7 +6,6 @@ import asyncio
import logging
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -102,19 +101,20 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
if provider_type == "immich":
from notify_bridge_core.providers.immich import ImmichServiceProvider
async with aiohttp.ClientSession() as http_session:
immich = ImmichServiceProvider(
http_session,
provider_config.get("url", ""),
provider_config.get("api_key", ""),
provider_config.get("external_domain"),
provider_name,
)
connected = await immich.connect()
if not connected:
return {"status": "error", "reason": "failed to connect to provider"}
from .http_session import get_http_session
http_session = await get_http_session()
immich = ImmichServiceProvider(
http_session,
provider_config.get("url", ""),
provider_config.get("api_key", ""),
provider_config.get("external_domain"),
provider_name,
)
connected = await immich.connect()
if not connected:
return {"status": "error", "reason": "failed to connect to provider"}
events, new_state = await immich.poll(collection_ids, state_dict)
events, new_state = await immich.poll(collection_ids, state_dict)
elif provider_type == "gitea":
# Gitea is webhook-based — events arrive via /api/webhooks/gitea endpoint.
# The scheduler still calls check_tracker but there's nothing to poll.
@@ -143,18 +143,22 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
events, new_state = await nut.poll(collection_ids, state_dict)
elif provider_type == "google_photos":
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
async with aiohttp.ClientSession() as http_session:
gp = GooglePhotosServiceProvider(
http_session,
provider_config.get("client_id", ""),
provider_config.get("client_secret", ""),
provider_config.get("refresh_token", ""),
provider_name,
)
connected = await gp.connect()
if not connected:
return {"status": "error", "reason": "failed to connect to Google Photos"}
events, new_state = await gp.poll(collection_ids, state_dict)
from .http_session import get_http_session
http_session = await get_http_session()
gp = GooglePhotosServiceProvider(
http_session,
provider_config.get("client_id", ""),
provider_config.get("client_secret", ""),
provider_config.get("refresh_token", ""),
provider_name,
)
connected = await gp.connect()
if not connected:
return {"status": "error", "reason": "failed to connect to Google Photos"}
events, new_state = await gp.poll(collection_ids, state_dict)
elif provider_type == "webhook":
# Webhook providers receive events via inbound HTTP; no polling needed.
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
else:
return {"status": "error", "reason": f"unsupported provider type: {provider_type}"}