refactor: comprehensive codebase review — security, performance, quality, UX
Security: - Fix NUT protocol command injection (validate names against safe regex) - Enable Jinja2 autoescape=True to prevent HTML injection via external data - Add WebhookProviderConfig validation model Performance: - Shared aiohttp.ClientSession singleton (replaces 40+ per-request sessions) - Fix 4 N+1 queries with batch IN loads (poller, scheduler, memory, broadcast) - asyncio.gather for Gitea commands and notification dispatcher - Add DB indexes on NotificationTrackerState.tracker_id, CommandTrackerListener - LRU cache for compiled Jinja2 templates - Daily EventLog cleanup job (90-day retention) - 30s HTTP timeout on all external calls - GROUP BY for target type counts (replaces 7 sequential queries) Code quality: - Extract get_owned_entity() helper (replaces 11 duplicate functions) - Extract slot_helpers.py (load_slots, save_slots, render_template_preview) - Extract command_utils.py (tracker lookup, last event, collection IDs) - Extract http_session.py (shared session lifecycle) - Provider connection validation dedup (3x → 1 helper) - Command dispatch tables replacing if/elif chains - Album+links fetch helper (fetch_albums_with_links) - Provider dispatch polymorphism (list_provider_collections) - Immutable _enrich_assets (no longer mutates in-place) - Fix _format_assets return type + handler unpacking Frontend: - Fix 18+ hardcoded English strings → t() with new i18n keys (en + ru) - Mobile "More" nav panel with provider filter and search - Shared Button.svelte component (4 variants, 2 sizes) - Shared ErrorBanner.svelte component (8 pages updated) - SvelteKit goto() replacing window.location.href - Dashboard grid fixed for 4 cards, paginator opacity consistency Functionality: - max_instances=1 on scheduler jobs (prevents duplicate events) - Webhook provider in watcher (prevents error spam) - Fix stale SQLModel reference in poller - Gitea get_repo() direct API call
This commit is contained in:
@@ -43,6 +43,17 @@ jobs:
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache
|
||||
cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max
|
||||
|
||||
- name: Trigger Portainer redeploy
|
||||
continue-on-error: true
|
||||
run: |
|
||||
if [ -n "${{ secrets.DOCKER_REDEPLOY_WEBHOOK_URL }}" ]; then
|
||||
echo "Triggering Portainer redeploy..."
|
||||
curl -sf -X POST "${{ secrets.DOCKER_REDEPLOY_WEBHOOK_URL }}" \
|
||||
--max-time 30 || echo "::warning::Portainer webhook failed"
|
||||
else
|
||||
echo "DOCKER_REDEPLOY_WEBHOOK_URL not set — skipping auto-deploy"
|
||||
fi
|
||||
|
||||
- name: Generate changelog
|
||||
id: changelog
|
||||
run: |
|
||||
@@ -56,7 +67,29 @@ jobs:
|
||||
|
||||
- name: Create Gitea Release
|
||||
run: |
|
||||
BODY=$(cat /tmp/changelog.txt | python3 -c "import sys,json; print(json.dumps(sys.stdin.read()))")
|
||||
if [ -f RELEASE_NOTES.md ]; then
|
||||
export RELEASE_NOTES=$(cat RELEASE_NOTES.md)
|
||||
echo "Found RELEASE_NOTES.md"
|
||||
else
|
||||
export RELEASE_NOTES=""
|
||||
echo "No RELEASE_NOTES.md found"
|
||||
fi
|
||||
|
||||
BODY=$(python3 -c "
|
||||
import json, os, sys
|
||||
|
||||
release_notes = os.environ.get('RELEASE_NOTES', '')
|
||||
changelog = open('/tmp/changelog.txt').read().strip()
|
||||
|
||||
sections = []
|
||||
if release_notes.strip():
|
||||
sections.append(release_notes.strip())
|
||||
if changelog:
|
||||
sections.append('## Changelog\n\n' + changelog)
|
||||
|
||||
print(json.dumps('\n\n'.join(sections)))
|
||||
")
|
||||
|
||||
curl -s -X POST \
|
||||
"https://${{ env.REGISTRY }}/api/v1/repos/${{ env.IMAGE_NAME }}/releases" \
|
||||
-H "Authorization: token ${{ secrets.RELEASE_TOKEN }}" \
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
<script lang="ts">
|
||||
import type { Snippet } from 'svelte';
|
||||
|
||||
let {
|
||||
variant = 'primary',
|
||||
size = 'md',
|
||||
disabled = false,
|
||||
type = 'button',
|
||||
href,
|
||||
onclick,
|
||||
children,
|
||||
class: extraClass = '',
|
||||
}: {
|
||||
variant?: 'primary' | 'secondary' | 'danger' | 'ghost';
|
||||
size?: 'sm' | 'md';
|
||||
disabled?: boolean;
|
||||
type?: 'button' | 'submit';
|
||||
href?: string;
|
||||
onclick?: (e: MouseEvent) => void;
|
||||
children: Snippet;
|
||||
class?: string;
|
||||
} = $props();
|
||||
|
||||
const baseClasses = 'inline-flex items-center justify-center gap-1.5 rounded-md text-sm font-medium transition-colors disabled:opacity-50';
|
||||
const sizeClasses: Record<string, string> = {
|
||||
sm: 'px-2.5 py-1 text-xs',
|
||||
md: 'px-4 py-2',
|
||||
};
|
||||
const variantClasses: Record<string, string> = {
|
||||
primary: 'btn-primary',
|
||||
secondary: 'btn-secondary',
|
||||
danger: 'btn-danger',
|
||||
ghost: 'btn-ghost',
|
||||
};
|
||||
|
||||
const classes = $derived(
|
||||
`${baseClasses} ${sizeClasses[size]} ${variantClasses[variant]} ${extraClass}`.trim()
|
||||
);
|
||||
</script>
|
||||
|
||||
{#if href && !disabled}
|
||||
<a {href} class={classes} onclick={onclick}>
|
||||
{@render children()}
|
||||
</a>
|
||||
{:else}
|
||||
<button {type} {disabled} class={classes} onclick={onclick}>
|
||||
{@render children()}
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.btn-primary {
|
||||
background: var(--color-primary);
|
||||
color: var(--color-primary-foreground);
|
||||
}
|
||||
.btn-primary:hover:not(:disabled) {
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.btn-secondary {
|
||||
background: var(--color-muted);
|
||||
color: var(--color-foreground);
|
||||
border: 1px solid var(--color-border);
|
||||
}
|
||||
.btn-secondary:hover:not(:disabled) {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.btn-danger {
|
||||
background: var(--color-error-fg);
|
||||
color: white;
|
||||
}
|
||||
.btn-danger:hover:not(:disabled) {
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.btn-ghost {
|
||||
background: transparent;
|
||||
color: var(--color-muted-foreground);
|
||||
}
|
||||
.btn-ghost:hover:not(:disabled) {
|
||||
background: var(--color-muted);
|
||||
color: var(--color-foreground);
|
||||
}
|
||||
</style>
|
||||
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import MdiIcon from './MdiIcon.svelte';
|
||||
import { t } from '$lib/i18n';
|
||||
|
||||
export interface EntityItem {
|
||||
value: string | number;
|
||||
@@ -142,7 +143,7 @@
|
||||
|
||||
<div class="ep-list" bind:this={listEl} role="listbox">
|
||||
{#if filtered.length === 0}
|
||||
<div class="ep-empty">No matches</div>
|
||||
<div class="ep-empty">{t('common.noMatches')}</div>
|
||||
{:else}
|
||||
{#each filtered as item, i}
|
||||
<button
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
<script lang="ts">
|
||||
interface Props {
|
||||
message: string;
|
||||
class?: string;
|
||||
}
|
||||
let { message, class: className = '' }: Props = $props();
|
||||
</script>
|
||||
|
||||
{#if message}
|
||||
<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4 {className}">
|
||||
{message}
|
||||
</div>
|
||||
{/if}
|
||||
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import MdiIcon from './MdiIcon.svelte';
|
||||
import { t } from '$lib/i18n';
|
||||
|
||||
export interface GridItem {
|
||||
value: string | number;
|
||||
@@ -117,7 +118,7 @@
|
||||
</button>
|
||||
{/each}
|
||||
{#if filtered.length === 0}
|
||||
<div class="icon-grid-empty" style="grid-column: 1 / -1; text-align: center; padding: 0.75rem; color: var(--color-muted-foreground); font-size: 0.75rem;">No matches</div>
|
||||
<div class="icon-grid-empty" style="grid-column: 1 / -1; text-align: center; padding: 0.75rem; color: var(--color-muted-foreground); font-size: 0.75rem;">{t('common.noMatches')}</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from 'svelte';
|
||||
import MdiIcon from './MdiIcon.svelte';
|
||||
import { t } from '$lib/i18n';
|
||||
|
||||
let { open = false, title = '', onclose, children } = $props<{
|
||||
open: boolean;
|
||||
@@ -93,7 +94,7 @@
|
||||
>
|
||||
<div style="display: flex; align-items: center; justify-content: space-between; padding: 1.5rem 1.5rem 1rem;">
|
||||
<h3 id="modal-title-{uniqueId}" style="font-size: 1.125rem; font-weight: 600;">{title}</h3>
|
||||
<button class="modal-close" onclick={onclose} aria-label="Close">
|
||||
<button class="modal-close" onclick={onclose} aria-label={t('common.close')}>
|
||||
<MdiIcon name="mdiClose" size={18} />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
<script lang="ts">
|
||||
import MdiIcon from './MdiIcon.svelte';
|
||||
import { t } from '$lib/i18n';
|
||||
|
||||
export interface MultiEntityItem {
|
||||
value: string;
|
||||
@@ -132,7 +133,7 @@
|
||||
|
||||
<div class="mes-list" bind:this={listEl} role="listbox">
|
||||
{#if filtered.length === 0}
|
||||
<div class="mes-empty">No matches</div>
|
||||
<div class="mes-empty">{t('common.noMatches')}</div>
|
||||
{:else}
|
||||
{#each filtered as item, i}
|
||||
{@const checked = (values || []).includes(item.value)}
|
||||
|
||||
@@ -56,7 +56,7 @@
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
<button class="snack-close" onclick={() => removeSnack(snack.id)} aria-label="Dismiss">
|
||||
<button class="snack-close" onclick={() => removeSnack(snack.id)} aria-label={t('common.dismiss')}>
|
||||
<MdiIcon name="mdiClose" size={14} />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
@@ -110,6 +110,8 @@ export const previewTargetTypeItems = (): GridItem[] => [
|
||||
{ value: 'email', icon: 'mdiEmailOutline', label: 'Email', desc: t('gridDesc.previewEmail') },
|
||||
{ value: 'discord', icon: 'mdiChat', label: 'Discord', desc: t('gridDesc.previewDiscord') },
|
||||
{ value: 'slack', icon: 'mdiSlack', label: 'Slack', desc: t('gridDesc.previewSlack') },
|
||||
{ value: 'ntfy', icon: 'mdiBellOutline', label: 'ntfy', desc: t('gridDesc.previewNtfy') },
|
||||
{ value: 'matrix', icon: 'mdiMatrix', label: 'Matrix', desc: t('gridDesc.previewMatrix') },
|
||||
];
|
||||
|
||||
// --- Provider type items (derived from descriptor registry) ---
|
||||
|
||||
@@ -36,7 +36,8 @@
|
||||
"targetMatrix": "Matrix",
|
||||
"targetBroadcast": "Broadcast",
|
||||
"automation": "Automation",
|
||||
"actions": "Actions"
|
||||
"actions": "Actions",
|
||||
"more": "More"
|
||||
},
|
||||
"auth": {
|
||||
"signIn": "Sign in",
|
||||
@@ -51,7 +52,9 @@
|
||||
"creatingAccount": "Creating account...",
|
||||
"passwordMismatch": "Passwords do not match",
|
||||
"passwordTooShort": "Password must be at least 8 characters",
|
||||
"or": "or"
|
||||
"or": "or",
|
||||
"loginFailed": "Login failed",
|
||||
"setupFailed": "Setup failed"
|
||||
},
|
||||
"dashboard": {
|
||||
"title": "Dashboard",
|
||||
@@ -150,7 +153,9 @@
|
||||
"gpRefreshTokenHint": "Obtain from Google OAuth Playground (developers.google.com/oauthplayground) with the Photos Library API scope.",
|
||||
"gpAllFieldsRequired": "Client ID, Client Secret, and Refresh Token are all required",
|
||||
"testAndSave": "Test & Save",
|
||||
"saveWithoutTest": "Save without testing"
|
||||
"saveWithoutTest": "Save without testing",
|
||||
"selectType": "Select a provider type",
|
||||
"testFailed": "Connection test failed"
|
||||
},
|
||||
"notificationTracker": {
|
||||
"title": "Notification Trackers",
|
||||
@@ -231,7 +236,8 @@
|
||||
"noLink": "No Link",
|
||||
"saveWithoutLinks": "Save without links",
|
||||
"createLinks": "Create {count} link(s)",
|
||||
"linksNote": "You can also create links manually in Immich."
|
||||
"linksNote": "You can also create links manually in Immich.",
|
||||
"createdLinks": "Created {count} public link(s)"
|
||||
},
|
||||
"templates": {
|
||||
"title": "Templates",
|
||||
@@ -409,7 +415,9 @@
|
||||
"cacheTtl": "Media cache TTL (hours)",
|
||||
"cacheTtlHint": "How long to cache uploaded Telegram file_ids before re-uploading (default: 48h)",
|
||||
"settingsSaved": "Settings saved",
|
||||
"noExternalDomain": "External domain URL not configured"
|
||||
"noExternalDomain": "External domain URL not configured",
|
||||
"saveFailed": "Failed to save bot",
|
||||
"webhookFailed": "Failed to register webhook"
|
||||
},
|
||||
"trackingConfig": {
|
||||
"title": "Tracking Configs",
|
||||
@@ -584,7 +592,7 @@
|
||||
"added_assets": "List of asset dicts (use {% for asset in added_assets %})",
|
||||
"removed_assets": "List of removed asset IDs (strings)",
|
||||
"shared": "Whether album is shared (boolean)",
|
||||
"target_type": "Target type: 'telegram' or 'webhook'",
|
||||
"target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
|
||||
"has_videos": "Whether added assets contain videos (boolean)",
|
||||
"has_photos": "Whether added assets contain photos (boolean)",
|
||||
"old_name": "Previous album name (rename events)",
|
||||
@@ -675,7 +683,8 @@
|
||||
"displayName": "Display Name",
|
||||
"testConnection": "Test connection",
|
||||
"noBots": "No Matrix bots yet.",
|
||||
"confirmDelete": "Delete this Matrix bot?"
|
||||
"confirmDelete": "Delete this Matrix bot?",
|
||||
"operationFailed": "Operation failed"
|
||||
},
|
||||
"emailBot": {
|
||||
"title": "Email Bots",
|
||||
@@ -693,7 +702,8 @@
|
||||
"useTls": "Use TLS/SSL",
|
||||
"testConnection": "Send test email",
|
||||
"noBots": "No email bots yet.",
|
||||
"confirmDelete": "Delete this email bot?"
|
||||
"confirmDelete": "Delete this email bot?",
|
||||
"operationFailed": "Operation failed"
|
||||
},
|
||||
"cmdTemplateConfig": {
|
||||
"title": "Command Templates",
|
||||
@@ -841,7 +851,12 @@
|
||||
"allTypes": "All types",
|
||||
"allProviders": "All providers",
|
||||
"noFilterResults": "No items match the current filter.",
|
||||
"redirecting": "Redirecting..."
|
||||
"redirecting": "Redirecting...",
|
||||
"noMatches": "No matches",
|
||||
"saveFailed": "Save failed",
|
||||
"loadFailed": "Failed to load data",
|
||||
"dismiss": "Dismiss",
|
||||
"systemSuffix": " (System)"
|
||||
},
|
||||
"templateSlot": {
|
||||
"message_assets_added": "New assets added to album",
|
||||
@@ -926,12 +941,15 @@
|
||||
"previewEmail": "Preview with email HTML format",
|
||||
"previewDiscord": "Preview with Discord markdown",
|
||||
"previewSlack": "Preview with Slack markdown",
|
||||
"previewNtfy": "Preview as ntfy notification",
|
||||
"previewMatrix": "Preview with Matrix HTML format",
|
||||
"providerImmich": "Self-hosted photo server",
|
||||
"providerGitea": "Self-hosted Git service",
|
||||
"providerPlanka": "Self-hosted Kanban board",
|
||||
"providerScheduler": "Time-based scheduled messages",
|
||||
"providerNut": "Network UPS monitoring",
|
||||
"providerGooglePhotos": "Google Photos albums & shared libraries"
|
||||
"providerGooglePhotos": "Google Photos albums & shared libraries",
|
||||
"providerWebhook": "Receive events via HTTP POST"
|
||||
},
|
||||
"error": {
|
||||
"notFound": "Page not found",
|
||||
|
||||
@@ -36,7 +36,8 @@
|
||||
"targetMatrix": "Matrix",
|
||||
"targetBroadcast": "Рассылка",
|
||||
"automation": "Автоматизация",
|
||||
"actions": "Действия"
|
||||
"actions": "Действия",
|
||||
"more": "Ещё"
|
||||
},
|
||||
"auth": {
|
||||
"signIn": "Войти",
|
||||
@@ -51,7 +52,9 @@
|
||||
"creatingAccount": "Создание...",
|
||||
"passwordMismatch": "Пароли не совпадают",
|
||||
"passwordTooShort": "Пароль должен быть не менее 8 символов",
|
||||
"or": "или"
|
||||
"or": "или",
|
||||
"loginFailed": "Ошибка входа",
|
||||
"setupFailed": "Ошибка настройки"
|
||||
},
|
||||
"dashboard": {
|
||||
"title": "Главная",
|
||||
@@ -150,7 +153,9 @@
|
||||
"gpRefreshTokenHint": "Получите через Google OAuth Playground (developers.google.com/oauthplayground) с областью Photos Library API.",
|
||||
"gpAllFieldsRequired": "Client ID, Client Secret и Refresh Token обязательны",
|
||||
"testAndSave": "Проверить и сохранить",
|
||||
"saveWithoutTest": "Сохранить без проверки"
|
||||
"saveWithoutTest": "Сохранить без проверки",
|
||||
"selectType": "Выберите тип провайдера",
|
||||
"testFailed": "Ошибка проверки подключения"
|
||||
},
|
||||
"notificationTracker": {
|
||||
"title": "Трекеры уведомлений",
|
||||
@@ -231,7 +236,8 @@
|
||||
"noLink": "Нет ссылки",
|
||||
"saveWithoutLinks": "Сохранить без ссылок",
|
||||
"createLinks": "Создать {count} ссылку(и)",
|
||||
"linksNote": "Вы также можете создать ссылки вручную в Immich."
|
||||
"linksNote": "Вы также можете создать ссылки вручную в Immich.",
|
||||
"createdLinks": "Создано публичных ссылок: {count}"
|
||||
},
|
||||
"templates": {
|
||||
"title": "Шаблоны",
|
||||
@@ -409,7 +415,9 @@
|
||||
"cacheTtl": "TTL кэша медиа (часы)",
|
||||
"cacheTtlHint": "Сколько хранить кэш Telegram file_id перед повторной загрузкой (по умолчанию: 48ч)",
|
||||
"settingsSaved": "Настройки сохранены",
|
||||
"noExternalDomain": "Внешний URL домена не настроен"
|
||||
"noExternalDomain": "Внешний URL домена не настроен",
|
||||
"saveFailed": "Не удалось сохранить бота",
|
||||
"webhookFailed": "Не удалось зарегистрировать webhook"
|
||||
},
|
||||
"trackingConfig": {
|
||||
"title": "Конфигурации отслеживания",
|
||||
@@ -584,7 +592,7 @@
|
||||
"added_assets": "Список файлов ({% for asset in added_assets %})",
|
||||
"removed_assets": "Список ID удалённых файлов (строки)",
|
||||
"shared": "Общий альбом (boolean)",
|
||||
"target_type": "Тип получателя: 'telegram' или 'webhook'",
|
||||
"target_type": "Тип получателя: telegram, webhook, email, discord, slack, ntfy или matrix",
|
||||
"has_videos": "Содержат ли добавленные файлы видео (boolean)",
|
||||
"has_photos": "Содержат ли добавленные файлы фото (boolean)",
|
||||
"old_name": "Прежнее название альбома (при переименовании)",
|
||||
@@ -675,7 +683,8 @@
|
||||
"displayName": "Отображаемое имя",
|
||||
"testConnection": "Проверить подключение",
|
||||
"noBots": "Matrix ботов пока нет.",
|
||||
"confirmDelete": "Удалить этот Matrix бот?"
|
||||
"confirmDelete": "Удалить этот Matrix бот?",
|
||||
"operationFailed": "Операция не удалась"
|
||||
},
|
||||
"emailBot": {
|
||||
"title": "Email боты",
|
||||
@@ -693,7 +702,8 @@
|
||||
"useTls": "Использовать TLS/SSL",
|
||||
"testConnection": "Отправить тестовое письмо",
|
||||
"noBots": "Email ботов пока нет.",
|
||||
"confirmDelete": "Удалить этот email бот?"
|
||||
"confirmDelete": "Удалить этот email бот?",
|
||||
"operationFailed": "Операция не удалась"
|
||||
},
|
||||
"cmdTemplateConfig": {
|
||||
"title": "Шаблоны команд",
|
||||
@@ -841,7 +851,12 @@
|
||||
"allTypes": "Все типы",
|
||||
"allProviders": "Все провайдеры",
|
||||
"noFilterResults": "Нет элементов, соответствующих фильтру.",
|
||||
"redirecting": "Перенаправление..."
|
||||
"redirecting": "Перенаправление...",
|
||||
"noMatches": "Ничего не найдено",
|
||||
"saveFailed": "Не удалось сохранить",
|
||||
"loadFailed": "Не удалось загрузить данные",
|
||||
"dismiss": "Закрыть",
|
||||
"systemSuffix": " (Системный)"
|
||||
},
|
||||
"templateSlot": {
|
||||
"message_assets_added": "Новые файлы добавлены в альбом",
|
||||
@@ -926,12 +941,15 @@
|
||||
"previewEmail": "Предпросмотр в формате Email HTML",
|
||||
"previewDiscord": "Предпросмотр в формате Discord",
|
||||
"previewSlack": "Предпросмотр в формате Slack",
|
||||
"previewNtfy": "Предпросмотр уведомления ntfy",
|
||||
"previewMatrix": "Предпросмотр в формате Matrix HTML",
|
||||
"providerImmich": "Фотосервер для самостоятельного размещения",
|
||||
"providerGitea": "Git-сервер для самостоятельного размещения",
|
||||
"providerPlanka": "Канбан-доска для самостоятельного размещения",
|
||||
"providerScheduler": "Запланированные сообщения по расписанию",
|
||||
"providerNut": "Мониторинг ИБП через NUT",
|
||||
"providerGooglePhotos": "Альбомы и общие библиотеки Google Фото"
|
||||
"providerGooglePhotos": "Альбомы и общие библиотеки Google Фото",
|
||||
"providerWebhook": "Приём событий через HTTP POST"
|
||||
},
|
||||
"error": {
|
||||
"notFound": "Страница не найдена",
|
||||
|
||||
@@ -215,15 +215,31 @@
|
||||
});
|
||||
}
|
||||
|
||||
// Mobile: flatten nav for bottom bar
|
||||
// Mobile: flatten nav for bottom bar (first 4 + "More" button)
|
||||
const mobileNavItems = $derived<NavItem[]>([
|
||||
{ href: '/', key: 'nav.dashboard', icon: 'mdiViewDashboard' },
|
||||
{ href: '/notification-trackers', key: 'nav.notification', icon: 'mdiBellOutline' },
|
||||
{ href: '/command-trackers', key: 'nav.commands', icon: 'mdiConsoleLine' },
|
||||
{ href: '/targets', key: 'nav.targets', icon: 'mdiTarget' },
|
||||
{ href: '/bots?tab=telegram', key: 'nav.bots', icon: 'mdiRobot' },
|
||||
]);
|
||||
|
||||
// "More" panel items — everything not in the bottom bar
|
||||
const mobileMoreItems = $derived<NavItem[]>([
|
||||
{ href: '/providers', key: 'nav.providers', icon: 'mdiServer' },
|
||||
{ href: '/bots?tab=telegram', key: 'nav.bots', icon: 'mdiRobot' },
|
||||
{ href: '/actions', key: 'nav.actions', icon: 'mdiPlayCircleOutline' },
|
||||
{ href: '/tracking-configs', key: 'nav.configs', icon: 'mdiCog' },
|
||||
{ href: '/template-configs', key: 'nav.templates', icon: 'mdiFileDocumentEdit' },
|
||||
{ href: '/command-configs', key: 'nav.configs', icon: 'mdiConsoleLine' },
|
||||
{ href: '/command-template-configs', key: 'nav.templates', icon: 'mdiCodeBracesBox' },
|
||||
...(auth.isAdmin ? [
|
||||
{ href: '/settings', key: 'nav.settings', icon: 'mdiCogOutline' },
|
||||
{ href: '/users', key: 'nav.users', icon: 'mdiAccountGroup' },
|
||||
] : []),
|
||||
]);
|
||||
|
||||
let mobileMoreOpen = $state(false);
|
||||
|
||||
const isAuthPage = $derived(
|
||||
page.url.pathname === '/login' || page.url.pathname === '/setup'
|
||||
);
|
||||
@@ -526,12 +542,50 @@
|
||||
<MdiIcon name={item.icon} size={20} />
|
||||
</a>
|
||||
{/each}
|
||||
<button onclick={logout} aria-label={t('nav.logout')}
|
||||
class="flex flex-col items-center gap-0.5 px-2 py-1.5 text-xs" style="color: var(--color-muted-foreground);">
|
||||
<MdiIcon name="mdiLogout" size={20} />
|
||||
<button onclick={() => openSearch?.()} aria-label={t('searchPalette.placeholder')}
|
||||
class="flex flex-col items-center gap-0.5 px-2 py-1.5 text-xs rounded-lg transition-all duration-200"
|
||||
style="color: var(--color-muted-foreground);">
|
||||
<MdiIcon name="mdiMagnify" size={20} />
|
||||
</button>
|
||||
<button onclick={() => mobileMoreOpen = !mobileMoreOpen} aria-label={t('nav.more')}
|
||||
class="flex flex-col items-center gap-0.5 px-2 py-1.5 text-xs rounded-lg transition-all duration-200"
|
||||
style="color: {mobileMoreOpen ? 'var(--color-primary)' : 'var(--color-muted-foreground)'};">
|
||||
<MdiIcon name="mdiDotsHorizontal" size={20} />
|
||||
</button>
|
||||
</nav>
|
||||
|
||||
<!-- Mobile "More" panel -->
|
||||
{#if mobileMoreOpen}
|
||||
<div class="mobile-more-backdrop" style="position: fixed; inset: 0; z-index: 49; background: rgba(0,0,0,0.4); backdrop-filter: blur(2px);"
|
||||
onclick={() => mobileMoreOpen = false} role="presentation"></div>
|
||||
<div class="mobile-more-panel" style="position: fixed; bottom: 3.25rem; left: 0; right: 0; z-index: 50; background: var(--color-sidebar); border-top: 1px solid var(--color-border); border-radius: 1rem 1rem 0 0; padding: 1rem; max-height: 60vh; overflow-y: auto;"
|
||||
transition:slide={{ duration: 200, easing: cubicOut }}>
|
||||
{#if allProviders.length > 1}
|
||||
<div class="mb-3 pb-3" style="border-bottom: 1px solid var(--color-border);">
|
||||
<IconGridSelect items={providerFilterItems} bind:value={providerFilterValue} columns={Math.min(providerFilterItems.length, 4)} compact />
|
||||
</div>
|
||||
{/if}
|
||||
<div class="grid grid-cols-3 gap-2">
|
||||
{#each mobileMoreItems as item}
|
||||
<a href={item.href}
|
||||
onclick={() => mobileMoreOpen = false}
|
||||
class="flex flex-col items-center gap-1 p-3 rounded-lg transition-all duration-200"
|
||||
style="color: {isActive(item.href) ? 'var(--color-primary)' : 'var(--color-muted-foreground)'}; background: {isActive(item.href) ? 'var(--color-sidebar-active)' : 'transparent'};"
|
||||
>
|
||||
<MdiIcon name={item.icon} size={20} />
|
||||
<span class="text-xs text-center leading-tight">{t(item.key)}</span>
|
||||
</a>
|
||||
{/each}
|
||||
<button onclick={() => { mobileMoreOpen = false; logout(); }}
|
||||
class="flex flex-col items-center gap-1 p-3 rounded-lg transition-all duration-200"
|
||||
style="color: var(--color-muted-foreground);">
|
||||
<MdiIcon name="mdiLogout" size={20} />
|
||||
<span class="text-xs text-center leading-tight">{t('nav.logout')}</span>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Main content -->
|
||||
<main class="flex-1 overflow-auto pb-16 md:pb-0">
|
||||
{#key page.url.pathname}
|
||||
@@ -579,6 +633,10 @@
|
||||
<style>
|
||||
@media (max-width: 767px) {
|
||||
.mobile-nav { display: flex !important; }
|
||||
.mobile-more-panel a:hover,
|
||||
.mobile-more-panel button:hover {
|
||||
background: var(--color-muted);
|
||||
}
|
||||
}
|
||||
|
||||
/* Provider filter chips */
|
||||
|
||||
@@ -231,7 +231,7 @@
|
||||
</div>
|
||||
</Card>
|
||||
{:else if status}
|
||||
<div class="grid grid-cols-1 sm:grid-cols-3 gap-4 mb-8 stagger-children">
|
||||
<div class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-4 gap-4 mb-8 stagger-children">
|
||||
{#each statCards as card, i}
|
||||
<div class="stat-card" style="--accent: {card.color};">
|
||||
<div class="stat-card-inner">
|
||||
@@ -289,7 +289,7 @@
|
||||
<div class="flex items-center justify-center gap-1">
|
||||
{#if totalPages > 1}
|
||||
<button onclick={() => goToPage(currentPage - 1)} disabled={currentPage <= 1}
|
||||
class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-30 disabled:cursor-default">
|
||||
class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-50 disabled:cursor-default">
|
||||
<MdiIcon name="mdiChevronLeft" size={16} />
|
||||
</button>
|
||||
{#each Array.from({ length: totalPages }, (_, i) => i + 1) as page}
|
||||
@@ -305,7 +305,7 @@
|
||||
{/if}
|
||||
{/each}
|
||||
<button onclick={() => goToPage(currentPage + 1)} disabled={currentPage >= totalPages}
|
||||
class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-30 disabled:cursor-default">
|
||||
class="px-2 py-1 text-sm border border-[var(--color-border)] rounded-md hover:bg-[var(--color-muted)] transition-colors disabled:opacity-50 disabled:cursor-default">
|
||||
<MdiIcon name="mdiChevronRight" size={16} />
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
import ConfirmModal from '$lib/components/ConfirmModal.svelte';
|
||||
import IconButton from '$lib/components/IconButton.svelte';
|
||||
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
|
||||
import Button from '$lib/components/Button.svelte';
|
||||
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
|
||||
import type { EmailBot } from '$lib/types';
|
||||
|
||||
let { onreload }: { onreload: () => Promise<void> } = $props();
|
||||
@@ -72,22 +74,21 @@
|
||||
try {
|
||||
const res = await api(`/email-bots/${botId}/test`, { method: 'POST' });
|
||||
if (res.success) snackSuccess(t('snack.emailBotTestSent'));
|
||||
else snackError(res.error || 'Failed');
|
||||
else snackError(res.error || t('emailBot.operationFailed'));
|
||||
} catch (err: any) { snackError(err.message); }
|
||||
emailTesting = { ...emailTesting, [botId]: false };
|
||||
}
|
||||
</script>
|
||||
|
||||
<PageHeader title={t('emailBot.title')} description={t('emailBot.description')}>
|
||||
<button onclick={() => { showEmailForm ? (showEmailForm = false, editingEmail = null) : openNewEmail(); }}
|
||||
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
|
||||
<Button size="sm" onclick={() => { showEmailForm ? (showEmailForm = false, editingEmail = null) : openNewEmail(); }}>
|
||||
{showEmailForm ? t('common.cancel') : t('emailBot.addBot')}
|
||||
</button>
|
||||
</Button>
|
||||
</PageHeader>
|
||||
|
||||
{#if showEmailForm}
|
||||
<Card class="mb-6">
|
||||
{#if error}<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>{/if}
|
||||
<ErrorBanner message={error} />
|
||||
<form onsubmit={saveEmailBot} class="space-y-3">
|
||||
<div>
|
||||
<label for="ebot-name" class="block text-sm font-medium mb-1">{t('emailBot.name')}</label>
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
import ConfirmModal from '$lib/components/ConfirmModal.svelte';
|
||||
import IconButton from '$lib/components/IconButton.svelte';
|
||||
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
|
||||
import Button from '$lib/components/Button.svelte';
|
||||
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
|
||||
import type { MatrixBot } from '$lib/types';
|
||||
|
||||
let { onreload }: { onreload: () => Promise<void> } = $props();
|
||||
@@ -70,22 +72,21 @@
|
||||
try {
|
||||
const res = await api(`/matrix-bots/${botId}/test`, { method: 'POST' });
|
||||
if (res.success) snackSuccess(t('snack.matrixBotTestOk'));
|
||||
else snackError(res.error || 'Failed');
|
||||
else snackError(res.error || t('matrixBot.operationFailed'));
|
||||
} catch (err: any) { snackError(err.message); }
|
||||
matrixTesting = { ...matrixTesting, [botId]: false };
|
||||
}
|
||||
</script>
|
||||
|
||||
<PageHeader title={t('matrixBot.title')} description={t('matrixBot.description')}>
|
||||
<button onclick={() => { showMatrixForm ? (showMatrixForm = false, editingMatrix = null) : openNewMatrix(); }}
|
||||
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
|
||||
<Button size="sm" onclick={() => { showMatrixForm ? (showMatrixForm = false, editingMatrix = null) : openNewMatrix(); }}>
|
||||
{showMatrixForm ? t('common.cancel') : t('matrixBot.addBot')}
|
||||
</button>
|
||||
</Button>
|
||||
</PageHeader>
|
||||
|
||||
{#if showMatrixForm}
|
||||
<Card class="mb-6">
|
||||
{#if error}<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>{/if}
|
||||
<ErrorBanner message={error} />
|
||||
<form onsubmit={saveMatrixBot} class="space-y-3">
|
||||
<div>
|
||||
<label for="mbot-name" class="block text-sm font-medium mb-1">{t('matrixBot.name')}</label>
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
import IconButton from '$lib/components/IconButton.svelte';
|
||||
import EntitySelect from '$lib/components/EntitySelect.svelte';
|
||||
import { snackSuccess, snackError, snackInfo } from '$lib/stores/snackbar.svelte';
|
||||
import Button from '$lib/components/Button.svelte';
|
||||
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
|
||||
import type { TelegramBot, TelegramChat } from '$lib/types';
|
||||
|
||||
interface CommandTrackerSummary { id: number; name: string; icon?: string; enabled: boolean }
|
||||
@@ -186,7 +188,7 @@
|
||||
try {
|
||||
const res = await api<ApiResult>(`/telegram-bots/${botId}/sync-commands`, { method: 'POST' });
|
||||
if (res.success) snackSuccess(t('telegramBot.commandsSynced'));
|
||||
else snackError(res.error || 'Failed');
|
||||
else snackError(res.error || t('telegramBot.saveFailed'));
|
||||
} catch (err: any) { snackError(err.message); }
|
||||
modeChanging = { ...modeChanging, [botId]: false };
|
||||
}
|
||||
@@ -218,7 +220,7 @@
|
||||
snackSuccess(res.verified ? t('telegramBot.webhookVerified') : t('telegramBot.webhookRegistered'));
|
||||
await loadWebhookStatus(botId);
|
||||
} else {
|
||||
snackError(res.error || 'Failed to register webhook');
|
||||
snackError(res.error || t('telegramBot.webhookFailed'));
|
||||
}
|
||||
} catch (err: any) { snackError(err.message); }
|
||||
modeChanging = { ...modeChanging, [botId]: false };
|
||||
@@ -229,7 +231,7 @@
|
||||
try {
|
||||
const res = await api<ApiResult>(`/telegram-bots/${botId}/webhook/unregister`, { method: 'POST' });
|
||||
if (res.success) { snackSuccess(t('telegramBot.webhookUnregistered')); await loadWebhookStatus(botId); }
|
||||
else snackError(res.error || 'Failed');
|
||||
else snackError(res.error || t('telegramBot.saveFailed'));
|
||||
} catch (err: any) { snackError(err.message); }
|
||||
modeChanging = { ...modeChanging, [botId]: false };
|
||||
}
|
||||
@@ -260,7 +262,7 @@
|
||||
try {
|
||||
const res = await api<ApiResult>(`/telegram-bots/${botId}/chats/${chatId}/test?locale=${getLocale()}`, { method: 'POST' });
|
||||
if (res.success) snackSuccess(t('snack.targetTestSent'));
|
||||
else snackError(res.error || 'Failed');
|
||||
else snackError(res.error || t('telegramBot.saveFailed'));
|
||||
} catch (err: any) { snackError(err.message); }
|
||||
chatTesting = { ...chatTesting, [key]: false };
|
||||
}
|
||||
@@ -277,15 +279,14 @@
|
||||
</script>
|
||||
|
||||
<PageHeader title={t('telegramBot.title')} description={t('telegramBot.description')}>
|
||||
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}
|
||||
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
|
||||
<Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
|
||||
{showForm ? t('common.cancel') : t('telegramBot.addBot')}
|
||||
</button>
|
||||
</Button>
|
||||
</PageHeader>
|
||||
|
||||
{#if showForm}
|
||||
<Card class="mb-6">
|
||||
{#if error}<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>{/if}
|
||||
<ErrorBanner message={error} />
|
||||
<form onsubmit={saveBot} class="space-y-3">
|
||||
<div>
|
||||
<label for="bot-name" class="block text-sm font-medium mb-1">{t('telegramBot.name')}</label>
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import IconGridSelect from '$lib/components/IconGridSelect.svelte';
|
||||
import { providerTypeItems, providerTypeFilterItems, responseModeItems } from '$lib/grid-items';
|
||||
import EntitySelect from '$lib/components/EntitySelect.svelte';
|
||||
import Button from '$lib/components/Button.svelte';
|
||||
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
|
||||
import { highlightFromUrl } from '$lib/highlight';
|
||||
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
|
||||
@@ -37,7 +38,7 @@
|
||||
let cmdTemplateConfigs = $derived(commandTemplateConfigsCache.items);
|
||||
const templateItems = $derived(cmdTemplateConfigs
|
||||
.filter((c) => c.provider_type === form.provider_type)
|
||||
.map((c) => ({ value: c.id, label: c.name + (c.user_id === 0 ? ' (System)' : ''), icon: c.icon || 'mdiCodeBracesBox', desc: c.provider_type }))
|
||||
.map((c) => ({ value: c.id, label: c.name + (c.user_id === 0 ? t('common.systemSuffix') : ''), icon: c.icon || 'mdiCodeBracesBox', desc: c.provider_type }))
|
||||
);
|
||||
let loaded = $state(false);
|
||||
let showForm = $state(false);
|
||||
@@ -151,10 +152,9 @@
|
||||
</script>
|
||||
|
||||
<PageHeader title={t('commandConfig.title')} description={t('commandConfig.description')}>
|
||||
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}
|
||||
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
|
||||
<Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
|
||||
{showForm ? t('common.cancel') : t('commandConfig.newConfig')}
|
||||
</button>
|
||||
</Button>
|
||||
</PageHeader>
|
||||
|
||||
{#if !loaded}<Loading />{:else}
|
||||
|
||||
@@ -32,7 +32,7 @@
|
||||
await login(username, password);
|
||||
window.location.href = '/';
|
||||
} catch (err: any) {
|
||||
error = err.message || 'Login failed';
|
||||
error = err.message || t('auth.loginFailed');
|
||||
}
|
||||
submitting = false;
|
||||
}
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
import { providerDefaultIcon } from '$lib/grid-items';
|
||||
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
|
||||
import { getDescriptor } from '$lib/providers';
|
||||
import Button from '$lib/components/Button.svelte';
|
||||
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
|
||||
import type { Tracker, TrackerTarget, TrackingConfig, TemplateConfig, NotificationTarget } from '$lib/types';
|
||||
|
||||
import TrackerForm from './TrackerForm.svelte';
|
||||
@@ -119,7 +121,7 @@
|
||||
capabilitiesCache.fetch(),
|
||||
]);
|
||||
} catch (err: any) {
|
||||
loadError = err.message || 'Failed to load data';
|
||||
loadError = err.message || t('common.loadFailed');
|
||||
snackError(loadError);
|
||||
} finally { loaded = true; highlightFromUrl(); }
|
||||
}
|
||||
@@ -212,7 +214,7 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
if (created > 0) snackSuccess(`Created ${created} public link(s)`);
|
||||
if (created > 0) snackSuccess(t('notificationTracker.createdLinks').replace('{count}', String(created)));
|
||||
linkWarning = null;
|
||||
linkCreating = false;
|
||||
await doSave();
|
||||
@@ -361,17 +363,16 @@
|
||||
</script>
|
||||
|
||||
<PageHeader title={t('notificationTracker.title')} description={t('notificationTracker.description')}>
|
||||
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}
|
||||
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
|
||||
<Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
|
||||
{showForm ? t('notificationTracker.cancel') : t('notificationTracker.newTracker')}
|
||||
</button>
|
||||
</Button>
|
||||
</PageHeader>
|
||||
|
||||
{#if !loaded}
|
||||
<Loading />
|
||||
{:else if loadError}
|
||||
<Card>
|
||||
<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3">{loadError}</div>
|
||||
<ErrorBanner message={loadError} class="mb-0" />
|
||||
</Card>
|
||||
{:else if showForm}
|
||||
<TrackerForm
|
||||
|
||||
@@ -12,12 +12,14 @@
|
||||
import EmptyState from '$lib/components/EmptyState.svelte';
|
||||
import ConfirmModal from '$lib/components/ConfirmModal.svelte';
|
||||
import IconButton from '$lib/components/IconButton.svelte';
|
||||
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
|
||||
import IconGridSelect from '$lib/components/IconGridSelect.svelte';
|
||||
import { providerTypeItems, providerDefaultIcon } from '$lib/grid-items';
|
||||
import { globalProviderFilter } from '$lib/stores/provider-filter.svelte';
|
||||
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
|
||||
import { highlightFromUrl } from '$lib/highlight';
|
||||
import { getDescriptor, buildProviderFormDefaults } from '$lib/providers';
|
||||
import Button from '$lib/components/Button.svelte';
|
||||
import type { ServiceProvider } from '$lib/types';
|
||||
|
||||
let allProviders = $derived(providersCache.items);
|
||||
@@ -136,10 +138,9 @@
|
||||
</script>
|
||||
|
||||
<PageHeader title={t('providers.title')} description={t('providers.description')}>
|
||||
<button onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}
|
||||
class="px-3 py-1.5 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90">
|
||||
<Button size="sm" onclick={() => { showForm ? (showForm = false, editing = null) : openNew(); }}>
|
||||
{showForm ? t('providers.cancel') : t('providers.addProvider')}
|
||||
</button>
|
||||
</Button>
|
||||
</PageHeader>
|
||||
|
||||
{#if !loaded}
|
||||
@@ -158,9 +159,7 @@
|
||||
{#if showForm}
|
||||
<div in:slide={{ duration: 200 }}>
|
||||
<Card class="mb-6">
|
||||
{#if error}
|
||||
<div class="bg-[var(--color-error-bg)] text-[var(--color-error-fg)] text-sm rounded-md p-3 mb-4">{error}</div>
|
||||
{/if}
|
||||
<ErrorBanner message={error} />
|
||||
<form onsubmit={save} class="space-y-3">
|
||||
<div>
|
||||
<label class="block text-sm font-medium mb-1">{t('providers.type')}</label>
|
||||
@@ -211,10 +210,9 @@
|
||||
<p class="text-xs text-[var(--color-muted-foreground)] mt-1">{t('providers.webhookUrlHint')}</p>
|
||||
</div>
|
||||
{/if}
|
||||
<button type="submit" disabled={submitting}
|
||||
class="px-4 py-2 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90 disabled:opacity-50">
|
||||
<Button type="submit" disabled={submitting}>
|
||||
{submitting ? t('providers.connecting') : (editing ? t('common.save') : t('providers.addProvider'))}
|
||||
</button>
|
||||
</Button>
|
||||
</form>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
@@ -5,8 +5,11 @@
|
||||
import Card from '$lib/components/Card.svelte';
|
||||
import IconPicker from '$lib/components/IconPicker.svelte';
|
||||
import IconGridSelect from '$lib/components/IconGridSelect.svelte';
|
||||
import { goto } from '$app/navigation';
|
||||
import { providerTypeItems } from '$lib/grid-items';
|
||||
import { getDescriptor, buildProviderFormDefaults } from '$lib/providers';
|
||||
import Button from '$lib/components/Button.svelte';
|
||||
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
|
||||
|
||||
let form = $state(buildProviderFormDefaults());
|
||||
let error = $state('');
|
||||
@@ -16,7 +19,7 @@
|
||||
|
||||
async function testAndSave() {
|
||||
const desc = descriptor;
|
||||
if (!desc) { error = 'Select a provider type'; return; }
|
||||
if (!desc) { error = t('providers.selectType'); return; }
|
||||
const { config, error: buildError } = desc.buildConfig(form, false);
|
||||
if (buildError) { error = t(buildError); snackError(error); return; }
|
||||
|
||||
@@ -32,22 +35,22 @@
|
||||
if (!result.ok) {
|
||||
await api(`/providers/${provider.id}`, { method: 'DELETE' }).catch(() => {});
|
||||
createdId = null;
|
||||
error = result.message || 'Connection test failed';
|
||||
error = result.message || t('providers.testFailed');
|
||||
snackError(error);
|
||||
} else {
|
||||
snackSuccess(t('snack.providerSaved'));
|
||||
window.location.href = '/providers';
|
||||
goto('/providers');
|
||||
}
|
||||
} catch (e: any) {
|
||||
if (createdId) await api(`/providers/${createdId}`, { method: 'DELETE' }).catch(() => {});
|
||||
error = e.message || 'Test failed'; snackError(error);
|
||||
error = e.message || t('providers.testFailed'); snackError(error);
|
||||
}
|
||||
finally { testing = false; }
|
||||
}
|
||||
|
||||
async function saveWithoutTest() {
|
||||
const desc = descriptor;
|
||||
if (!desc) { error = 'Select a provider type'; return; }
|
||||
if (!desc) { error = t('providers.selectType'); return; }
|
||||
const { config, error: buildError } = desc.buildConfig(form, false);
|
||||
if (buildError) { error = t(buildError); snackError(error); return; }
|
||||
|
||||
@@ -58,8 +61,8 @@
|
||||
body: JSON.stringify({ type: form.type, name: form.name || desc.defaultName, icon: form.icon, config }),
|
||||
});
|
||||
snackSuccess(t('snack.providerSaved'));
|
||||
window.location.href = '/providers';
|
||||
} catch (e: any) { error = e.message || 'Save failed'; snackError(error); }
|
||||
goto('/providers');
|
||||
} catch (e: any) { error = e.message || t('common.saveFailed'); snackError(error); }
|
||||
finally { saving = false; }
|
||||
}
|
||||
</script>
|
||||
@@ -112,22 +115,18 @@
|
||||
</div>
|
||||
{/each}
|
||||
|
||||
{#if error}
|
||||
<p class="text-sm text-[var(--color-error-fg)]">{error}</p>
|
||||
{/if}
|
||||
<ErrorBanner message={error} />
|
||||
|
||||
<div class="flex gap-3 pt-2">
|
||||
<button onclick={testAndSave} disabled={testing || saving}
|
||||
class="px-4 py-2 bg-[var(--color-primary)] text-[var(--color-primary-foreground)] rounded-md text-sm font-medium hover:opacity-90 disabled:opacity-50">
|
||||
<Button onclick={testAndSave} disabled={testing || saving}>
|
||||
{testing ? t('providers.connecting') : t('providers.testAndSave')}
|
||||
</button>
|
||||
<button onclick={saveWithoutTest} disabled={testing || saving}
|
||||
class="px-4 py-2 bg-[var(--color-muted)] text-[var(--color-foreground)] rounded-md text-sm font-medium hover:opacity-80 disabled:opacity-50">
|
||||
</Button>
|
||||
<Button variant="secondary" onclick={saveWithoutTest} disabled={testing || saving}>
|
||||
{saving ? t('common.loading') : t('providers.saveWithoutTest')}
|
||||
</button>
|
||||
<a href="/providers" class="px-4 py-2 bg-[var(--color-muted)] text-[var(--color-muted-foreground)] rounded-md text-sm font-medium hover:opacity-80">
|
||||
</Button>
|
||||
<Button variant="secondary" href="/providers">
|
||||
{t('common.cancel')}
|
||||
</a>
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
import Loading from '$lib/components/Loading.svelte';
|
||||
import MdiIcon from '$lib/components/MdiIcon.svelte';
|
||||
import Hint from '$lib/components/Hint.svelte';
|
||||
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
|
||||
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
|
||||
|
||||
let loaded = $state(false);
|
||||
let saving = $state(false);
|
||||
let error = $state('');
|
||||
let settings = $state({
|
||||
external_url: '',
|
||||
telegram_webhook_secret: '',
|
||||
@@ -20,16 +22,16 @@
|
||||
onMount(async () => {
|
||||
try {
|
||||
settings = await api('/settings');
|
||||
} catch (err: any) { snackError(err.message); }
|
||||
} catch (err: any) { error = err.message; snackError(err.message); }
|
||||
finally { loaded = true; }
|
||||
});
|
||||
|
||||
async function save() {
|
||||
saving = true;
|
||||
saving = true; error = '';
|
||||
try {
|
||||
settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) });
|
||||
snackSuccess(t('settings.saved'));
|
||||
} catch (err: any) { snackError(err.message); }
|
||||
} catch (err: any) { error = err.message; snackError(err.message); }
|
||||
saving = false;
|
||||
}
|
||||
</script>
|
||||
@@ -39,6 +41,7 @@
|
||||
{#if !loaded}
|
||||
<Loading />
|
||||
{:else}
|
||||
<ErrorBanner message={error} />
|
||||
<div class="space-y-6">
|
||||
<!-- General section -->
|
||||
<Card>
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
try {
|
||||
await setup(username, password);
|
||||
window.location.href = '/';
|
||||
} catch (err: any) { error = err.message || 'Setup failed'; }
|
||||
} catch (err: any) { error = err.message || t('auth.setupFailed'); }
|
||||
submitting = false;
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import { chatActionItems } from '$lib/grid-items';
|
||||
import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte';
|
||||
import { highlightFromUrl } from '$lib/highlight';
|
||||
import ErrorBanner from '$lib/components/ErrorBanner.svelte';
|
||||
import type { NotificationTarget, TargetReceiver, TelegramChat } from '$lib/types';
|
||||
|
||||
import TargetForm from './TargetForm.svelte';
|
||||
@@ -419,7 +420,7 @@
|
||||
{#if !loaded}<Loading />{:else}
|
||||
|
||||
{#if loadError}
|
||||
<div class="mb-4 p-3 rounded-md text-sm bg-[var(--color-error-bg)] text-[var(--color-error-fg)]">{loadError}</div>
|
||||
<ErrorBanner message={loadError} />
|
||||
{/if}
|
||||
|
||||
{#if showForm}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
@@ -68,14 +69,17 @@ class NotificationDispatcher:
|
||||
|
||||
Returns list of results (one per target).
|
||||
"""
|
||||
raw_results = await asyncio.gather(
|
||||
*[self._send_to_target(event, t) for t in targets],
|
||||
return_exceptions=True,
|
||||
)
|
||||
results = []
|
||||
for target in targets:
|
||||
try:
|
||||
result = await self._send_to_target(event, target)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
_LOGGER.error("Failed to dispatch to target: %s", e)
|
||||
results.append({"success": False, "error": str(e)})
|
||||
for raw in raw_results:
|
||||
if isinstance(raw, Exception):
|
||||
_LOGGER.error("Failed to dispatch to target: %s", raw)
|
||||
results.append({"success": False, "error": str(raw)})
|
||||
else:
|
||||
results.append(raw)
|
||||
return results
|
||||
|
||||
def _resolve_template(
|
||||
|
||||
@@ -85,6 +85,20 @@ class GiteaClient:
|
||||
return repos
|
||||
|
||||
|
||||
async def get_repo(self, owner: str, repo: str) -> dict[str, Any] | None:
|
||||
"""Fetch a single repository by owner/repo name."""
|
||||
try:
|
||||
async with self._session.get(
|
||||
f"{self._url}/api/v1/repos/{owner}/{repo}",
|
||||
headers=self._headers,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
_LOGGER.warning("Failed to fetch repo %s/%s: HTTP %s", owner, repo, response.status)
|
||||
except aiohttp.ClientError as err:
|
||||
_LOGGER.warning("Failed to fetch repo %s/%s: %s", owner, repo, err)
|
||||
return None
|
||||
|
||||
async def get_repo_issues(
|
||||
self, owner: str, repo: str, state: str = "open", limit: int = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
|
||||
@@ -14,12 +14,28 @@ _DEFAULT_PORT = 3493
|
||||
_READ_TIMEOUT = 10.0
|
||||
_CONNECT_TIMEOUT = 5.0
|
||||
|
||||
# Allowed characters for NUT protocol identifiers (UPS names, variable names).
|
||||
# Prevents command injection via newlines or special characters.
|
||||
_SAFE_NAME_RE = re.compile(r"^[\w.\-]+$")
|
||||
|
||||
# Regex to parse VAR lines: VAR <ups> <name> "<value>"
|
||||
_VAR_RE = re.compile(r'^VAR\s+(\S+)\s+(\S+)\s+"(.*)"$')
|
||||
# Regex to parse UPS lines: UPS <name> "<description>"
|
||||
_UPS_RE = re.compile(r'^UPS\s+(\S+)\s+"(.*)"$')
|
||||
|
||||
|
||||
def _validate_name(value: str, label: str) -> None:
|
||||
"""Validate that *value* is a safe NUT protocol identifier.
|
||||
|
||||
Raises ``NutClientError`` if *value* contains characters outside
|
||||
``[\\w.\\-]``, which could be used for protocol command injection.
|
||||
"""
|
||||
if not _SAFE_NAME_RE.match(value):
|
||||
raise NutClientError(
|
||||
f"Invalid {label}: {value!r} contains disallowed characters"
|
||||
)
|
||||
|
||||
|
||||
class NutClientError(Exception):
|
||||
"""Error communicating with NUT server."""
|
||||
|
||||
@@ -91,6 +107,7 @@ class NutClient:
|
||||
|
||||
async def list_var(self, ups_name: str) -> dict[str, str]:
|
||||
"""Get all variables for a UPS device."""
|
||||
_validate_name(ups_name, "UPS name")
|
||||
lines = await self._list_command(f"LIST VAR {ups_name}")
|
||||
variables: dict[str, str] = {}
|
||||
for line in lines:
|
||||
@@ -101,6 +118,8 @@ class NutClient:
|
||||
|
||||
async def get_var(self, ups_name: str, var_name: str) -> str:
|
||||
"""Get a single variable value."""
|
||||
_validate_name(ups_name, "UPS name")
|
||||
_validate_name(var_name, "variable name")
|
||||
response = await self._command(f"GET VAR {ups_name} {var_name}")
|
||||
m = _VAR_RE.match(response)
|
||||
if m:
|
||||
|
||||
@@ -10,7 +10,7 @@ from jinja2.sandbox import SandboxedEnvironment
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_env = SandboxedEnvironment(autoescape=False)
|
||||
_env = SandboxedEnvironment(autoescape=True)
|
||||
|
||||
|
||||
def render_template(template_str: str, context: dict[str, Any]) -> str:
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import Action, ActionRule, User
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,10 +60,9 @@ def _rule_response(rule: ActionRule) -> dict:
|
||||
async def _get_user_action(
|
||||
session: AsyncSession, action_id: int, user: User
|
||||
) -> Action:
|
||||
action = await session.get(Action, action_id)
|
||||
if not action or action.user_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Action not found")
|
||||
return action
|
||||
return await get_owned_entity(
|
||||
session, Action, action_id, user.id, not_found_msg="Action not found",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -12,12 +12,10 @@ from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import CommandTemplateConfig, CommandTemplateSlot, User
|
||||
from .slot_helpers import load_slots, render_template_preview, save_slots
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,38 +42,11 @@ class CommandTemplateConfigUpdate(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]:
|
||||
"""Load slots as {slot_name: {locale: template}}."""
|
||||
result = await session.exec(
|
||||
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == config_id)
|
||||
)
|
||||
nested: dict[str, dict[str, str]] = {}
|
||||
for s in result.all():
|
||||
nested.setdefault(s.slot_name, {})[s.locale] = s.template
|
||||
return nested
|
||||
return await load_slots(session, CommandTemplateSlot, config_id)
|
||||
|
||||
|
||||
async def _save_slots(session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]]) -> None:
|
||||
"""Save slots from {slot_name: {locale: template}} format."""
|
||||
for slot_name, locale_map in slots.items():
|
||||
for locale, template_text in locale_map.items():
|
||||
result = await session.exec(
|
||||
select(CommandTemplateSlot).where(
|
||||
CommandTemplateSlot.config_id == config_id,
|
||||
CommandTemplateSlot.slot_name == slot_name,
|
||||
CommandTemplateSlot.locale == locale,
|
||||
)
|
||||
)
|
||||
existing = result.first()
|
||||
if existing:
|
||||
existing.template = template_text
|
||||
session.add(existing)
|
||||
else:
|
||||
session.add(CommandTemplateSlot(
|
||||
config_id=config_id,
|
||||
slot_name=slot_name,
|
||||
locale=locale,
|
||||
template=template_text,
|
||||
))
|
||||
await save_slots(session, CommandTemplateSlot, config_id, slots)
|
||||
|
||||
|
||||
async def _response(session: AsyncSession, c: CommandTemplateConfig) -> dict[str, Any]:
|
||||
@@ -367,18 +338,4 @@ async def preview_raw(
|
||||
"wait": 15,
|
||||
}
|
||||
|
||||
try:
|
||||
env = SandboxedEnvironment(autoescape=False)
|
||||
env.from_string(body.template)
|
||||
except TemplateSyntaxError as e:
|
||||
return {"rendered": None, "error": e.message, "error_line": e.lineno}
|
||||
|
||||
try:
|
||||
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
|
||||
tmpl = strict_env.from_string(body.template)
|
||||
rendered = tmpl.render(**sample_ctx)
|
||||
return {"rendered": rendered}
|
||||
except UndefinedError as e:
|
||||
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
|
||||
except Exception as e:
|
||||
return {"rendered": None, "error": str(e), "error_line": None}
|
||||
return render_template_preview(body.template, sample_ctx)
|
||||
|
||||
@@ -17,6 +17,7 @@ from ..database.models import (
|
||||
TelegramBot,
|
||||
User,
|
||||
)
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -401,7 +402,7 @@ async def _listener_response(session: AsyncSession, l: CommandTrackerListener) -
|
||||
async def _get_user_tracker(
|
||||
session: AsyncSession, tracker_id: int, user_id: int
|
||||
) -> CommandTracker:
|
||||
tracker = await session.get(CommandTracker, tracker_id)
|
||||
if not tracker or tracker.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Command tracker not found")
|
||||
return tracker
|
||||
return await get_owned_entity(
|
||||
session, CommandTracker, tracker_id, user_id,
|
||||
not_found_msg="Command tracker not found",
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import EmailBot, User
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -156,7 +157,6 @@ def _response(bot: EmailBot) -> dict:
|
||||
|
||||
|
||||
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> EmailBot:
|
||||
bot = await session.get(EmailBot, bot_id)
|
||||
if not bot or bot.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Email bot not found")
|
||||
return bot
|
||||
return await get_owned_entity(
|
||||
session, EmailBot, bot_id, user_id, not_found_msg="Email bot not found",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Shared helpers for API route modules."""
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
T = TypeVar("T", bound=SQLModel)
|
||||
|
||||
|
||||
async def get_owned_entity(
|
||||
session: AsyncSession,
|
||||
model: type[T],
|
||||
entity_id: int,
|
||||
user_id: int,
|
||||
*,
|
||||
owner_field: str = "user_id",
|
||||
not_found_msg: str = "Not found",
|
||||
) -> T:
|
||||
"""Fetch an entity by PK and verify ownership, or raise 404."""
|
||||
entity = await session.get(model, entity_id)
|
||||
if not entity or getattr(entity, owner_field) != user_id:
|
||||
raise HTTPException(status_code=404, detail=not_found_msg)
|
||||
return entity
|
||||
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import MatrixBot, User
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -108,33 +109,34 @@ async def test_matrix_bot(
|
||||
bot = await _get_user_bot(session, bot_id, user.id)
|
||||
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as http:
|
||||
# Verify token with /whoami
|
||||
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
|
||||
headers = {"Authorization": f"Bearer {bot.access_token}"}
|
||||
try:
|
||||
async with http.get(whoami_url, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
body = await resp.text()
|
||||
return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"}
|
||||
whoami = await resp.json()
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": f"Connection failed: {e}"}
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
# Verify token with /whoami
|
||||
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
|
||||
headers = {"Authorization": f"Bearer {bot.access_token}"}
|
||||
try:
|
||||
async with http.get(whoami_url, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
body = await resp.text()
|
||||
return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"}
|
||||
whoami = await resp.json()
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": f"Connection failed: {e}"}
|
||||
|
||||
result = {"success": True, "user_id": whoami.get("user_id", "")}
|
||||
result = {"success": True, "user_id": whoami.get("user_id", "")}
|
||||
|
||||
# Optionally send a test message
|
||||
if room_id:
|
||||
from notify_bridge_core.notifications.matrix.client import MatrixClient
|
||||
client = MatrixClient(http, bot.homeserver_url, bot.access_token)
|
||||
send_result = await client.send_message(
|
||||
room_id,
|
||||
"Test message from Notify Bridge",
|
||||
html_message="<b>Test message</b> from Notify Bridge",
|
||||
)
|
||||
result["send_result"] = send_result
|
||||
# Optionally send a test message
|
||||
if room_id:
|
||||
from notify_bridge_core.notifications.matrix.client import MatrixClient
|
||||
client = MatrixClient(http, bot.homeserver_url, bot.access_token)
|
||||
send_result = await client.send_message(
|
||||
room_id,
|
||||
"Test message from Notify Bridge",
|
||||
html_message="<b>Test message</b> from Notify Bridge",
|
||||
)
|
||||
result["send_result"] = send_result
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
|
||||
def _response(bot: MatrixBot) -> dict:
|
||||
@@ -150,7 +152,6 @@ def _response(bot: MatrixBot) -> dict:
|
||||
|
||||
|
||||
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> MatrixBot:
|
||||
bot = await session.get(MatrixBot, bot_id)
|
||||
if not bot or bot.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Matrix bot not found")
|
||||
return bot
|
||||
return await get_owned_entity(
|
||||
session, MatrixBot, bot_id, user_id, not_found_msg="Matrix bot not found",
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from ..database.models import (
|
||||
)
|
||||
from ..services.notifier import send_test_notification
|
||||
from ..services.test_dispatch import dispatch_test_notification
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -277,7 +278,7 @@ async def _tt_response(session: AsyncSession, tt: NotificationTrackerTarget) ->
|
||||
async def _get_user_tracker(
|
||||
session: AsyncSession, tracker_id: int, user_id: int
|
||||
) -> NotificationTracker:
|
||||
tracker = await session.get(NotificationTracker, tracker_id)
|
||||
if not tracker or tracker.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Tracker not found")
|
||||
return tracker
|
||||
return await get_owned_entity(
|
||||
session, NotificationTracker, tracker_id, user_id,
|
||||
not_found_msg="Tracker not found",
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from ..database.models import (
|
||||
User,
|
||||
)
|
||||
from ..services.scheduler import schedule_tracker, unschedule_tracker
|
||||
from .helpers import get_owned_entity
|
||||
from .notification_tracker_targets import _tt_response
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -205,7 +206,7 @@ async def _tracker_response(session: AsyncSession, t: NotificationTracker) -> di
|
||||
async def _get_user_tracker(
|
||||
session: AsyncSession, tracker_id: int, user_id: int
|
||||
) -> NotificationTracker:
|
||||
tracker = await session.get(NotificationTracker, tracker_id)
|
||||
if not tracker or tracker.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Tracker not found")
|
||||
return tracker
|
||||
return await get_owned_entity(
|
||||
session, NotificationTracker, tracker_id, user_id,
|
||||
not_found_msg="Tracker not found",
|
||||
)
|
||||
|
||||
@@ -13,7 +13,12 @@ import aiohttp
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import ServiceProvider, User
|
||||
from ..services import make_immich_provider, make_gitea_provider, make_planka_provider, make_nut_provider, make_google_photos_provider
|
||||
from ..services import (
|
||||
make_immich_provider, make_gitea_provider, make_planka_provider,
|
||||
make_nut_provider, make_google_photos_provider, list_provider_collections,
|
||||
)
|
||||
from ..services.http_session import get_http_session
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -82,6 +87,20 @@ class GooglePhotosProviderConfig(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class PayloadMapping(BaseModel):
|
||||
variable: str
|
||||
jsonpath: str
|
||||
default: str | None = None
|
||||
|
||||
|
||||
class WebhookProviderConfig(BaseModel):
|
||||
auth_mode: str = "none"
|
||||
webhook_secret: str | None = None
|
||||
payload_mappings: list[PayloadMapping] = []
|
||||
event_type_path: str | None = None
|
||||
collection_path: str | None = None
|
||||
|
||||
|
||||
_PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
|
||||
"immich": ImmichProviderConfig,
|
||||
"gitea": GiteaProviderConfig,
|
||||
@@ -89,6 +108,7 @@ _PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
|
||||
"scheduler": SchedulerProviderConfig,
|
||||
"nut": NutProviderConfig,
|
||||
"google_photos": GooglePhotosProviderConfig,
|
||||
"webhook": WebhookProviderConfig,
|
||||
}
|
||||
|
||||
|
||||
@@ -106,6 +126,70 @@ def _validate_provider_config(provider_type: str, config: dict[str, Any]) -> Non
|
||||
)
|
||||
|
||||
|
||||
async def _test_provider_connection(provider: ServiceProvider) -> dict[str, Any]:
|
||||
"""Test provider connection and return the result dict.
|
||||
|
||||
For providers that lack optional credentials (gitea without api_token,
|
||||
planka without api_key), returns a success stub.
|
||||
"""
|
||||
http_session = await get_http_session()
|
||||
|
||||
if provider.type == "immich":
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
return await immich.test_connection()
|
||||
|
||||
if provider.type == "gitea":
|
||||
if not provider.config.get("api_token"):
|
||||
return {"ok": True, "message": "Gitea webhook-only mode (no API token for testing)"}
|
||||
gitea = make_gitea_provider(http_session, provider)
|
||||
return await gitea.test_connection()
|
||||
|
||||
if provider.type == "planka":
|
||||
if not provider.config.get("api_key"):
|
||||
return {"ok": True, "message": "Planka webhook-only mode (no API key for testing)"}
|
||||
planka = make_planka_provider(http_session, provider)
|
||||
return await planka.test_connection()
|
||||
|
||||
if provider.type == "nut":
|
||||
nut = make_nut_provider(provider)
|
||||
return await nut.test_connection()
|
||||
|
||||
if provider.type == "google_photos":
|
||||
gp = make_google_photos_provider(http_session, provider)
|
||||
return await gp.test_connection()
|
||||
|
||||
if provider.type in ("scheduler", "webhook"):
|
||||
return {"ok": True, "message": "Virtual provider — always available"}
|
||||
|
||||
return {"ok": False, "message": f"Unknown provider type: {provider.type}"}
|
||||
|
||||
|
||||
async def _validate_provider_connection(provider: ServiceProvider) -> dict[str, Any]:
|
||||
"""Test provider connection. Raise HTTPException on failure.
|
||||
|
||||
Returns the test_result dict on success (caller may inspect extra fields
|
||||
like ``external_domain``).
|
||||
"""
|
||||
try:
|
||||
test_result = await _test_provider_connection(provider)
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
except OSError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", f"Cannot connect to {provider.type} provider"),
|
||||
)
|
||||
return test_result
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_providers(
|
||||
user: User = Depends(get_current_user),
|
||||
@@ -128,96 +212,15 @@ async def create_provider(
|
||||
"""Add a new service provider (validates connection for known types)."""
|
||||
_validate_provider_config(body.type, body.config)
|
||||
|
||||
# Validate connection for known provider types
|
||||
try:
|
||||
if body.type == "immich":
|
||||
from notify_bridge_core.providers.immich import ImmichServiceProvider
|
||||
config = body.config
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = ImmichServiceProvider(
|
||||
http_session, config.get("url", ""), config.get("api_key", ""),
|
||||
config.get("external_domain"), body.name,
|
||||
)
|
||||
test_result = await immich.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", f"Cannot connect to {body.type} provider"),
|
||||
)
|
||||
# Store external_domain from server config if available
|
||||
if test_result.get("external_domain"):
|
||||
config["external_domain"] = test_result["external_domain"]
|
||||
# Build a temporary ServiceProvider for connection testing
|
||||
temp_provider = ServiceProvider(
|
||||
id=0, user_id=0, type=body.type, name=body.name, config=body.config,
|
||||
)
|
||||
test_result = await _validate_provider_connection(temp_provider)
|
||||
|
||||
elif body.type == "gitea":
|
||||
config = body.config
|
||||
# api_token is optional (webhook_secret is required, but token only for repo listing)
|
||||
if config.get("api_token"):
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
from notify_bridge_core.providers.gitea import GiteaServiceProvider
|
||||
gitea = GiteaServiceProvider(
|
||||
http_session, config.get("url", ""), config.get("api_token", ""), body.name,
|
||||
)
|
||||
test_result = await gitea.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Gitea"),
|
||||
)
|
||||
|
||||
elif body.type == "planka":
|
||||
config = body.config
|
||||
if config.get("api_key"):
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
from notify_bridge_core.providers.planka import PlankaServiceProvider
|
||||
planka = PlankaServiceProvider(
|
||||
http_session, config.get("url", ""), config.get("api_key", ""), body.name,
|
||||
)
|
||||
test_result = await planka.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Planka"),
|
||||
)
|
||||
|
||||
elif body.type == "nut":
|
||||
nut = make_nut_provider(ServiceProvider(
|
||||
id=0, user_id=0, type="nut", name=body.name, config=body.config,
|
||||
))
|
||||
test_result = await nut.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to NUT server"),
|
||||
)
|
||||
|
||||
elif body.type == "google_photos":
|
||||
config = body.config
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
|
||||
gp = GooglePhotosServiceProvider(
|
||||
http_session, config.get("client_id", ""), config.get("client_secret", ""),
|
||||
config.get("refresh_token", ""), body.name,
|
||||
)
|
||||
test_result = await gp.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Google Photos"),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
except OSError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
|
||||
# Scheduler: no validation needed (virtual provider)
|
||||
# Store external_domain from Immich server config if available
|
||||
if test_result.get("external_domain"):
|
||||
body.config["external_domain"] = test_result["external_domain"]
|
||||
|
||||
provider = ServiceProvider(
|
||||
user_id=user.id,
|
||||
@@ -307,78 +310,10 @@ async def update_provider(
|
||||
provider.config = body.config
|
||||
|
||||
# Re-validate connection when config changes for known provider types
|
||||
if config_changed and provider.type == "immich":
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
test_result = await immich.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", f"Cannot connect to {provider.type} provider"),
|
||||
)
|
||||
if test_result.get("external_domain"):
|
||||
provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
elif config_changed and provider.type == "gitea":
|
||||
if provider.config.get("api_token"):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gitea = make_gitea_provider(http_session, provider)
|
||||
test_result = await gitea.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Gitea"),
|
||||
)
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
elif config_changed and provider.type == "planka":
|
||||
if provider.config.get("api_key"):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
planka = make_planka_provider(http_session, provider)
|
||||
test_result = await planka.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Planka"),
|
||||
)
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
elif config_changed and provider.type == "nut":
|
||||
nut = make_nut_provider(provider)
|
||||
test_result = await nut.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to NUT server"),
|
||||
)
|
||||
elif config_changed and provider.type == "google_photos":
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gp = make_google_photos_provider(http_session, provider)
|
||||
test_result = await gp.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Google Photos"),
|
||||
)
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
if config_changed:
|
||||
test_result = await _validate_provider_connection(provider)
|
||||
if test_result.get("external_domain"):
|
||||
provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
|
||||
|
||||
session.add(provider)
|
||||
await session.commit()
|
||||
@@ -408,39 +343,7 @@ async def test_provider(
|
||||
):
|
||||
"""Check if a service provider is reachable."""
|
||||
provider = await _get_user_provider(session, provider_id, user.id)
|
||||
|
||||
if provider.type == "immich":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
return await immich.test_connection()
|
||||
|
||||
if provider.type == "gitea":
|
||||
if not provider.config.get("api_token"):
|
||||
return {"ok": True, "message": "Gitea webhook-only mode (no API token for testing)"}
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gitea = make_gitea_provider(http_session, provider)
|
||||
return await gitea.test_connection()
|
||||
|
||||
if provider.type == "planka":
|
||||
if not provider.config.get("api_key"):
|
||||
return {"ok": True, "message": "Planka webhook-only mode (no API key for testing)"}
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
planka = make_planka_provider(http_session, provider)
|
||||
return await planka.test_connection()
|
||||
|
||||
if provider.type == "scheduler":
|
||||
return {"ok": True, "message": "Virtual provider — always available"}
|
||||
|
||||
if provider.type == "nut":
|
||||
nut = make_nut_provider(provider)
|
||||
return await nut.test_connection()
|
||||
|
||||
if provider.type == "google_photos":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gp = make_google_photos_provider(http_session, provider)
|
||||
return await gp.test_connection()
|
||||
|
||||
return {"ok": False, "message": f"Unknown provider type: {provider.type}"}
|
||||
return await _test_provider_connection(provider)
|
||||
|
||||
|
||||
@router.get("/{provider_id}/people")
|
||||
@@ -454,14 +357,14 @@ async def list_people(
|
||||
|
||||
if provider.type == "immich":
|
||||
from notify_bridge_core.providers.immich.client import ImmichClient
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
client = ImmichClient(
|
||||
http_session,
|
||||
provider.config.get("url", ""),
|
||||
provider.config.get("api_key", ""),
|
||||
)
|
||||
people = await client.get_people()
|
||||
return [{"id": pid, "name": name} for pid, name in people.items()]
|
||||
http_session = await get_http_session()
|
||||
client = ImmichClient(
|
||||
http_session,
|
||||
provider.config.get("url", ""),
|
||||
provider.config.get("api_key", ""),
|
||||
)
|
||||
people = await client.get_people()
|
||||
return [{"id": pid, "name": name} for pid, name in people.items()]
|
||||
|
||||
return []
|
||||
|
||||
@@ -475,35 +378,7 @@ async def list_collections(
|
||||
"""Fetch collections from a service provider."""
|
||||
provider = await _get_user_provider(session, provider_id, user.id)
|
||||
|
||||
if provider.type == "immich":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
return await immich.list_collections()
|
||||
|
||||
if provider.type == "gitea":
|
||||
if not provider.config.get("api_token"):
|
||||
return []
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gitea = make_gitea_provider(http_session, provider)
|
||||
return await gitea.list_collections()
|
||||
|
||||
if provider.type == "planka":
|
||||
if not provider.config.get("api_key"):
|
||||
return []
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
planka = make_planka_provider(http_session, provider)
|
||||
return await planka.list_collections()
|
||||
|
||||
if provider.type == "nut":
|
||||
nut = make_nut_provider(provider)
|
||||
return await nut.list_collections()
|
||||
|
||||
if provider.type == "google_photos":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gp = make_google_photos_provider(http_session, provider)
|
||||
return await gp.list_collections()
|
||||
|
||||
return []
|
||||
return await list_provider_collections(provider)
|
||||
|
||||
|
||||
@router.get("/{provider_id}/albums/{album_id}/shared-links")
|
||||
@@ -517,19 +392,19 @@ async def get_album_shared_links(
|
||||
provider = await _get_user_provider(session, provider_id, user.id)
|
||||
|
||||
if provider.type == "immich":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
links = await immich.client.get_shared_links(album_id)
|
||||
return [
|
||||
{
|
||||
"id": link.id,
|
||||
"key": link.key,
|
||||
"has_password": link.has_password,
|
||||
"is_expired": link.is_expired,
|
||||
"is_accessible": link.is_accessible,
|
||||
}
|
||||
for link in links
|
||||
]
|
||||
http_session = await get_http_session()
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
links = await immich.client.get_shared_links(album_id)
|
||||
return [
|
||||
{
|
||||
"id": link.id,
|
||||
"key": link.key,
|
||||
"has_password": link.has_password,
|
||||
"is_expired": link.is_expired,
|
||||
"is_accessible": link.is_accessible,
|
||||
}
|
||||
for link in links
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
@@ -545,15 +420,13 @@ async def create_album_shared_link(
|
||||
provider = await _get_user_provider(session, provider_id, user.id)
|
||||
|
||||
if provider.type == "immich":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
success = await immich.client.create_shared_link(album_id)
|
||||
if success:
|
||||
return {"success": True}
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=400, detail="Failed to create shared link")
|
||||
http_session = await get_http_session()
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
success = await immich.client.create_shared_link(album_id)
|
||||
if success:
|
||||
return {"success": True}
|
||||
raise HTTPException(status_code=400, detail="Failed to create shared link")
|
||||
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=400, detail="Provider type does not support shared links")
|
||||
|
||||
|
||||
@@ -580,7 +453,7 @@ async def _get_user_provider(
|
||||
session: AsyncSession, provider_id: int, user_id: int
|
||||
) -> ServiceProvider:
|
||||
"""Get a provider owned by the user, or raise 404."""
|
||||
provider = await session.get(ServiceProvider, provider_id)
|
||||
if not provider or provider.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
return provider
|
||||
return await get_owned_entity(
|
||||
session, ServiceProvider, provider_id, user_id,
|
||||
not_found_msg="Provider not found",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Shared slot load/save and Jinja2 preview helpers for template config APIs."""
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
from jinja2 import StrictUndefined, TemplateSyntaxError, UndefinedError
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
S = TypeVar("S", bound=SQLModel)
|
||||
|
||||
|
||||
async def load_slots(
|
||||
session: AsyncSession,
|
||||
slot_model: type[S],
|
||||
config_id: int,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Load all template slots for a config as {slot_name: {locale: template}}.
|
||||
|
||||
Works for both TemplateSlot and CommandTemplateSlot — they share the same
|
||||
column names (config_id, slot_name, locale, template).
|
||||
"""
|
||||
result = await session.exec(
|
||||
select(slot_model).where(slot_model.config_id == config_id) # type: ignore[attr-defined]
|
||||
)
|
||||
slots: dict[str, dict[str, str]] = {}
|
||||
for s in result.all():
|
||||
slots.setdefault(s.slot_name, {})[s.locale] = s.template # type: ignore[attr-defined]
|
||||
return slots
|
||||
|
||||
|
||||
async def save_slots(
|
||||
session: AsyncSession,
|
||||
slot_model: type[S],
|
||||
config_id: int,
|
||||
slots: dict[str, dict[str, str]],
|
||||
) -> None:
|
||||
"""Create or update template slots for a config (locale-aware).
|
||||
|
||||
Works for both TemplateSlot and CommandTemplateSlot.
|
||||
"""
|
||||
for slot_name, locale_map in slots.items():
|
||||
for locale, template_text in locale_map.items():
|
||||
result = await session.exec(
|
||||
select(slot_model).where(
|
||||
slot_model.config_id == config_id, # type: ignore[attr-defined]
|
||||
slot_model.slot_name == slot_name, # type: ignore[attr-defined]
|
||||
slot_model.locale == locale, # type: ignore[attr-defined]
|
||||
)
|
||||
)
|
||||
existing = result.first()
|
||||
if existing:
|
||||
existing.template = template_text # type: ignore[attr-defined]
|
||||
session.add(existing)
|
||||
else:
|
||||
session.add(slot_model(
|
||||
config_id=config_id,
|
||||
slot_name=slot_name,
|
||||
locale=locale,
|
||||
template=template_text,
|
||||
))
|
||||
|
||||
|
||||
def render_template_preview(template: str, context: dict) -> dict:
|
||||
"""Two-pass Jinja2 render: syntax check, then strict render.
|
||||
|
||||
Returns a dict with either ``{"rendered": str}`` on success, or
|
||||
``{"rendered": None, "error": str, ...}`` on failure.
|
||||
"""
|
||||
# Pass 1: syntax check (default Undefined — catches parse errors only)
|
||||
try:
|
||||
env = SandboxedEnvironment(autoescape=False)
|
||||
env.from_string(template)
|
||||
except TemplateSyntaxError as e:
|
||||
return {
|
||||
"rendered": None,
|
||||
"error": e.message,
|
||||
"error_line": e.lineno,
|
||||
}
|
||||
|
||||
# Pass 2: render with StrictUndefined to catch unknown variables
|
||||
try:
|
||||
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
|
||||
tmpl = strict_env.from_string(template)
|
||||
rendered = tmpl.render(**context)
|
||||
return {"rendered": rendered}
|
||||
except UndefinedError as e:
|
||||
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
|
||||
except Exception as e:
|
||||
return {"rendered": None, "error": str(e), "error_line": None}
|
||||
@@ -112,8 +112,16 @@ async def get_nav_counts(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Return entity counts for sidebar navigation badges."""
|
||||
counts = {}
|
||||
"""Return entity counts for sidebar navigation badges.
|
||||
|
||||
Note: queries run sequentially because SQLAlchemy AsyncSession is NOT safe
|
||||
for concurrent use within a single session (no asyncio.gather). We
|
||||
minimise round-trips by combining user + system counts and per-type
|
||||
target counts into single aggregate queries where possible.
|
||||
"""
|
||||
counts: dict[str, int] = {}
|
||||
|
||||
# --- 1) User-owned entity counts (one query per model) ---
|
||||
for model, key in [
|
||||
(ServiceProvider, "providers"),
|
||||
(NotificationTracker, "notification_trackers"),
|
||||
@@ -132,7 +140,7 @@ async def get_nav_counts(
|
||||
)).one()
|
||||
counts[key] = count
|
||||
|
||||
# System-owned entities (user_id=0) count as well
|
||||
# --- 2) Add system-owned counts (user_id=0) for shared entities ---
|
||||
for model, key in [
|
||||
(TemplateConfig, "template_configs"),
|
||||
(CommandTemplateConfig, "command_template_configs"),
|
||||
@@ -144,15 +152,22 @@ async def get_nav_counts(
|
||||
)).one()
|
||||
counts[key] += system_count
|
||||
|
||||
# Per-type target counts for nav badges
|
||||
for target_type in ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"):
|
||||
type_count = (await session.exec(
|
||||
select(func.count()).select_from(NotificationTarget).where(
|
||||
NotificationTarget.user_id == user.id,
|
||||
NotificationTarget.type == target_type,
|
||||
)
|
||||
)).one()
|
||||
counts[f"targets_{target_type}"] = type_count
|
||||
# --- 3) Per-type target counts in a single query using conditional aggregation ---
|
||||
target_types = ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix")
|
||||
type_counts_result = (await session.exec(
|
||||
select(
|
||||
NotificationTarget.type,
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
NotificationTarget.user_id == user.id,
|
||||
NotificationTarget.type.in_(target_types),
|
||||
)
|
||||
.group_by(NotificationTarget.type)
|
||||
)).all()
|
||||
type_counts_map = dict(type_counts_result)
|
||||
for target_type in target_types:
|
||||
counts[f"targets_{target_type}"] = type_counts_map.get(target_type, 0)
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import NotificationTarget, TargetReceiver, User
|
||||
from ..services.notifier import send_to_receiver
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -170,7 +171,7 @@ def _response(r: TargetReceiver) -> dict:
|
||||
|
||||
|
||||
async def _get_user_target(session: AsyncSession, target_id: int, user_id: int) -> NotificationTarget:
|
||||
target = await session.get(NotificationTarget, target_id)
|
||||
if not target or target.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Target not found")
|
||||
return target
|
||||
return await get_owned_entity(
|
||||
session, NotificationTarget, target_id, user_id,
|
||||
not_found_msg="Target not found",
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import NotificationTarget, NotificationTrackerTarget, TargetReceiver, TelegramBot, TelegramChat, User
|
||||
from ..services.notifier import send_test_notification
|
||||
from .helpers import get_owned_entity
|
||||
from .target_receivers import _receiver_key
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -306,8 +307,15 @@ async def _validate_broadcast_children(
|
||||
return
|
||||
if exclude_target_id and exclude_target_id in child_ids:
|
||||
raise HTTPException(status_code=400, detail="A broadcast target cannot include itself")
|
||||
|
||||
# Batch-load all children in a single IN query instead of N+1 individual fetches
|
||||
children = (await session.exec(
|
||||
select(NotificationTarget).where(NotificationTarget.id.in_(child_ids))
|
||||
)).all()
|
||||
children_by_id = {c.id: c for c in children}
|
||||
|
||||
for child_id in child_ids:
|
||||
child = await session.get(NotificationTarget, child_id)
|
||||
child = children_by_id.get(child_id)
|
||||
if not child or child.user_id != user_id:
|
||||
raise HTTPException(status_code=400, detail=f"Child target {child_id} not found")
|
||||
if child.type == "broadcast":
|
||||
@@ -378,7 +386,7 @@ def _safe_config(target: NotificationTarget) -> dict:
|
||||
async def _get_user_target(
|
||||
session: AsyncSession, target_id: int, user_id: int
|
||||
) -> NotificationTarget:
|
||||
target = await session.get(NotificationTarget, target_id)
|
||||
if not target or target.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Target not found")
|
||||
return target
|
||||
return await get_owned_entity(
|
||||
session, NotificationTarget, target_id, user_id,
|
||||
not_found_msg="Target not found",
|
||||
)
|
||||
|
||||
@@ -7,8 +7,6 @@ from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
@@ -19,6 +17,7 @@ from ..database.models import AppSetting, NotificationTarget, TargetReceiver, Te
|
||||
from ..services.notifier import _get_test_message
|
||||
from ..services.telegram_poller import schedule_bot_polling, unschedule_bot_polling
|
||||
from .app_settings import get_setting
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -290,10 +289,11 @@ async def test_chat(
|
||||
):
|
||||
"""Send a test message to a chat via the bot."""
|
||||
bot = await _get_user_bot(session, bot_id, user.id)
|
||||
from ..services.http_session import get_http_session
|
||||
message = _get_test_message(locale, "telegram")
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, bot.token)
|
||||
return await client.send_message(chat_id, message)
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, bot.token)
|
||||
return await client.send_message(chat_id, message)
|
||||
|
||||
|
||||
class ChatUpdate(BaseModel):
|
||||
@@ -344,41 +344,44 @@ async def delete_chat(
|
||||
|
||||
async def _get_webhook_info(token: str) -> dict | None:
|
||||
"""Call Telegram getWebhookInfo via TelegramClient."""
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_webhook_info()
|
||||
return result.get("result") if result.get("success") else None
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_webhook_info()
|
||||
return result.get("result") if result.get("success") else None
|
||||
|
||||
|
||||
async def _get_me(token: str) -> dict | None:
|
||||
"""Call Telegram getMe via TelegramClient."""
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_me()
|
||||
return result.get("result") if result.get("success") else None
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_me()
|
||||
return result.get("result") if result.get("success") else None
|
||||
|
||||
|
||||
async def _fetch_chats_from_telegram(token: str) -> list[dict]:
|
||||
"""Fetch chats from Telegram getUpdates via TelegramClient."""
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_updates(limit=100)
|
||||
if not result.get("success"):
|
||||
return []
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_updates(limit=100)
|
||||
if not result.get("success"):
|
||||
return []
|
||||
|
||||
seen: dict[int, dict] = {}
|
||||
for update in result.get("result", []):
|
||||
msg = update.get("message", {})
|
||||
chat = msg.get("chat", {})
|
||||
chat_id = chat.get("id")
|
||||
if chat_id and chat_id not in seen:
|
||||
seen[chat_id] = {
|
||||
"id": chat_id,
|
||||
"title": chat.get("title") or (chat.get("first_name", "") + (" " + chat.get("last_name", "")).strip()),
|
||||
"type": chat.get("type", "private"),
|
||||
"username": chat.get("username", ""),
|
||||
}
|
||||
return list(seen.values())
|
||||
seen: dict[int, dict] = {}
|
||||
for update in result.get("result", []):
|
||||
msg = update.get("message", {})
|
||||
chat = msg.get("chat", {})
|
||||
chat_id = chat.get("id")
|
||||
if chat_id and chat_id not in seen:
|
||||
seen[chat_id] = {
|
||||
"id": chat_id,
|
||||
"title": chat.get("title") or (chat.get("first_name", "") + (" " + chat.get("last_name", "")).strip()),
|
||||
"type": chat.get("type", "private"),
|
||||
"username": chat.get("username", ""),
|
||||
}
|
||||
return list(seen.values())
|
||||
|
||||
|
||||
def _chat_response(c: TelegramChat) -> dict:
|
||||
@@ -410,10 +413,9 @@ def _bot_response(b: TelegramBot) -> dict:
|
||||
|
||||
|
||||
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> TelegramBot:
|
||||
bot = await session.get(TelegramBot, bot_id)
|
||||
if not bot or bot.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Bot not found")
|
||||
return bot
|
||||
return await get_owned_entity(
|
||||
session, TelegramBot, bot_id, user_id, not_found_msg="Bot not found",
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -13,12 +13,12 @@ from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import TemplateConfig, TemplateSlot, User
|
||||
from ..services.sample_context import _SAMPLE_CONTEXT
|
||||
from .slot_helpers import load_slots, render_template_preview, save_slots
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,40 +49,13 @@ class TemplateConfigUpdate(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]:
|
||||
"""Load all template slots for a config as {slot_name: {locale: template}}."""
|
||||
result = await session.exec(
|
||||
select(TemplateSlot).where(TemplateSlot.config_id == config_id)
|
||||
)
|
||||
slots: dict[str, dict[str, str]] = {}
|
||||
for s in result.all():
|
||||
slots.setdefault(s.slot_name, {})[s.locale] = s.template
|
||||
return slots
|
||||
return await load_slots(session, TemplateSlot, config_id)
|
||||
|
||||
|
||||
async def _save_slots(
|
||||
session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]]
|
||||
) -> None:
|
||||
"""Create or update template slots for a config (locale-aware)."""
|
||||
for slot_name, locale_map in slots.items():
|
||||
for locale, template_text in locale_map.items():
|
||||
result = await session.exec(
|
||||
select(TemplateSlot).where(
|
||||
TemplateSlot.config_id == config_id,
|
||||
TemplateSlot.slot_name == slot_name,
|
||||
TemplateSlot.locale == locale,
|
||||
)
|
||||
)
|
||||
existing = result.first()
|
||||
if existing:
|
||||
existing.template = template_text
|
||||
session.add(existing)
|
||||
else:
|
||||
session.add(TemplateSlot(
|
||||
config_id=config_id,
|
||||
slot_name=slot_name,
|
||||
locale=locale,
|
||||
template=template_text,
|
||||
))
|
||||
await save_slots(session, TemplateSlot, config_id, slots)
|
||||
|
||||
|
||||
async def _response(session: AsyncSession, c: TemplateConfig) -> dict[str, Any]:
|
||||
@@ -155,7 +128,7 @@ async def get_template_variables(
|
||||
"photo_count": "Total photo count in album",
|
||||
"video_count": "Total video count in album",
|
||||
"owner": "Album owner name",
|
||||
"target_type": "Target type: 'telegram' or 'webhook'",
|
||||
"target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
|
||||
"has_videos": "Whether added assets contain videos (boolean)",
|
||||
"has_photos": "Whether added assets contain photos (boolean)",
|
||||
"has_oversized_videos": "Whether any video exceeds the target's size limit (boolean)",
|
||||
@@ -206,7 +179,7 @@ async def get_template_variables(
|
||||
}
|
||||
scheduled_vars = {
|
||||
"date": "Current date string",
|
||||
"target_type": "Target type: 'telegram' or 'webhook'",
|
||||
"target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -284,7 +257,7 @@ def _webhook_variables() -> dict:
|
||||
"source_ip": "IP address of the webhook sender",
|
||||
"raw_payload": "Full JSON payload as dict (use raw_payload.field or raw_payload | tojson)",
|
||||
"timestamp": "When the webhook was received",
|
||||
"target_type": "Target type: 'telegram' or 'webhook'",
|
||||
"target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -529,7 +502,7 @@ async def preview_date_format(
|
||||
|
||||
class PreviewRequest(BaseModel):
|
||||
template: str
|
||||
target_type: str = "telegram" # "telegram" or "webhook"
|
||||
target_type: str = "telegram" # telegram, webhook, email, discord, slack, ntfy, matrix
|
||||
date_format: str = "%d.%m.%Y, %H:%M UTC"
|
||||
date_only_format: str = "%d.%m.%Y"
|
||||
|
||||
@@ -545,33 +518,12 @@ async def preview_raw(
|
||||
1. Parse with default Undefined (catches syntax errors)
|
||||
2. Render with StrictUndefined (catches unknown variables like {{ asset.a }})
|
||||
"""
|
||||
# Pass 1: syntax check
|
||||
from datetime import datetime
|
||||
ctx = {**_SAMPLE_CONTEXT, "target_type": body.target_type,
|
||||
"date_format": body.date_format, "date_only_format": body.date_only_format}
|
||||
# Format common_date using the provided date_only_format
|
||||
try:
|
||||
env = SandboxedEnvironment(autoescape=False)
|
||||
env.from_string(body.template)
|
||||
except TemplateSyntaxError as e:
|
||||
return {
|
||||
"rendered": None,
|
||||
"error": e.message,
|
||||
"error_line": e.lineno,
|
||||
}
|
||||
|
||||
# Pass 2: render with strict undefined to catch unknown variables
|
||||
try:
|
||||
from datetime import datetime
|
||||
ctx = {**_SAMPLE_CONTEXT, "target_type": body.target_type,
|
||||
"date_format": body.date_format, "date_only_format": body.date_only_format}
|
||||
# Format common_date using the provided date_only_format
|
||||
try:
|
||||
ctx["common_date"] = datetime(2026, 3, 19).strftime(body.date_only_format)
|
||||
except (ValueError, TypeError):
|
||||
ctx["common_date"] = "19.03.2026"
|
||||
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
|
||||
tmpl = strict_env.from_string(body.template)
|
||||
rendered = tmpl.render(**ctx)
|
||||
return {"rendered": rendered}
|
||||
except UndefinedError as e:
|
||||
# Still a valid template syntactically, but references unknown variable
|
||||
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
|
||||
except Exception as e:
|
||||
return {"rendered": None, "error": str(e), "error_line": None}
|
||||
ctx["common_date"] = datetime(2026, 3, 19).strftime(body.date_only_format)
|
||||
except (ValueError, TypeError):
|
||||
ctx["common_date"] = "19.03.2026"
|
||||
return render_template_preview(body.template, ctx)
|
||||
|
||||
@@ -3,9 +3,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..database.models import CommandTracker, CommandConfig, ServiceProvider, TelegramBot
|
||||
from ..database.models import CommandConfig, CommandTracker, ServiceProvider, TelegramBot
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CommandResponse:
|
||||
"""A single response from one tracker's command execution."""
|
||||
|
||||
text: str | None = None
|
||||
media: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class ProviderCommandHandler(ABC):
|
||||
@@ -14,6 +23,8 @@ class ProviderCommandHandler(ABC):
|
||||
Each provider (Immich, Gitea, etc.) implements this interface to handle
|
||||
its own set of commands. The dispatch layer routes commands to the
|
||||
correct handler based on the provider type.
|
||||
|
||||
Each handler call receives a single (tracker, config, provider) context.
|
||||
"""
|
||||
|
||||
provider_type: str
|
||||
@@ -35,26 +46,28 @@ class ProviderCommandHandler(ABC):
|
||||
count: int,
|
||||
locale: str,
|
||||
response_mode: str,
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
provider: ServiceProvider,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
bot: TelegramBot,
|
||||
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
) -> str | list[dict[str, Any]] | None:
|
||||
"""Handle a provider-specific command.
|
||||
tracker: CommandTracker,
|
||||
config: CommandConfig,
|
||||
) -> CommandResponse | None:
|
||||
"""Handle a provider-specific command for a single tracker.
|
||||
|
||||
Args:
|
||||
cmd: The command name (without '/').
|
||||
args: Arguments after the command.
|
||||
count: Number of results to return.
|
||||
locale: User's locale ('en', 'ru').
|
||||
response_mode: 'media' or 'text'.
|
||||
providers_map: Provider instances keyed by ID.
|
||||
cmd_templates: Template slots {slot_name: {locale: template}}.
|
||||
response_mode: 'media' or 'text' (from this tracker's config).
|
||||
provider: The service provider instance for this tracker.
|
||||
cmd_templates: Template slots for this tracker's command template config.
|
||||
bot: The Telegram bot instance.
|
||||
ctx_tuples: Command context tuples for this provider type.
|
||||
tracker: The command tracker being dispatched.
|
||||
config: The command config for this tracker.
|
||||
|
||||
Returns:
|
||||
Text response, media list, or None if unhandled.
|
||||
A CommandResponse, or None if unhandled.
|
||||
"""
|
||||
|
||||
def get_rate_categories(self) -> dict[str, str]:
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Shared command handler utilities to reduce boilerplate across providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import EventLog, NotificationTracker, ServiceProvider
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_trackers_for_provider(provider_id: int) -> list[NotificationTracker]:
|
||||
"""Get notification trackers for a single provider."""
|
||||
from .handler import _get_notification_trackers_for_providers
|
||||
|
||||
return await _get_notification_trackers_for_providers({provider_id})
|
||||
|
||||
|
||||
async def get_last_event_str(tracker_ids: list[int]) -> str:
|
||||
"""Get formatted timestamp of most recent event for given trackers.
|
||||
|
||||
Returns a 'YYYY-MM-DD HH:MM' string, or '-' if no events exist.
|
||||
"""
|
||||
if not tracker_ids:
|
||||
return "-"
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(EventLog)
|
||||
.where(EventLog.tracker_id.in_(tracker_ids))
|
||||
.order_by(EventLog.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
last_event = result.first()
|
||||
return last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
|
||||
|
||||
|
||||
def get_tracked_collection_ids(
|
||||
provider: ServiceProvider,
|
||||
trackers: list[NotificationTracker],
|
||||
*,
|
||||
max_items: int = 20,
|
||||
) -> list[str]:
|
||||
"""Get deduplicated collection IDs from trackers for a provider.
|
||||
|
||||
Iterates all trackers belonging to *provider*, collects IDs from both
|
||||
``collection_ids`` and ``filters.collections``, deduplicates while
|
||||
preserving order, and caps at *max_items*.
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for tracker in trackers:
|
||||
if tracker.provider_id != provider.id:
|
||||
continue
|
||||
for cid in tracker.collection_ids or []:
|
||||
if cid not in seen:
|
||||
seen.add(cid)
|
||||
result.append(cid)
|
||||
for cid in (tracker.filters or {}).get("collections", []):
|
||||
if cid not in seen:
|
||||
seen.add(cid)
|
||||
result.append(cid)
|
||||
return result[:max_items]
|
||||
@@ -2,27 +2,55 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import (
|
||||
CommandConfig, CommandTracker, EventLog,
|
||||
NotificationTracker, ServiceProvider, TelegramBot,
|
||||
CommandConfig, CommandTracker, ServiceProvider, TelegramBot,
|
||||
)
|
||||
from ..services import make_gitea_provider
|
||||
from .base import ProviderCommandHandler
|
||||
from .handler import _render_cmd_template, _get_notification_trackers_for_providers
|
||||
from ..services.http_session import get_http_session
|
||||
from .base import CommandResponse, ProviderCommandHandler
|
||||
from .command_utils import get_last_event_str, get_tracked_collection_ids, get_trackers_for_provider
|
||||
from .handler import _render_cmd_template
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_GITEA_COMMANDS = {"status", "repos", "issues", "prs", "commits"}
|
||||
|
||||
|
||||
def _get_tracked_repos(
|
||||
provider: ServiceProvider,
|
||||
trackers: list,
|
||||
) -> list[tuple[ServiceProvider, str, str]]:
|
||||
"""Get (provider, owner, repo) tuples from tracked collection_ids."""
|
||||
if not provider.config.get("api_token"):
|
||||
return []
|
||||
collection_ids = get_tracked_collection_ids(provider, trackers)
|
||||
repos: list[tuple[ServiceProvider, str, str]] = []
|
||||
for full_name in collection_ids:
|
||||
parts = full_name.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
repos.append((provider, parts[0], parts[1]))
|
||||
return repos
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command dispatch table
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
|
||||
|
||||
|
||||
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
|
||||
"""Register a function in the text command dispatch table."""
|
||||
name = fn.__name__.removeprefix("_cmd_")
|
||||
_TEXT_COMMANDS[name] = fn
|
||||
return fn
|
||||
|
||||
|
||||
class GiteaCommandHandler(ProviderCommandHandler):
|
||||
"""Handles Gitea-specific bot commands."""
|
||||
|
||||
@@ -44,91 +72,35 @@ class GiteaCommandHandler(ProviderCommandHandler):
|
||||
count: int,
|
||||
locale: str,
|
||||
response_mode: str,
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
provider: ServiceProvider,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
bot: TelegramBot,
|
||||
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
) -> str | list[dict[str, Any]] | None:
|
||||
if cmd == "status":
|
||||
ctx = await _cmd_status(providers_map)
|
||||
return _render_cmd_template(cmd_templates, "status", locale, ctx)
|
||||
if cmd == "repos":
|
||||
ctx = await _cmd_repos(providers_map)
|
||||
return _render_cmd_template(cmd_templates, "repos", locale, ctx)
|
||||
if cmd == "issues":
|
||||
ctx = await _cmd_issues(providers_map, count)
|
||||
return _render_cmd_template(cmd_templates, "issues", locale, ctx)
|
||||
if cmd == "prs":
|
||||
ctx = await _cmd_prs(providers_map, count)
|
||||
return _render_cmd_template(cmd_templates, "prs", locale, ctx)
|
||||
if cmd == "commits":
|
||||
ctx = await _cmd_commits(providers_map, count)
|
||||
return _render_cmd_template(cmd_templates, "commits", locale, ctx)
|
||||
return None
|
||||
tracker: CommandTracker,
|
||||
config: CommandConfig,
|
||||
) -> CommandResponse | None:
|
||||
fn = _TEXT_COMMANDS.get(cmd)
|
||||
if fn is None:
|
||||
return None
|
||||
ctx = await fn(provider, count)
|
||||
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
|
||||
|
||||
|
||||
def _get_tracked_repos(
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
trackers: list[NotificationTracker],
|
||||
) -> list[tuple[ServiceProvider, str, str]]:
|
||||
"""Get (provider, owner, repo) tuples from tracked collection_ids."""
|
||||
repos: list[tuple[ServiceProvider, str, str]] = []
|
||||
for tracker in trackers:
|
||||
provider = providers_map.get(tracker.provider_id)
|
||||
if not provider or provider.type != "gitea":
|
||||
continue
|
||||
if not provider.config.get("api_token"):
|
||||
continue
|
||||
for full_name in (tracker.collection_ids or []):
|
||||
parts = full_name.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
repos.append((provider, parts[0], parts[1]))
|
||||
# Also check filters.collections
|
||||
for tracker in trackers:
|
||||
provider = providers_map.get(tracker.provider_id)
|
||||
if not provider or provider.type != "gitea":
|
||||
continue
|
||||
if not provider.config.get("api_token"):
|
||||
continue
|
||||
for full_name in (tracker.filters or {}).get("collections", []):
|
||||
parts = full_name.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
entry = (provider, parts[0], parts[1])
|
||||
if entry not in repos:
|
||||
repos.append(entry)
|
||||
return repos[:20] # Cap to prevent API hammering
|
||||
@_text_cmd
|
||||
async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
tracked_repos = _get_tracked_repos(provider, trackers)
|
||||
|
||||
|
||||
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_repos = _get_tracked_repos(providers_map, trackers)
|
||||
|
||||
# Get server version from first Gitea provider with token
|
||||
# Get server version
|
||||
server_version = "unknown"
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider in providers_map.values():
|
||||
if provider.type == "gitea" and provider.config.get("api_token"):
|
||||
gitea = make_gitea_provider(http, provider)
|
||||
version = await gitea.client.get_server_version()
|
||||
if version:
|
||||
server_version = version
|
||||
break
|
||||
if provider.config.get("api_token"):
|
||||
http = await get_http_session()
|
||||
gitea = make_gitea_provider(http, provider)
|
||||
version = await gitea.client.get_server_version()
|
||||
if version:
|
||||
server_version = version
|
||||
|
||||
# Last event
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
if tracker_ids:
|
||||
result = await session.exec(
|
||||
select(EventLog)
|
||||
.where(EventLog.tracker_id.in_(tracker_ids))
|
||||
.order_by(EventLog.created_at.desc()).limit(1)
|
||||
)
|
||||
last_event = result.first()
|
||||
else:
|
||||
last_event = None
|
||||
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
last_str = await get_last_event_str(tracker_ids)
|
||||
|
||||
return {
|
||||
"repos_count": len(tracked_repos),
|
||||
@@ -137,116 +109,139 @@ async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, An
|
||||
}
|
||||
|
||||
|
||||
async def _cmd_repos(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_repos = _get_tracked_repos(providers_map, trackers)
|
||||
@_text_cmd
|
||||
async def _cmd_repos(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
tracked_repos = _get_tracked_repos(provider, trackers)
|
||||
|
||||
repos_data: list[dict[str, Any]] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider, owner, repo in tracked_repos:
|
||||
gitea = make_gitea_provider(http, provider)
|
||||
try:
|
||||
all_repos = await gitea.client.get_repos(limit=50)
|
||||
for r in all_repos:
|
||||
if r.get("full_name") == f"{owner}/{repo}":
|
||||
repos_data.append({
|
||||
"full_name": r.get("full_name", ""),
|
||||
"description": r.get("description", ""),
|
||||
"stars": r.get("stars_count", 0),
|
||||
"url": r.get("html_url", ""),
|
||||
})
|
||||
break
|
||||
else:
|
||||
repos_data.append({
|
||||
"full_name": f"{owner}/{repo}",
|
||||
"description": "",
|
||||
"stars": 0,
|
||||
"url": "",
|
||||
})
|
||||
except Exception:
|
||||
repos_data.append({
|
||||
"full_name": f"{owner}/{repo}",
|
||||
"description": "?",
|
||||
"stars": 0,
|
||||
"url": "",
|
||||
})
|
||||
http = await get_http_session()
|
||||
|
||||
async def _fetch_repo(prov: ServiceProvider, owner: str, repo: str) -> dict[str, Any]:
|
||||
gitea = make_gitea_provider(http, prov)
|
||||
# Use direct get_repo endpoint instead of listing all repos
|
||||
r = await gitea.client.get_repo(owner, repo)
|
||||
if r:
|
||||
return {
|
||||
"full_name": r.get("full_name", ""),
|
||||
"description": r.get("description", ""),
|
||||
"stars": r.get("stars_count", 0),
|
||||
"url": r.get("html_url", ""),
|
||||
}
|
||||
return {
|
||||
"full_name": f"{owner}/{repo}",
|
||||
"description": "",
|
||||
"stars": 0,
|
||||
"url": "",
|
||||
}
|
||||
|
||||
tasks = [_fetch_repo(prov, owner, repo) for prov, owner, repo in tracked_repos]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for (prov, owner, repo), result in zip(tracked_repos, results):
|
||||
if isinstance(result, Exception):
|
||||
_LOGGER.warning("Failed to fetch repo %s/%s: %s", owner, repo, result)
|
||||
repos_data.append({
|
||||
"full_name": f"{owner}/{repo}",
|
||||
"description": "?",
|
||||
"stars": 0,
|
||||
"url": "",
|
||||
})
|
||||
else:
|
||||
repos_data.append(result)
|
||||
|
||||
return {"repos": repos_data}
|
||||
|
||||
|
||||
async def _cmd_issues(
|
||||
providers_map: dict[int, ServiceProvider], count: int,
|
||||
) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_repos = _get_tracked_repos(providers_map, trackers)
|
||||
@_text_cmd
|
||||
async def _cmd_issues(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
tracked_repos = _get_tracked_repos(provider, trackers)
|
||||
|
||||
all_issues: list[dict[str, Any]] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider, owner, repo in tracked_repos:
|
||||
gitea = make_gitea_provider(http, provider)
|
||||
issues = await gitea.client.get_repo_issues(owner, repo, limit=count)
|
||||
for issue in issues:
|
||||
all_issues.append({
|
||||
"repo": f"{owner}/{repo}",
|
||||
"number": issue.get("number", 0),
|
||||
"title": issue.get("title", ""),
|
||||
"url": issue.get("html_url", ""),
|
||||
"user": issue.get("user", {}).get("login", ""),
|
||||
"state": issue.get("state", ""),
|
||||
})
|
||||
http = await get_http_session()
|
||||
|
||||
async def _fetch_issues(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
|
||||
gitea = make_gitea_provider(http, prov)
|
||||
return await gitea.client.get_repo_issues(owner, repo, limit=count)
|
||||
|
||||
tasks = [_fetch_issues(prov, owner, repo) for prov, owner, repo in tracked_repos]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for (prov, owner, repo), result in zip(tracked_repos, results):
|
||||
if isinstance(result, Exception):
|
||||
_LOGGER.warning("Failed to fetch issues for %s/%s: %s", owner, repo, result)
|
||||
continue
|
||||
for issue in result:
|
||||
all_issues.append({
|
||||
"repo": f"{owner}/{repo}",
|
||||
"number": issue.get("number", 0),
|
||||
"title": issue.get("title", ""),
|
||||
"url": issue.get("html_url", ""),
|
||||
"user": issue.get("user", {}).get("login", ""),
|
||||
"state": issue.get("state", ""),
|
||||
})
|
||||
|
||||
all_issues.sort(key=lambda i: i.get("number", 0), reverse=True)
|
||||
return {"issues": all_issues[:count]}
|
||||
|
||||
|
||||
async def _cmd_prs(
|
||||
providers_map: dict[int, ServiceProvider], count: int,
|
||||
) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_repos = _get_tracked_repos(providers_map, trackers)
|
||||
@_text_cmd
|
||||
async def _cmd_prs(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
tracked_repos = _get_tracked_repos(provider, trackers)
|
||||
|
||||
all_prs: list[dict[str, Any]] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider, owner, repo in tracked_repos:
|
||||
gitea = make_gitea_provider(http, provider)
|
||||
prs = await gitea.client.get_repo_pulls(owner, repo, limit=count)
|
||||
for pr in prs:
|
||||
all_prs.append({
|
||||
"repo": f"{owner}/{repo}",
|
||||
"number": pr.get("number", 0),
|
||||
"title": pr.get("title", ""),
|
||||
"url": pr.get("html_url", ""),
|
||||
"user": pr.get("user", {}).get("login", ""),
|
||||
"state": pr.get("state", ""),
|
||||
})
|
||||
http = await get_http_session()
|
||||
|
||||
async def _fetch_prs(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
|
||||
gitea = make_gitea_provider(http, prov)
|
||||
return await gitea.client.get_repo_pulls(owner, repo, limit=count)
|
||||
|
||||
tasks = [_fetch_prs(prov, owner, repo) for prov, owner, repo in tracked_repos]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for (prov, owner, repo), result in zip(tracked_repos, results):
|
||||
if isinstance(result, Exception):
|
||||
_LOGGER.warning("Failed to fetch PRs for %s/%s: %s", owner, repo, result)
|
||||
continue
|
||||
for pr in result:
|
||||
all_prs.append({
|
||||
"repo": f"{owner}/{repo}",
|
||||
"number": pr.get("number", 0),
|
||||
"title": pr.get("title", ""),
|
||||
"url": pr.get("html_url", ""),
|
||||
"user": pr.get("user", {}).get("login", ""),
|
||||
"state": pr.get("state", ""),
|
||||
})
|
||||
|
||||
all_prs.sort(key=lambda p: p.get("number", 0), reverse=True)
|
||||
return {"prs": all_prs[:count]}
|
||||
|
||||
|
||||
async def _cmd_commits(
|
||||
providers_map: dict[int, ServiceProvider], count: int,
|
||||
) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_repos = _get_tracked_repos(providers_map, trackers)
|
||||
@_text_cmd
|
||||
async def _cmd_commits(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
tracked_repos = _get_tracked_repos(provider, trackers)
|
||||
|
||||
all_commits: list[dict[str, Any]] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider, owner, repo in tracked_repos:
|
||||
gitea = make_gitea_provider(http, provider)
|
||||
commits = await gitea.client.get_repo_commits(owner, repo, limit=count)
|
||||
for c in commits:
|
||||
commit_data = c.get("commit", {})
|
||||
all_commits.append({
|
||||
"repo": f"{owner}/{repo}",
|
||||
"short_id": c.get("sha", "")[:7],
|
||||
"message": commit_data.get("message", "").split("\n")[0][:80],
|
||||
"author": commit_data.get("author", {}).get("name", ""),
|
||||
"url": c.get("html_url", ""),
|
||||
})
|
||||
http = await get_http_session()
|
||||
|
||||
async def _fetch_commits(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]:
|
||||
gitea = make_gitea_provider(http, prov)
|
||||
return await gitea.client.get_repo_commits(owner, repo, limit=count)
|
||||
|
||||
tasks = [_fetch_commits(prov, owner, repo) for prov, owner, repo in tracked_repos]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for (prov, owner, repo), result in zip(tracked_repos, results):
|
||||
if isinstance(result, Exception):
|
||||
_LOGGER.warning("Failed to fetch commits for %s/%s: %s", owner, repo, result)
|
||||
continue
|
||||
for c in result:
|
||||
commit_data = c.get("commit", {})
|
||||
all_commits.append({
|
||||
"repo": f"{owner}/{repo}",
|
||||
"short_id": c.get("sha", "")[:7],
|
||||
"message": commit_data.get("message", "").split("\n")[0][:80],
|
||||
"author": commit_data.get("author", {}).get("name", ""),
|
||||
"url": c.get("html_url", ""),
|
||||
})
|
||||
|
||||
return {"commits": all_commits[:count]}
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
@@ -25,17 +26,21 @@ from ..database.models import (
|
||||
ServiceProvider,
|
||||
TelegramBot,
|
||||
)
|
||||
from .base import CommandResponse
|
||||
from .parser import parse_command
|
||||
from .registry import get_rate_category
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Singleton Jinja2 environment for template rendering (Phase 4d)
|
||||
_JINJA_ENV = SandboxedEnvironment(autoescape=False)
|
||||
_JINJA_ENV = SandboxedEnvironment(autoescape=True)
|
||||
|
||||
# Rate limit state with automatic TTL expiry (Phase 4e)
|
||||
_rate_limits: TTLCache = TTLCache(maxsize=10000, ttl=3600)
|
||||
|
||||
# Maximum responses per command to avoid Telegram rate limits
|
||||
_MAX_RESPONSES_PER_COMMAND = 5
|
||||
|
||||
|
||||
def _check_rate_limit(bot_id: int, chat_id: str, cmd: str, limits: dict[str, int]) -> int | None:
|
||||
"""Check rate limit. Returns seconds to wait, or None if OK."""
|
||||
@@ -60,6 +65,12 @@ def _resolve_template(
|
||||
return locale_map.get(locale) or locale_map.get("en")
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _compile_template(template_str: str):
|
||||
"""Cache compiled Jinja2 templates to avoid re-parsing identical strings."""
|
||||
return _JINJA_ENV.from_string(template_str)
|
||||
|
||||
|
||||
def _render_cmd_template(
|
||||
templates: dict[str, dict[str, str]], slot_name: str, locale: str,
|
||||
context: dict[str, Any],
|
||||
@@ -70,20 +81,28 @@ def _render_cmd_template(
|
||||
_LOGGER.warning("No command template found for slot '%s' locale '%s'", slot_name, locale)
|
||||
return f"[No template: {slot_name}]"
|
||||
try:
|
||||
tmpl = _JINJA_ENV.from_string(template_str)
|
||||
tmpl = _compile_template(template_str)
|
||||
return tmpl.render(**context)
|
||||
except Exception as e:
|
||||
_LOGGER.warning("Failed to render command template '%s': %s", slot_name, e)
|
||||
return f"[Template error: {slot_name}]"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _resolve_command_context(
|
||||
bot: TelegramBot,
|
||||
) -> tuple[list[tuple[CommandTracker, CommandConfig, ServiceProvider]], dict[str, dict[str, str]]]:
|
||||
) -> tuple[
|
||||
list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
dict[int, dict[str, dict[str, str]]],
|
||||
]:
|
||||
"""Resolve all enabled command trackers, configs, and providers for a bot.
|
||||
|
||||
Returns (context_tuples, cmd_template_slots).
|
||||
cmd_template_slots is {slot_name: {locale: template}}.
|
||||
Returns:
|
||||
(context_tuples, templates_by_config_id)
|
||||
templates_by_config_id is {command_template_config_id: {slot_name: {locale: template}}}.
|
||||
"""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
@@ -142,8 +161,8 @@ async def _resolve_command_context(
|
||||
continue
|
||||
tuples.append((tracker, config, provider))
|
||||
|
||||
# Load command template slots — merge from all configs
|
||||
cmd_template_slots: dict[str, dict[str, str]] = {}
|
||||
# Load command template slots per config (not merged)
|
||||
templates_by_config_id: dict[int, dict[str, dict[str, str]]] = {}
|
||||
seen_config_ids: set[int] = set()
|
||||
for _, config, _ in tuples:
|
||||
cfg_id = config.command_template_config_id
|
||||
@@ -154,98 +173,136 @@ async def _resolve_command_context(
|
||||
CommandTemplateSlot.config_id == cfg_id
|
||||
)
|
||||
)
|
||||
slots: dict[str, dict[str, str]] = {}
|
||||
for s in slot_result.all():
|
||||
cmd_template_slots.setdefault(s.slot_name, {})[s.locale] = s.template
|
||||
slots.setdefault(s.slot_name, {})[s.locale] = s.template
|
||||
templates_by_config_id[cfg_id] = slots
|
||||
|
||||
return tuples, cmd_template_slots
|
||||
return tuples, templates_by_config_id
|
||||
|
||||
|
||||
def _merge_command_context(
|
||||
def _templates_for_config(
|
||||
templates_by_config_id: dict[int, dict[str, dict[str, str]]],
|
||||
config: CommandConfig,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Get template slots for a specific command config."""
|
||||
cfg_id = config.command_template_config_id
|
||||
if cfg_id and cfg_id in templates_by_config_id:
|
||||
return templates_by_config_id[cfg_id]
|
||||
return {}
|
||||
|
||||
|
||||
def _merge_all_templates(
|
||||
templates_by_config_id: dict[int, dict[str, dict[str, str]]],
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Merge all template config slots into one dict (for universal commands)."""
|
||||
merged: dict[str, dict[str, str]] = {}
|
||||
for slots in templates_by_config_id.values():
|
||||
for slot_name, locale_map in slots.items():
|
||||
merged.setdefault(slot_name, {}).update(locale_map)
|
||||
return merged
|
||||
|
||||
|
||||
def _merge_enabled_commands(
|
||||
ctx: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
) -> tuple[list[str], str, int, dict[str, Any]]:
|
||||
"""Merge enabled_commands from all configs and pick defaults from first config."""
|
||||
) -> tuple[list[str], dict[str, Any]]:
|
||||
"""Merge enabled_commands (union) and rate_limits from all configs.
|
||||
|
||||
Rate limits use the most restrictive (minimum) cooldown per category.
|
||||
"""
|
||||
if not ctx:
|
||||
return [], "media", 5, {}
|
||||
return [], {}
|
||||
|
||||
enabled: set[str] = set()
|
||||
merged_limits: dict[str, int] = {}
|
||||
for _, config, _ in ctx:
|
||||
enabled.update(config.enabled_commands or [])
|
||||
for category, cooldown in (config.rate_limits or {}).items():
|
||||
if category not in merged_limits:
|
||||
merged_limits[category] = cooldown
|
||||
else:
|
||||
merged_limits[category] = min(merged_limits[category], cooldown)
|
||||
|
||||
first_config = ctx[0][1]
|
||||
response_mode = first_config.response_mode or "media"
|
||||
default_count = first_config.default_count or 5
|
||||
rate_limits = first_config.rate_limits or {}
|
||||
return sorted(enabled), merged_limits
|
||||
|
||||
return sorted(enabled), response_mode, default_count, rate_limits
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def handle_command(
|
||||
bot: TelegramBot,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
language_code: str = "",
|
||||
) -> str | list[dict[str, Any]] | None:
|
||||
) -> list[CommandResponse] | None:
|
||||
"""Handle a bot command. Routes to provider-specific handlers.
|
||||
|
||||
Returns text response, media list, or None.
|
||||
Returns a list of CommandResponse objects (one per tracker), or None.
|
||||
Universal commands (/start, /help) return a single-element list.
|
||||
Provider-specific commands dispatch per-tracker with per-tracker config.
|
||||
"""
|
||||
cmd, args, count_override = parse_command(text)
|
||||
if not cmd:
|
||||
return None
|
||||
|
||||
ctx_tuples, cmd_templates = await _resolve_command_context(bot)
|
||||
enabled, response_mode, default_count, rate_limits = _merge_command_context(ctx_tuples)
|
||||
ctx_tuples, templates_by_config_id = await _resolve_command_context(bot)
|
||||
enabled, rate_limits = _merge_enabled_commands(ctx_tuples)
|
||||
|
||||
locale = language_code[:2].lower() if language_code else "en"
|
||||
if locale not in ("en", "ru"):
|
||||
locale = "en"
|
||||
|
||||
# Merged templates for universal commands
|
||||
merged_templates = _merge_all_templates(templates_by_config_id)
|
||||
|
||||
if cmd == "start":
|
||||
return _render_cmd_template(cmd_templates, "start", locale, {"bot_name": bot.name})
|
||||
text_resp = _render_cmd_template(merged_templates, "start", locale, {"bot_name": bot.name})
|
||||
return [CommandResponse(text=text_resp)]
|
||||
|
||||
if cmd not in enabled and cmd != "start":
|
||||
return None
|
||||
|
||||
# Rate limit check
|
||||
# Rate limit check (once per command, shared across all trackers)
|
||||
wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits)
|
||||
if wait is not None:
|
||||
return _render_cmd_template(cmd_templates, "rate_limited", locale, {"wait": wait})
|
||||
text_resp = _render_cmd_template(merged_templates, "rate_limited", locale, {"wait": wait})
|
||||
return [CommandResponse(text=text_resp)]
|
||||
|
||||
count = min(count_override or default_count, 20)
|
||||
|
||||
# Build providers map from command context
|
||||
providers_map: dict[int, ServiceProvider] = {}
|
||||
for _, _, provider in ctx_tuples:
|
||||
providers_map[provider.id] = provider
|
||||
|
||||
# Universal commands
|
||||
# Universal commands — single merged response
|
||||
if cmd == "help":
|
||||
ctx = _cmd_help(enabled, locale, cmd_templates)
|
||||
return _render_cmd_template(cmd_templates, "help", locale, ctx)
|
||||
ctx = _cmd_help(enabled, locale, merged_templates)
|
||||
text_resp = _render_cmd_template(merged_templates, "help", locale, ctx)
|
||||
return [CommandResponse(text=text_resp)]
|
||||
|
||||
# Provider-specific dispatch
|
||||
# Provider-specific dispatch — per-tracker
|
||||
from .dispatch import get_handler
|
||||
|
||||
# Group ctx_tuples by provider type
|
||||
by_type: dict[str, list[tuple[CommandTracker, CommandConfig, ServiceProvider]]] = {}
|
||||
for t in ctx_tuples:
|
||||
ptype = t[2].type
|
||||
by_type.setdefault(ptype, []).append(t)
|
||||
|
||||
# Find which handler claims this command
|
||||
for ptype, ptuples in by_type.items():
|
||||
handler = get_handler(ptype)
|
||||
if handler and cmd in handler.get_provider_commands():
|
||||
# Build provider map filtered to this provider type
|
||||
pmap = {p.id: p for _, _, p in ptuples}
|
||||
result = await handler.handle(
|
||||
cmd, args, count, locale, response_mode,
|
||||
pmap, cmd_templates, bot, ptuples,
|
||||
responses: list[CommandResponse] = []
|
||||
for tracker, config, provider in ctx_tuples:
|
||||
if len(responses) >= _MAX_RESPONSES_PER_COMMAND:
|
||||
_LOGGER.warning(
|
||||
"Truncated command responses at %d for bot %d cmd /%s",
|
||||
_MAX_RESPONSES_PER_COMMAND, bot.id, cmd,
|
||||
)
|
||||
if result is not None:
|
||||
return result
|
||||
break
|
||||
|
||||
return None
|
||||
handler = get_handler(provider.type)
|
||||
if not handler or cmd not in handler.get_provider_commands():
|
||||
continue
|
||||
|
||||
tracker_templates = _templates_for_config(templates_by_config_id, config)
|
||||
count = min(count_override or config.default_count or 5, 20)
|
||||
response_mode = config.response_mode or "media"
|
||||
|
||||
result = await handler.handle(
|
||||
cmd, args, count, locale, response_mode,
|
||||
provider, tracker_templates, bot, tracker, config,
|
||||
)
|
||||
if result is not None:
|
||||
responses.append(result)
|
||||
|
||||
return responses if responses else None
|
||||
|
||||
|
||||
def _cmd_help(
|
||||
@@ -283,17 +340,13 @@ async def send_reply(
|
||||
session: aiohttp.ClientSession | None = None,
|
||||
) -> None:
|
||||
"""Send a text reply via TelegramClient."""
|
||||
async def _send(http: aiohttp.ClientSession) -> None:
|
||||
client = TelegramClient(http, bot_token)
|
||||
result = await client.send_message(chat_id, text, reply_to_message_id=reply_to_message_id)
|
||||
if not result.get("success"):
|
||||
_LOGGER.warning("Telegram reply failed: %s", result.get("error"))
|
||||
|
||||
if session is not None:
|
||||
await _send(session)
|
||||
else:
|
||||
async with aiohttp.ClientSession() as http:
|
||||
await _send(http)
|
||||
if session is None:
|
||||
from ..services.http_session import get_http_session
|
||||
session = await get_http_session()
|
||||
client = TelegramClient(session, bot_token)
|
||||
result = await client.send_message(chat_id, text, reply_to_message_id=reply_to_message_id)
|
||||
if not result.get("success"):
|
||||
_LOGGER.warning("Telegram reply failed: %s", result.get("error"))
|
||||
|
||||
|
||||
async def send_media_group(
|
||||
@@ -319,52 +372,50 @@ async def send_media_group(
|
||||
captions = [item.get("caption", "") for item in media_items if item.get("caption")]
|
||||
caption = "\n".join(captions) if captions else None
|
||||
|
||||
async def _send(http: aiohttp.ClientSession) -> None:
|
||||
client = TelegramClient(http, bot_token)
|
||||
result = await client.send_notification(
|
||||
chat_id, assets=assets, caption=caption,
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
chat_action=None,
|
||||
)
|
||||
if not result.get("success"):
|
||||
_LOGGER.warning("Telegram media group failed: %s", result.get("error"))
|
||||
|
||||
if session is not None:
|
||||
await _send(session)
|
||||
else:
|
||||
async with aiohttp.ClientSession() as http:
|
||||
await _send(http)
|
||||
if session is None:
|
||||
from ..services.http_session import get_http_session
|
||||
session = await get_http_session()
|
||||
client = TelegramClient(session, bot_token)
|
||||
result = await client.send_notification(
|
||||
chat_id, assets=assets, caption=caption,
|
||||
reply_to_message_id=reply_to_message_id,
|
||||
chat_action=None,
|
||||
)
|
||||
if not result.get("success"):
|
||||
_LOGGER.warning("Telegram media group failed: %s", result.get("error"))
|
||||
|
||||
|
||||
async def register_commands_with_telegram(bot: TelegramBot) -> bool:
|
||||
"""Register enabled commands with Telegram BotFather API via TelegramClient."""
|
||||
ctx_tuples, templates = await _resolve_command_context(bot)
|
||||
enabled, _, _, _ = _merge_command_context(ctx_tuples)
|
||||
ctx_tuples, templates_by_config_id = await _resolve_command_context(bot)
|
||||
enabled, _ = _merge_enabled_commands(ctx_tuples)
|
||||
templates = _merge_all_templates(templates_by_config_id)
|
||||
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, bot.token)
|
||||
success = False
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, bot.token)
|
||||
success = False
|
||||
|
||||
# Register per-locale commands
|
||||
for locale in ("en", "ru"):
|
||||
commands = []
|
||||
for cmd in enabled:
|
||||
desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd
|
||||
commands.append({"command": cmd, "description": desc})
|
||||
result = await client.set_my_commands(commands, language_code=locale)
|
||||
if result.get("success"):
|
||||
success = True
|
||||
else:
|
||||
_LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error"))
|
||||
|
||||
# Register default (no language_code) with EN descriptions
|
||||
en_commands = []
|
||||
# Register per-locale commands
|
||||
for locale in ("en", "ru"):
|
||||
commands = []
|
||||
for cmd in enabled:
|
||||
desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd
|
||||
en_commands.append({"command": cmd, "description": desc})
|
||||
result = await client.set_my_commands(en_commands)
|
||||
desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd
|
||||
commands.append({"command": cmd, "description": desc})
|
||||
result = await client.set_my_commands(commands, language_code=locale)
|
||||
if result.get("success"):
|
||||
_LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username)
|
||||
success = True
|
||||
else:
|
||||
_LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error"))
|
||||
|
||||
return success
|
||||
# Register default (no language_code) with EN descriptions
|
||||
en_commands = []
|
||||
for cmd in enabled:
|
||||
desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd
|
||||
en_commands.append({"command": cmd, "description": desc})
|
||||
result = await client.set_my_commands(en_commands)
|
||||
if result.get("success"):
|
||||
_LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username)
|
||||
success = True
|
||||
|
||||
return success
|
||||
|
||||
@@ -6,70 +6,48 @@ import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.providers.immich.asset_utils import get_public_url
|
||||
|
||||
from ...database.models import ServiceProvider, TelegramBot
|
||||
from ...database.models import ServiceProvider
|
||||
from ...services import make_immich_provider
|
||||
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
|
||||
from .common import _format_assets, build_asset_dict
|
||||
from ...services.http_session import get_http_session
|
||||
from ..command_utils import get_trackers_for_provider
|
||||
from ..handler import _render_cmd_template
|
||||
from .common import _format_assets, build_asset_dict, fetch_albums_with_links
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _cmd_albums(
|
||||
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
|
||||
provider: ServiceProvider, locale: str,
|
||||
) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
if not trackers:
|
||||
return {"albums": []}
|
||||
|
||||
albums_data: list[dict] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for tracker in trackers:
|
||||
provider = providers_map.get(tracker.provider_id)
|
||||
if not provider or provider.type != "immich":
|
||||
continue
|
||||
immich = make_immich_provider(http, provider)
|
||||
album_ids = tracker.collection_ids or []
|
||||
if not album_ids:
|
||||
continue
|
||||
# Deduplicate album IDs while preserving order
|
||||
seen: set[str] = set()
|
||||
album_ids: list[str] = []
|
||||
for tracker in trackers:
|
||||
for aid in tracker.collection_ids or []:
|
||||
if aid not in seen:
|
||||
seen.add(aid)
|
||||
album_ids.append(aid)
|
||||
if not album_ids:
|
||||
return {"albums": []}
|
||||
|
||||
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
|
||||
album_results = await asyncio.gather(
|
||||
*[immich.client.get_album(aid) for aid in album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
link_results = await asyncio.gather(
|
||||
*[immich.client.get_shared_links(aid) for aid in album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
for album_id, result, links in zip(album_ids, album_results, link_results):
|
||||
if isinstance(result, Exception):
|
||||
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
|
||||
albums_data.append({
|
||||
"name": f"{album_id[:8]}...", "asset_count": "?", "id": album_id,
|
||||
})
|
||||
elif result:
|
||||
pub_url = ""
|
||||
if not isinstance(links, Exception) and ext_domain:
|
||||
pub_url = get_public_url(ext_domain, links) or ""
|
||||
albums_data.append({
|
||||
"name": result.name, "asset_count": result.asset_count,
|
||||
"id": album_id, "public_url": pub_url,
|
||||
})
|
||||
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
|
||||
http = await get_http_session()
|
||||
immich = make_immich_provider(http, provider)
|
||||
albums_data = await fetch_albums_with_links(immich.client, album_ids, ext_domain)
|
||||
|
||||
return {"albums": albums_data}
|
||||
|
||||
|
||||
async def cmd_favorites(
|
||||
bot: TelegramBot, providers_map: dict[int, ServiceProvider],
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
all_album_ids: list[str], count: int, locale: str,
|
||||
response_mode: str, client: Any,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
) -> str | dict[str, Any]:
|
||||
"""Handle /favorites command with concurrent album fetching."""
|
||||
album_ids = all_album_ids[:10]
|
||||
if not album_ids:
|
||||
@@ -104,28 +82,6 @@ async def cmd_summary(
|
||||
if not all_album_ids:
|
||||
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": []})
|
||||
|
||||
album_results = await asyncio.gather(
|
||||
*[client.get_album(aid) for aid in all_album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
link_results = await asyncio.gather(
|
||||
*[client.get_shared_links(aid) for aid in all_album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
ext = external_domain.rstrip("/")
|
||||
|
||||
albums_data: list[dict] = []
|
||||
for album_id, result, links in zip(all_album_ids, album_results, link_results):
|
||||
if isinstance(result, Exception):
|
||||
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
|
||||
continue
|
||||
if result:
|
||||
pub_url = ""
|
||||
if not isinstance(links, Exception) and ext:
|
||||
pub_url = get_public_url(ext, links) or ""
|
||||
albums_data.append({
|
||||
"name": result.name, "asset_count": result.asset_count,
|
||||
"id": album_id, "public_url": pub_url,
|
||||
})
|
||||
|
||||
albums_data = await fetch_albums_with_links(client, all_album_ids, ext, include_failed=False)
|
||||
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": albums_data})
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from ...services import make_immich_provider
|
||||
from notify_bridge_core.providers.immich.asset_utils import get_public_url
|
||||
|
||||
from ..handler import _render_cmd_template
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -17,6 +19,53 @@ _IMMICH_COMMANDS = {
|
||||
}
|
||||
|
||||
|
||||
async def fetch_albums_with_links(
|
||||
client: Any,
|
||||
album_ids: list[str],
|
||||
ext_domain: str,
|
||||
*,
|
||||
include_failed: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch albums and their shared links concurrently.
|
||||
|
||||
Returns a list of album data dicts with keys: name, asset_count, id,
|
||||
public_url, and ``_album`` (the raw album object for callers that need
|
||||
asset-level access).
|
||||
|
||||
When *include_failed* is True, albums that fail to fetch are included
|
||||
with placeholder data (``"?"`` for counts). When False, they are
|
||||
silently skipped.
|
||||
"""
|
||||
album_results = await asyncio.gather(
|
||||
*[client.get_album(aid) for aid in album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
link_results = await asyncio.gather(
|
||||
*[client.get_shared_links(aid) for aid in album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
albums_data: list[dict[str, Any]] = []
|
||||
for album_id, result, links in zip(album_ids, album_results, link_results):
|
||||
if isinstance(result, Exception):
|
||||
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
|
||||
if include_failed:
|
||||
albums_data.append({
|
||||
"name": f"{album_id[:8]}...", "asset_count": "?",
|
||||
"id": album_id, "public_url": "", "_album": None,
|
||||
})
|
||||
continue
|
||||
if result:
|
||||
pub_url = ""
|
||||
if not isinstance(links, Exception) and ext_domain:
|
||||
pub_url = get_public_url(ext_domain, links) or ""
|
||||
albums_data.append({
|
||||
"name": result.name, "asset_count": result.asset_count,
|
||||
"id": album_id, "public_url": pub_url, "_album": result,
|
||||
})
|
||||
return albums_data
|
||||
|
||||
|
||||
def build_asset_dict(
|
||||
asset: Any,
|
||||
*,
|
||||
@@ -56,8 +105,14 @@ def _format_assets(
|
||||
assets: list[dict[str, Any]], cmd: str, query: str,
|
||||
locale: str, response_mode: str, client: Any,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Format asset results as text or media payload."""
|
||||
) -> str | dict[str, Any]:
|
||||
"""Format asset results as text or a text-plus-media payload.
|
||||
|
||||
Returns:
|
||||
str: rendered text when *response_mode* is ``"text"`` (or no assets).
|
||||
dict: ``{"text": ..., "media": [...]}`` when *response_mode* is
|
||||
``"media"`` and assets are present.
|
||||
"""
|
||||
if not assets:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": query})
|
||||
|
||||
@@ -68,7 +123,7 @@ def _format_assets(
|
||||
})
|
||||
|
||||
if response_mode == "media":
|
||||
media_items = []
|
||||
media_items: list[dict[str, Any]] = []
|
||||
for asset in assets:
|
||||
asset_id = asset.get("id", "")
|
||||
media_items.append({
|
||||
|
||||
@@ -13,23 +13,22 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ...database.engine import get_engine
|
||||
from ...database.models import (
|
||||
EventLog, NotificationTarget, NotificationTrackerTarget,
|
||||
ServiceProvider, TelegramBot, TrackingConfig,
|
||||
EventLog, NotificationTracker, NotificationTrackerTarget,
|
||||
ServiceProvider, TrackingConfig,
|
||||
)
|
||||
from notify_bridge_core.providers.immich.asset_utils import get_public_url
|
||||
|
||||
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
|
||||
from .common import _format_assets, build_asset_dict
|
||||
from ..command_utils import get_trackers_for_provider
|
||||
from ..handler import _render_cmd_template
|
||||
from .common import _format_assets, build_asset_dict, fetch_albums_with_links
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _cmd_events(
|
||||
bot: TelegramBot, providers_map: dict[int, ServiceProvider],
|
||||
provider: ServiceProvider,
|
||||
count: int, locale: str,
|
||||
) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
if not tracker_ids:
|
||||
return {"events": []}
|
||||
@@ -57,32 +56,21 @@ async def cmd_latest(
|
||||
locale: str, response_mode: str,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
external_domain: str = "",
|
||||
) -> str | list[dict[str, Any]]:
|
||||
) -> str | dict[str, Any]:
|
||||
"""Handle /latest command with concurrent album fetching."""
|
||||
album_ids = all_album_ids[:10]
|
||||
if not album_ids:
|
||||
return _format_assets([], "latest", "", locale, response_mode, client, cmd_templates)
|
||||
|
||||
album_results = await asyncio.gather(
|
||||
*[client.get_album(aid) for aid in album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
link_results = await asyncio.gather(
|
||||
*[client.get_shared_links(aid) for aid in album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
ext = external_domain.rstrip("/")
|
||||
fetched = await fetch_albums_with_links(client, album_ids, ext, include_failed=False)
|
||||
|
||||
latest_assets: list[dict[str, Any]] = []
|
||||
for album_id, result, links in zip(album_ids, album_results, link_results):
|
||||
if isinstance(result, Exception):
|
||||
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
|
||||
continue
|
||||
if result:
|
||||
pub_url = ""
|
||||
if not isinstance(links, Exception) and ext:
|
||||
pub_url = get_public_url(ext, links) or ""
|
||||
for aid, asset in list(result.assets.items())[:count]:
|
||||
for album_data in fetched:
|
||||
pub_url = album_data.get("public_url", "")
|
||||
album_obj = album_data.get("_album")
|
||||
if album_obj:
|
||||
for aid, asset in list(album_obj.assets.items())[:count]:
|
||||
asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else ""
|
||||
latest_assets.append(build_asset_dict(asset, public_url=asset_pub))
|
||||
|
||||
@@ -95,32 +83,21 @@ async def cmd_random(
|
||||
locale: str, response_mode: str,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
external_domain: str = "",
|
||||
) -> str | list[dict[str, Any]]:
|
||||
) -> str | dict[str, Any]:
|
||||
"""Handle /random command with concurrent album fetching."""
|
||||
album_ids = all_album_ids[:10]
|
||||
if not album_ids:
|
||||
return _format_assets([], "random", "", locale, response_mode, client, cmd_templates)
|
||||
|
||||
album_results = await asyncio.gather(
|
||||
*[client.get_album(aid) for aid in album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
link_results = await asyncio.gather(
|
||||
*[client.get_shared_links(aid) for aid in album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
ext = external_domain.rstrip("/")
|
||||
fetched = await fetch_albums_with_links(client, album_ids, ext, include_failed=False)
|
||||
|
||||
random_assets: list[dict[str, Any]] = []
|
||||
for album_id, result, links in zip(album_ids, album_results, link_results):
|
||||
if isinstance(result, Exception):
|
||||
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
|
||||
continue
|
||||
if result:
|
||||
pub_url = ""
|
||||
if not isinstance(links, Exception) and ext:
|
||||
pub_url = get_public_url(ext, links) or ""
|
||||
asset_list = list(result.assets.values())
|
||||
for album_data in fetched:
|
||||
pub_url = album_data.get("public_url", "")
|
||||
album_obj = album_data.get("_album")
|
||||
if album_obj:
|
||||
asset_list = list(album_obj.assets.values())
|
||||
sampled = rng.sample(asset_list, min(count, len(asset_list)))
|
||||
for asset in sampled:
|
||||
asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else ""
|
||||
@@ -130,40 +107,40 @@ async def cmd_random(
|
||||
return _format_assets(random_assets[:count], "random", "", locale, response_mode, client, cmd_templates)
|
||||
|
||||
|
||||
async def _check_native_memory(bot: TelegramBot) -> bool:
|
||||
"""Check if any tracker-target linked to this bot uses native memory source."""
|
||||
async def _check_native_memory(provider_id: int) -> bool:
|
||||
"""Check if any notification tracker for this provider uses native memory source."""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(NotificationTarget).where(
|
||||
NotificationTarget.type == "telegram",
|
||||
NotificationTarget.user_id == bot.user_id,
|
||||
tracker_result = await session.exec(
|
||||
select(NotificationTracker).where(
|
||||
NotificationTracker.provider_id == provider_id,
|
||||
)
|
||||
)
|
||||
targets = result.all()
|
||||
bot_target_ids = {t.id for t in targets if t.config.get("bot_token") == bot.token}
|
||||
if not bot_target_ids:
|
||||
trackers = tracker_result.all()
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
if not tracker_ids:
|
||||
return False
|
||||
tt_result = await session.exec(
|
||||
select(NotificationTrackerTarget).where(
|
||||
NotificationTrackerTarget.target_id.in_(bot_target_ids)
|
||||
NotificationTrackerTarget.tracker_id.in_(tracker_ids)
|
||||
)
|
||||
)
|
||||
for tt in tt_result.all():
|
||||
if tt.tracking_config_id:
|
||||
tc = await session.get(TrackingConfig, tt.tracking_config_id)
|
||||
if tc and tc.memory_source == "native":
|
||||
return True
|
||||
return False
|
||||
tc_ids = list({tt.tracking_config_id for tt in tt_result.all() if tt.tracking_config_id})
|
||||
if not tc_ids:
|
||||
return False
|
||||
tc_result = await session.exec(
|
||||
select(TrackingConfig).where(TrackingConfig.id.in_(tc_ids))
|
||||
)
|
||||
return any(tc.memory_source == "native" for tc in tc_result.all())
|
||||
|
||||
|
||||
async def cmd_memory(
|
||||
bot: TelegramBot, client: Any, all_album_ids: list[str], count: int,
|
||||
provider_id: int, client: Any, all_album_ids: list[str], count: int,
|
||||
locale: str, response_mode: str,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
) -> str | dict[str, Any]:
|
||||
"""Handle /memory command with concurrent album fetching."""
|
||||
use_native = await _check_native_memory(bot)
|
||||
use_native = await _check_native_memory(provider_id)
|
||||
today = datetime.now(timezone.utc)
|
||||
memory_assets: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
@@ -2,26 +2,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ...database.engine import get_engine
|
||||
from ...database.models import (
|
||||
CommandConfig, CommandTracker, EventLog,
|
||||
CommandConfig, CommandTracker,
|
||||
ServiceProvider, TelegramBot,
|
||||
)
|
||||
from ...services import make_immich_provider
|
||||
from ..base import ProviderCommandHandler
|
||||
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
|
||||
from notify_bridge_core.providers.immich.asset_utils import get_public_url
|
||||
from ...services.http_session import get_http_session
|
||||
from ..base import CommandResponse, ProviderCommandHandler
|
||||
from ..command_utils import get_last_event_str, get_trackers_for_provider
|
||||
from ..handler import _render_cmd_template
|
||||
|
||||
from .albums import _cmd_albums, cmd_favorites, cmd_summary
|
||||
from .common import _IMMICH_COMMANDS
|
||||
from .common import _IMMICH_COMMANDS, fetch_albums_with_links
|
||||
from .events import _cmd_events, cmd_latest, cmd_memory, cmd_random
|
||||
from .search import cmd_find, cmd_person, cmd_place, cmd_search
|
||||
|
||||
@@ -29,21 +24,15 @@ _LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _cmd_status(
|
||||
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
|
||||
provider: ServiceProvider, locale: str,
|
||||
) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
active = sum(1 for t in trackers if t.enabled)
|
||||
total = len(trackers)
|
||||
total_albums = sum(len(t.collection_ids or []) for t in trackers)
|
||||
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(EventLog).order_by(EventLog.created_at.desc()).limit(1)
|
||||
)
|
||||
last_event = result.first()
|
||||
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
last_str = await get_last_event_str(tracker_ids)
|
||||
|
||||
return {
|
||||
"trackers_active": active, "trackers_total": total,
|
||||
@@ -52,16 +41,13 @@ async def _cmd_status(
|
||||
|
||||
|
||||
async def _cmd_people(
|
||||
providers_map: dict[int, ServiceProvider], locale: str,
|
||||
provider: ServiceProvider, locale: str,
|
||||
) -> dict[str, Any]:
|
||||
all_people: dict[str, str] = {}
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider in providers_map.values():
|
||||
if provider.type != "immich":
|
||||
continue
|
||||
immich = make_immich_provider(http, provider)
|
||||
people = await immich.client.get_people()
|
||||
all_people.update(people)
|
||||
http = await get_http_session()
|
||||
immich = make_immich_provider(http, provider)
|
||||
people = await immich.client.get_people()
|
||||
all_people.update(people)
|
||||
names = sorted(all_people.values())
|
||||
return {"people": names}
|
||||
|
||||
@@ -87,106 +73,92 @@ class ImmichCommandHandler(ProviderCommandHandler):
|
||||
count: int,
|
||||
locale: str,
|
||||
response_mode: str,
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
provider: ServiceProvider,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
bot: TelegramBot,
|
||||
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
) -> str | list[dict[str, Any]] | None:
|
||||
tracker: CommandTracker,
|
||||
config: CommandConfig,
|
||||
) -> CommandResponse | None:
|
||||
if cmd == "status":
|
||||
ctx = await _cmd_status(bot, providers_map, locale)
|
||||
return _render_cmd_template(cmd_templates, "status", locale, ctx)
|
||||
ctx = await _cmd_status(provider, locale)
|
||||
return CommandResponse(text=_render_cmd_template(cmd_templates, "status", locale, ctx))
|
||||
if cmd == "albums":
|
||||
ctx = await _cmd_albums(bot, providers_map, locale)
|
||||
return _render_cmd_template(cmd_templates, "albums", locale, ctx)
|
||||
ctx = await _cmd_albums(provider, locale)
|
||||
return CommandResponse(text=_render_cmd_template(cmd_templates, "albums", locale, ctx))
|
||||
if cmd == "events":
|
||||
ctx = await _cmd_events(bot, providers_map, count, locale)
|
||||
return _render_cmd_template(cmd_templates, "events", locale, ctx)
|
||||
ctx = await _cmd_events(provider, count, locale)
|
||||
return CommandResponse(text=_render_cmd_template(cmd_templates, "events", locale, ctx))
|
||||
if cmd == "people":
|
||||
ctx = await _cmd_people(providers_map, locale)
|
||||
return _render_cmd_template(cmd_templates, "people", locale, ctx)
|
||||
ctx = await _cmd_people(provider, locale)
|
||||
return CommandResponse(text=_render_cmd_template(cmd_templates, "people", locale, ctx))
|
||||
if cmd in ("search", "find", "person", "place", "latest",
|
||||
"random", "favorites", "summary", "memory"):
|
||||
return await _cmd_immich(
|
||||
bot, cmd, args, count, locale, response_mode,
|
||||
providers_map, cmd_templates,
|
||||
cmd, args, count, locale, response_mode,
|
||||
provider, cmd_templates,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def _cmd_immich(
|
||||
bot: TelegramBot, cmd: str, args: str, count: int, locale: str,
|
||||
response_mode: str, providers_map: dict[int, ServiceProvider],
|
||||
cmd: str, args: str, count: int, locale: str,
|
||||
response_mode: str, provider: ServiceProvider,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
) -> CommandResponse | None:
|
||||
"""Handle commands that need Immich API access and may return media."""
|
||||
if not providers_map:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
|
||||
|
||||
provider_ids = set(providers_map.keys())
|
||||
notification_trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
notification_trackers = await get_trackers_for_provider(provider.id)
|
||||
|
||||
all_album_ids: list[str] = []
|
||||
for t in notification_trackers:
|
||||
all_album_ids.extend(t.collection_ids or [])
|
||||
|
||||
provider: ServiceProvider | None = None
|
||||
for p in providers_map.values():
|
||||
if p.type == "immich":
|
||||
provider = p
|
||||
break
|
||||
if not provider:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
|
||||
|
||||
ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/")
|
||||
|
||||
async with aiohttp.ClientSession() as http:
|
||||
immich = make_immich_provider(http, provider)
|
||||
client = immich.client
|
||||
http = await get_http_session()
|
||||
immich = make_immich_provider(http, provider)
|
||||
client = immich.client
|
||||
|
||||
# Build asset_id → public_url map from tracked albums' shared links
|
||||
asset_public_urls: dict[str, str] = {}
|
||||
if ext_domain and all_album_ids and cmd in ("search", "find", "person", "place", "favorites"):
|
||||
link_results = await asyncio.gather(
|
||||
*[client.get_shared_links(aid) for aid in all_album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
album_results = await asyncio.gather(
|
||||
*[client.get_album(aid) for aid in all_album_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
for album_id, links, album in zip(all_album_ids, link_results, album_results):
|
||||
if isinstance(links, Exception) or isinstance(album, Exception):
|
||||
continue
|
||||
pub_url = get_public_url(ext_domain, links)
|
||||
if pub_url and album:
|
||||
for asset_id in album.assets:
|
||||
asset_public_urls[asset_id] = f"{pub_url}/photos/{asset_id}"
|
||||
# Build asset_id → public_url map from tracked albums' shared links
|
||||
asset_public_urls: dict[str, str] = {}
|
||||
if ext_domain and all_album_ids and cmd in ("search", "find", "person", "place", "favorites"):
|
||||
fetched = await fetch_albums_with_links(client, all_album_ids, ext_domain, include_failed=False)
|
||||
for album_data in fetched:
|
||||
pub_url = album_data.get("public_url", "")
|
||||
album_obj = album_data.get("_album")
|
||||
if pub_url and album_obj:
|
||||
for asset_id in album_obj.assets:
|
||||
asset_public_urls[asset_id] = f"{pub_url}/photos/{asset_id}"
|
||||
|
||||
if cmd == "search":
|
||||
return await cmd_search(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
|
||||
# Wrap single-provider in a map for functions that still expect it
|
||||
providers_map = {provider.id: provider}
|
||||
|
||||
if cmd == "find":
|
||||
return await cmd_find(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
|
||||
result: str | dict[str, Any] | None = None
|
||||
|
||||
if cmd == "person":
|
||||
return await cmd_person(client, args, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
|
||||
if cmd == "search":
|
||||
result = await cmd_search(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
|
||||
elif cmd == "find":
|
||||
result = await cmd_find(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
|
||||
elif cmd == "person":
|
||||
result = await cmd_person(client, args, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
|
||||
elif cmd == "place":
|
||||
result = await cmd_place(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
|
||||
elif cmd == "favorites":
|
||||
result = await cmd_favorites(providers_map, all_album_ids, count, locale, response_mode, client, cmd_templates)
|
||||
elif cmd == "latest":
|
||||
result = await cmd_latest(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
|
||||
elif cmd == "random":
|
||||
result = await cmd_random(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
|
||||
elif cmd == "summary":
|
||||
result = await cmd_summary(client, all_album_ids, locale, cmd_templates, external_domain=ext_domain)
|
||||
elif cmd == "memory":
|
||||
result = await cmd_memory(provider.id, client, all_album_ids, count, locale, response_mode, cmd_templates)
|
||||
|
||||
if cmd == "place":
|
||||
return await cmd_place(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls)
|
||||
|
||||
if cmd == "favorites":
|
||||
return await cmd_favorites(bot, providers_map, all_album_ids, count, locale, response_mode, client, cmd_templates)
|
||||
|
||||
if cmd == "latest":
|
||||
return await cmd_latest(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
|
||||
|
||||
if cmd == "random":
|
||||
return await cmd_random(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain)
|
||||
|
||||
if cmd == "summary":
|
||||
return await cmd_summary(client, all_album_ids, locale, cmd_templates, external_domain=ext_domain)
|
||||
|
||||
if cmd == "memory":
|
||||
return await cmd_memory(bot, client, all_album_ids, count, locale, response_mode, cmd_templates)
|
||||
|
||||
return None
|
||||
if result is None:
|
||||
return None
|
||||
# _format_assets returns {"text": ..., "media": [...]} for media mode
|
||||
if isinstance(result, dict):
|
||||
return CommandResponse(
|
||||
text=result.get("text"),
|
||||
media=result.get("media", []),
|
||||
)
|
||||
return CommandResponse(text=result)
|
||||
|
||||
@@ -9,14 +9,15 @@ from .common import _format_assets
|
||||
|
||||
|
||||
def _enrich_assets(assets: list[dict[str, Any]], asset_public_urls: dict[str, str]) -> list[dict[str, Any]]:
|
||||
"""Add public_url to assets from the pre-built map."""
|
||||
"""Add public_url to assets from the pre-built map. Returns new list without mutating inputs."""
|
||||
if not asset_public_urls:
|
||||
return assets
|
||||
for asset in assets:
|
||||
aid = asset.get("id", "")
|
||||
if aid and aid in asset_public_urls and not asset.get("public_url"):
|
||||
asset["public_url"] = asset_public_urls[aid]
|
||||
return assets
|
||||
return [
|
||||
{**asset, "public_url": asset_public_urls.get(asset.get("id", ""), "")}
|
||||
if asset.get("id", "") in asset_public_urls and not asset.get("public_url")
|
||||
else asset
|
||||
for asset in assets
|
||||
]
|
||||
|
||||
|
||||
async def cmd_search(
|
||||
@@ -24,7 +25,7 @@ async def cmd_search(
|
||||
locale: str, response_mode: str,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
asset_public_urls: dict[str, str] | None = None,
|
||||
) -> str | list[dict[str, Any]]:
|
||||
) -> str | dict[str, Any]:
|
||||
"""Handle /search command."""
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "search", "query": ""})
|
||||
@@ -38,7 +39,7 @@ async def cmd_find(
|
||||
locale: str, response_mode: str,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
asset_public_urls: dict[str, str] | None = None,
|
||||
) -> str | list[dict[str, Any]]:
|
||||
) -> str | dict[str, Any]:
|
||||
"""Handle /find command."""
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "find", "query": ""})
|
||||
@@ -52,7 +53,7 @@ async def cmd_person(
|
||||
locale: str, response_mode: str,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
asset_public_urls: dict[str, str] | None = None,
|
||||
) -> str | list[dict[str, Any]]:
|
||||
) -> str | dict[str, Any]:
|
||||
"""Handle /person command."""
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": ""})
|
||||
@@ -74,7 +75,7 @@ async def cmd_place(
|
||||
locale: str, response_mode: str,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
asset_public_urls: dict[str, str] | None = None,
|
||||
) -> str | list[dict[str, Any]]:
|
||||
) -> str | dict[str, Any]:
|
||||
"""Handle /place command."""
|
||||
if not args:
|
||||
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "place", "query": ""})
|
||||
|
||||
@@ -3,17 +3,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from ..database.models import CommandConfig, CommandTracker, ServiceProvider, TelegramBot
|
||||
from ..services import make_nut_provider
|
||||
from .base import ProviderCommandHandler
|
||||
from .base import CommandResponse, ProviderCommandHandler
|
||||
from .handler import _render_cmd_template
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_NUT_COMMANDS = {"status", "devices", "battery"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command dispatch table
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
|
||||
|
||||
|
||||
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
|
||||
"""Register a function in the text command dispatch table."""
|
||||
name = fn.__name__.removeprefix("_cmd_")
|
||||
_TEXT_COMMANDS[name] = fn
|
||||
return fn
|
||||
|
||||
|
||||
class NutCommandHandler(ProviderCommandHandler):
|
||||
"""Handles NUT-specific bot commands."""
|
||||
@@ -33,80 +47,73 @@ class NutCommandHandler(ProviderCommandHandler):
|
||||
count: int,
|
||||
locale: str,
|
||||
response_mode: str,
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
provider: ServiceProvider,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
bot: TelegramBot,
|
||||
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
) -> str | list[dict[str, Any]] | None:
|
||||
if cmd == "status":
|
||||
ctx = await _cmd_status(providers_map)
|
||||
return _render_cmd_template(cmd_templates, "status", locale, ctx)
|
||||
if cmd == "devices":
|
||||
ctx = await _cmd_devices(providers_map)
|
||||
return _render_cmd_template(cmd_templates, "devices", locale, ctx)
|
||||
if cmd == "battery":
|
||||
ctx = await _cmd_battery(providers_map)
|
||||
return _render_cmd_template(cmd_templates, "battery", locale, ctx)
|
||||
return None
|
||||
tracker: CommandTracker,
|
||||
config: CommandConfig,
|
||||
) -> CommandResponse | None:
|
||||
fn = _TEXT_COMMANDS.get(cmd)
|
||||
if fn is None:
|
||||
return None
|
||||
ctx = await fn(provider, count)
|
||||
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
|
||||
|
||||
|
||||
async def _query_all_ups(
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
async def _query_ups(
|
||||
provider: ServiceProvider,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Connect to all NUT providers and query UPS data."""
|
||||
"""Connect to a NUT provider and query UPS data."""
|
||||
from notify_bridge_core.providers.nut.models import NutUpsData
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
for provider in providers_map.values():
|
||||
if provider.type != "nut":
|
||||
continue
|
||||
nut = make_nut_provider(provider)
|
||||
nut = make_nut_provider(provider)
|
||||
try:
|
||||
client = nut._make_client()
|
||||
await client.connect()
|
||||
try:
|
||||
client = nut._make_client()
|
||||
await client.connect()
|
||||
try:
|
||||
devices = await client.list_ups()
|
||||
for dev in devices:
|
||||
variables = await client.list_var(dev.name)
|
||||
data = NutUpsData.from_variables(dev.name, variables)
|
||||
results.append({
|
||||
"name": data.name,
|
||||
"description": data.description,
|
||||
"model": data.model,
|
||||
"manufacturer": data.manufacturer,
|
||||
"status": data.status,
|
||||
"battery_charge": int(data.battery_charge) if data.battery_charge is not None else None,
|
||||
"battery_runtime": data.battery_runtime_formatted,
|
||||
"ups_load": int(data.ups_load) if data.ups_load is not None else None,
|
||||
"input_voltage": str(data.input_voltage) if data.input_voltage is not None else None,
|
||||
"output_voltage": str(data.output_voltage) if data.output_voltage is not None else None,
|
||||
})
|
||||
finally:
|
||||
await client.disconnect()
|
||||
except Exception as exc:
|
||||
_LOGGER.warning("Failed to query NUT provider %s: %s", provider.name, exc)
|
||||
devices = await client.list_ups()
|
||||
for dev in devices:
|
||||
variables = await client.list_var(dev.name)
|
||||
data = NutUpsData.from_variables(dev.name, variables)
|
||||
results.append({
|
||||
"name": data.name,
|
||||
"description": data.description,
|
||||
"model": data.model,
|
||||
"manufacturer": data.manufacturer,
|
||||
"status": data.status,
|
||||
"battery_charge": int(data.battery_charge) if data.battery_charge is not None else None,
|
||||
"battery_runtime": data.battery_runtime_formatted,
|
||||
"ups_load": int(data.ups_load) if data.ups_load is not None else None,
|
||||
"input_voltage": str(data.input_voltage) if data.input_voltage is not None else None,
|
||||
"output_voltage": str(data.output_voltage) if data.output_voltage is not None else None,
|
||||
})
|
||||
finally:
|
||||
await client.disconnect()
|
||||
except Exception as exc:
|
||||
_LOGGER.warning("Failed to query NUT provider %s: %s", provider.name, exc)
|
||||
return results
|
||||
|
||||
|
||||
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
|
||||
devices = await _query_all_ups(providers_map)
|
||||
@_text_cmd
|
||||
async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
devices = await _query_ups(provider)
|
||||
return {"devices": devices}
|
||||
|
||||
|
||||
async def _cmd_devices(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
|
||||
@_text_cmd
|
||||
async def _cmd_devices(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
devices: list[dict[str, Any]] = []
|
||||
for provider in providers_map.values():
|
||||
if provider.type != "nut":
|
||||
continue
|
||||
nut = make_nut_provider(provider)
|
||||
try:
|
||||
device_list = await nut.list_collections()
|
||||
devices.extend(device_list)
|
||||
except Exception as exc:
|
||||
_LOGGER.warning("Failed to list devices from %s: %s", provider.name, exc)
|
||||
nut = make_nut_provider(provider)
|
||||
try:
|
||||
device_list = await nut.list_collections()
|
||||
devices.extend(device_list)
|
||||
except Exception as exc:
|
||||
_LOGGER.warning("Failed to list devices from %s: %s", provider.name, exc)
|
||||
return {"devices": devices}
|
||||
|
||||
|
||||
async def _cmd_battery(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
|
||||
devices = await _query_all_ups(providers_map)
|
||||
@_text_cmd
|
||||
async def _cmd_battery(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
devices = await _query_ups(provider)
|
||||
return {"devices": devices}
|
||||
|
||||
@@ -3,26 +3,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import (
|
||||
CommandConfig, CommandTracker, EventLog,
|
||||
NotificationTracker, ServiceProvider, TelegramBot,
|
||||
CommandConfig, CommandTracker, ServiceProvider, TelegramBot,
|
||||
)
|
||||
from ..services import make_planka_provider
|
||||
from .base import ProviderCommandHandler
|
||||
from .handler import _render_cmd_template, _get_notification_trackers_for_providers
|
||||
from ..services.http_session import get_http_session
|
||||
from .base import CommandResponse, ProviderCommandHandler
|
||||
from .command_utils import get_last_event_str, get_tracked_collection_ids, get_trackers_for_provider
|
||||
from .handler import _render_cmd_template
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_PLANKA_COMMANDS = {"status", "boards", "cards", "lists"}
|
||||
|
||||
|
||||
def _get_tracked_board_ids(
|
||||
provider: ServiceProvider,
|
||||
trackers: list,
|
||||
) -> list[str]:
|
||||
"""Get board IDs from tracked collection_ids for this provider."""
|
||||
if not provider.config.get("api_key"):
|
||||
return []
|
||||
return get_tracked_collection_ids(provider, trackers)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command dispatch table
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {}
|
||||
|
||||
|
||||
def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]:
|
||||
"""Register a function in the text command dispatch table."""
|
||||
name = fn.__name__.removeprefix("_cmd_")
|
||||
_TEXT_COMMANDS[name] = fn
|
||||
return fn
|
||||
|
||||
|
||||
class PlankaCommandHandler(ProviderCommandHandler):
|
||||
"""Handles Planka-specific bot commands."""
|
||||
|
||||
@@ -43,69 +64,26 @@ class PlankaCommandHandler(ProviderCommandHandler):
|
||||
count: int,
|
||||
locale: str,
|
||||
response_mode: str,
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
provider: ServiceProvider,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
bot: TelegramBot,
|
||||
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
) -> str | list[dict[str, Any]] | None:
|
||||
if cmd == "status":
|
||||
ctx = await _cmd_status(providers_map)
|
||||
return _render_cmd_template(cmd_templates, "status", locale, ctx)
|
||||
if cmd == "boards":
|
||||
ctx = await _cmd_boards(providers_map)
|
||||
return _render_cmd_template(cmd_templates, "boards", locale, ctx)
|
||||
if cmd == "cards":
|
||||
ctx = await _cmd_cards(providers_map, count)
|
||||
return _render_cmd_template(cmd_templates, "cards", locale, ctx)
|
||||
if cmd == "lists":
|
||||
ctx = await _cmd_lists(providers_map)
|
||||
return _render_cmd_template(cmd_templates, "lists", locale, ctx)
|
||||
return None
|
||||
tracker: CommandTracker,
|
||||
config: CommandConfig,
|
||||
) -> CommandResponse | None:
|
||||
fn = _TEXT_COMMANDS.get(cmd)
|
||||
if fn is None:
|
||||
return None
|
||||
ctx = await fn(provider, count)
|
||||
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
|
||||
|
||||
|
||||
def _get_tracked_board_ids(
|
||||
providers_map: dict[int, ServiceProvider],
|
||||
trackers: list[NotificationTracker],
|
||||
) -> list[tuple[ServiceProvider, str]]:
|
||||
"""Get (provider, board_id) tuples from tracked collection_ids."""
|
||||
boards: list[tuple[ServiceProvider, str]] = []
|
||||
for tracker in trackers:
|
||||
provider = providers_map.get(tracker.provider_id)
|
||||
if not provider or provider.type != "planka":
|
||||
continue
|
||||
if not provider.config.get("api_key"):
|
||||
continue
|
||||
for board_id in (tracker.collection_ids or []):
|
||||
entry = (provider, board_id)
|
||||
if entry not in boards:
|
||||
boards.append(entry)
|
||||
# Also check filters.collections
|
||||
for board_id in (tracker.filters or {}).get("collections", []):
|
||||
entry = (provider, board_id)
|
||||
if entry not in boards:
|
||||
boards.append(entry)
|
||||
return boards[:20]
|
||||
@_text_cmd
|
||||
async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
tracked_boards = _get_tracked_board_ids(provider, trackers)
|
||||
|
||||
|
||||
async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
|
||||
|
||||
# Last event
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
if tracker_ids:
|
||||
result = await session.exec(
|
||||
select(EventLog)
|
||||
.where(EventLog.tracker_id.in_(tracker_ids))
|
||||
.order_by(EventLog.created_at.desc()).limit(1)
|
||||
)
|
||||
last_event = result.first()
|
||||
else:
|
||||
last_event = None
|
||||
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
last_str = await get_last_event_str(tracker_ids)
|
||||
|
||||
return {
|
||||
"boards_count": len(tracked_boards),
|
||||
@@ -113,81 +91,69 @@ async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, An
|
||||
}
|
||||
|
||||
|
||||
async def _cmd_boards(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
|
||||
@_text_cmd
|
||||
async def _cmd_boards(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
tracked_boards = _get_tracked_board_ids(provider, trackers)
|
||||
|
||||
boards_data: list[dict[str, Any]] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider, board_id in tracked_boards:
|
||||
planka = make_planka_provider(http, provider)
|
||||
all_boards = await planka.client.get_boards()
|
||||
for b in all_boards:
|
||||
if str(b.get("id", "")) == board_id:
|
||||
boards_data.append({"name": b.get("name", board_id)})
|
||||
break
|
||||
else:
|
||||
boards_data.append({"name": board_id})
|
||||
http = await get_http_session()
|
||||
planka = make_planka_provider(http, provider)
|
||||
all_boards = await planka.client.get_boards()
|
||||
board_names = {str(b.get("id", "")): b.get("name", "") for b in all_boards}
|
||||
for board_id in tracked_boards:
|
||||
boards_data.append({"name": board_names.get(board_id, board_id)})
|
||||
|
||||
return {"boards": boards_data}
|
||||
|
||||
|
||||
async def _cmd_cards(
|
||||
providers_map: dict[int, ServiceProvider], count: int,
|
||||
) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
|
||||
@_text_cmd
|
||||
async def _cmd_cards(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
tracked_boards = _get_tracked_board_ids(provider, trackers)
|
||||
|
||||
all_cards: list[dict[str, Any]] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider, board_id in tracked_boards:
|
||||
planka = make_planka_provider(http, provider)
|
||||
cards = await planka.client.get_board_cards(board_id, limit=count)
|
||||
lists = await planka.client.get_board_lists(board_id)
|
||||
lists_by_id = {str(lst.get("id", "")): lst.get("name", "") for lst in lists}
|
||||
http = await get_http_session()
|
||||
planka = make_planka_provider(http, provider)
|
||||
boards = await planka.client.get_boards()
|
||||
board_names = {str(b.get("id", "")): b.get("name", "") for b in boards}
|
||||
|
||||
boards = await planka.client.get_boards()
|
||||
board_name = board_id
|
||||
for b in boards:
|
||||
if str(b.get("id", "")) == board_id:
|
||||
board_name = b.get("name", board_id)
|
||||
break
|
||||
for board_id in tracked_boards:
|
||||
cards = await planka.client.get_board_cards(board_id, limit=count)
|
||||
lists = await planka.client.get_board_lists(board_id)
|
||||
lists_by_id = {str(lst.get("id", "")): lst.get("name", "") for lst in lists}
|
||||
board_name = board_names.get(board_id, board_id)
|
||||
|
||||
for card in cards:
|
||||
list_id = str(card.get("listId", ""))
|
||||
all_cards.append({
|
||||
"name": card.get("name", ""),
|
||||
"list_name": lists_by_id.get(list_id, ""),
|
||||
"board_name": board_name,
|
||||
})
|
||||
for card in cards:
|
||||
list_id = str(card.get("listId", ""))
|
||||
all_cards.append({
|
||||
"name": card.get("name", ""),
|
||||
"list_name": lists_by_id.get(list_id, ""),
|
||||
"board_name": board_name,
|
||||
})
|
||||
|
||||
return {"cards": all_cards[:count]}
|
||||
|
||||
|
||||
async def _cmd_lists(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracked_boards = _get_tracked_board_ids(providers_map, trackers)
|
||||
@_text_cmd
|
||||
async def _cmd_lists(provider: ServiceProvider, count: int) -> dict[str, Any]:
|
||||
trackers = await get_trackers_for_provider(provider.id)
|
||||
tracked_boards = _get_tracked_board_ids(provider, trackers)
|
||||
|
||||
all_lists: list[dict[str, Any]] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
for provider, board_id in tracked_boards:
|
||||
planka = make_planka_provider(http, provider)
|
||||
lists = await planka.client.get_board_lists(board_id)
|
||||
http = await get_http_session()
|
||||
planka = make_planka_provider(http, provider)
|
||||
boards = await planka.client.get_boards()
|
||||
board_names = {str(b.get("id", "")): b.get("name", "") for b in boards}
|
||||
|
||||
boards = await planka.client.get_boards()
|
||||
board_name = board_id
|
||||
for b in boards:
|
||||
if str(b.get("id", "")) == board_id:
|
||||
board_name = b.get("name", board_id)
|
||||
break
|
||||
for board_id in tracked_boards:
|
||||
lists = await planka.client.get_board_lists(board_id)
|
||||
board_name = board_names.get(board_id, board_id)
|
||||
|
||||
for lst in lists:
|
||||
all_lists.append({
|
||||
"name": lst.get("name", ""),
|
||||
"board_name": board_name,
|
||||
})
|
||||
for lst in lists:
|
||||
all_lists.append({
|
||||
"name": lst.get("name", ""),
|
||||
"board_name": board_name,
|
||||
})
|
||||
|
||||
return {"lists": all_lists}
|
||||
|
||||
@@ -6,7 +6,6 @@ import hmac
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
@@ -16,6 +15,7 @@ from notify_bridge_core.notifications.telegram.client import TelegramClient
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import TelegramBot, TelegramChat
|
||||
from ..services.telegram import save_chat_from_webhook
|
||||
from .base import CommandResponse
|
||||
from .handler import handle_command, send_media_group, send_reply
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -89,15 +89,13 @@ async def telegram_webhook(
|
||||
return {"ok": True, "skipped": "commands_disabled"}
|
||||
effective_lang = chat_row.language_override or msg_language
|
||||
message_id = message.get("message_id")
|
||||
cmd_response = await handle_command(bot, chat_id, text, language_code=effective_lang)
|
||||
if cmd_response is not None:
|
||||
if isinstance(cmd_response, dict) and "media" in cmd_response:
|
||||
await send_reply(bot.token, chat_id, cmd_response["text"], reply_to_message_id=message_id)
|
||||
await send_media_group(bot.token, chat_id, cmd_response["media"], reply_to_message_id=message_id)
|
||||
elif isinstance(cmd_response, list):
|
||||
await send_media_group(bot.token, chat_id, cmd_response, reply_to_message_id=message_id)
|
||||
else:
|
||||
await send_reply(bot.token, chat_id, cmd_response, reply_to_message_id=message_id)
|
||||
responses = await handle_command(bot, chat_id, text, language_code=effective_lang)
|
||||
if responses:
|
||||
for resp in responses:
|
||||
if resp.text:
|
||||
await send_reply(bot.token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||
if resp.media:
|
||||
await send_media_group(bot.token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||
return {"ok": True}
|
||||
|
||||
return {"ok": True, "skipped": "not_a_command"}
|
||||
@@ -105,13 +103,15 @@ async def telegram_webhook(
|
||||
|
||||
async def register_webhook(bot_token: str, webhook_url: str, secret: str | None = None) -> dict:
|
||||
"""Register webhook URL with Telegram Bot API via TelegramClient."""
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, bot_token)
|
||||
return await client.set_webhook(webhook_url, secret=secret)
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, bot_token)
|
||||
return await client.set_webhook(webhook_url, secret=secret)
|
||||
|
||||
|
||||
async def unregister_webhook(bot_token: str) -> dict:
|
||||
"""Remove webhook from Telegram Bot API via TelegramClient."""
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, bot_token)
|
||||
return await client.delete_webhook()
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, bot_token)
|
||||
return await client.delete_webhook()
|
||||
|
||||
@@ -359,6 +359,7 @@ class NotificationTrackerState(SQLModel, table=True):
|
||||
# Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id
|
||||
tracker_id: int = Field(
|
||||
foreign_key="notification_tracker.id",
|
||||
index=True,
|
||||
sa_column_kwargs={"name": "notification_tracker_id"},
|
||||
)
|
||||
collection_id: str
|
||||
@@ -458,7 +459,7 @@ class CommandTrackerListener(SQLModel, table=True):
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
command_tracker_id: int = Field(
|
||||
foreign_key="command_tracker.id",
|
||||
|
||||
index=True,
|
||||
|
||||
)
|
||||
listener_type: str # e.g. "telegram_bot"
|
||||
|
||||
@@ -73,6 +73,8 @@ async def lifespan(app: FastAPI):
|
||||
await start_scheduler()
|
||||
yield
|
||||
# Graceful shutdown
|
||||
from .services.http_session import close_http_session
|
||||
await close_http_session()
|
||||
scheduler = get_scheduler()
|
||||
if scheduler.running:
|
||||
scheduler.shutdown()
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
"""Shared service utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.providers.immich import ImmichServiceProvider
|
||||
from notify_bridge_core.providers.gitea import GiteaServiceProvider
|
||||
from notify_bridge_core.providers.planka import PlankaServiceProvider
|
||||
@@ -8,8 +14,23 @@ from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvid
|
||||
|
||||
from ..database.models import ServiceProvider
|
||||
|
||||
# Default timeout for all outgoing HTTP requests to external services.
|
||||
DEFAULT_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||
|
||||
def make_immich_provider(http_session, provider: ServiceProvider) -> ImmichServiceProvider:
|
||||
|
||||
class CollectionProvider(Protocol):
|
||||
"""Protocol for providers that can list collections."""
|
||||
|
||||
async def list_collections(self) -> list[dict[str, Any]]: ...
|
||||
|
||||
|
||||
class TestableProvider(Protocol):
|
||||
"""Protocol for providers that support connection testing."""
|
||||
|
||||
async def test_connection(self) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
def make_immich_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> ImmichServiceProvider:
|
||||
"""Create an ImmichServiceProvider from a DB provider model."""
|
||||
config = provider.config or {}
|
||||
return ImmichServiceProvider(
|
||||
@@ -21,7 +42,7 @@ def make_immich_provider(http_session, provider: ServiceProvider) -> ImmichServi
|
||||
)
|
||||
|
||||
|
||||
def make_gitea_provider(http_session, provider: ServiceProvider) -> GiteaServiceProvider:
|
||||
def make_gitea_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> GiteaServiceProvider:
|
||||
"""Create a GiteaServiceProvider from a DB provider model."""
|
||||
config = provider.config or {}
|
||||
return GiteaServiceProvider(
|
||||
@@ -32,7 +53,7 @@ def make_gitea_provider(http_session, provider: ServiceProvider) -> GiteaService
|
||||
)
|
||||
|
||||
|
||||
def make_planka_provider(http_session, provider: ServiceProvider) -> PlankaServiceProvider:
|
||||
def make_planka_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> PlankaServiceProvider:
|
||||
"""Create a PlankaServiceProvider from a DB provider model."""
|
||||
config = provider.config or {}
|
||||
return PlankaServiceProvider(
|
||||
@@ -55,7 +76,7 @@ def make_nut_provider(provider: ServiceProvider) -> NutServiceProvider:
|
||||
)
|
||||
|
||||
|
||||
def make_google_photos_provider(http_session, provider: ServiceProvider) -> GooglePhotosServiceProvider:
|
||||
def make_google_photos_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> GooglePhotosServiceProvider:
|
||||
"""Create a GooglePhotosServiceProvider from a DB provider model."""
|
||||
config = provider.config or {}
|
||||
return GooglePhotosServiceProvider(
|
||||
@@ -65,3 +86,61 @@ def make_google_photos_provider(http_session, provider: ServiceProvider) -> Goog
|
||||
config.get("refresh_token", ""),
|
||||
provider.name,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider factory registry — maps provider type strings to factory callables
|
||||
# that create a provider with a ``list_collections`` method. Providers that
|
||||
# require an API credential skip creation when the credential is missing
|
||||
# (the factory returns None in that case).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_collection_provider(
|
||||
http_session: aiohttp.ClientSession,
|
||||
provider: ServiceProvider,
|
||||
) -> CollectionProvider | None:
|
||||
"""Create a CollectionProvider for the given DB provider, or None if unsupported."""
|
||||
ptype = provider.type
|
||||
config = provider.config or {}
|
||||
|
||||
if ptype == "immich":
|
||||
return make_immich_provider(http_session, provider)
|
||||
if ptype == "gitea":
|
||||
if not config.get("api_token"):
|
||||
return None
|
||||
return make_gitea_provider(http_session, provider)
|
||||
if ptype == "planka":
|
||||
if not config.get("api_key"):
|
||||
return None
|
||||
return make_planka_provider(http_session, provider)
|
||||
if ptype == "google_photos":
|
||||
return make_google_photos_provider(http_session, provider)
|
||||
# NUT provider needs no http_session
|
||||
if ptype == "nut":
|
||||
return make_nut_provider(provider) # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
|
||||
# Set of provider types that need an aiohttp session for collection listing.
|
||||
_HTTP_COLLECTION_PROVIDERS = {"immich", "gitea", "planka", "google_photos"}
|
||||
|
||||
|
||||
async def list_provider_collections(provider: ServiceProvider) -> list[dict[str, Any]]:
|
||||
"""List collections for any supported provider type.
|
||||
|
||||
Returns an empty list for providers that don't support collections or
|
||||
are missing required credentials.
|
||||
"""
|
||||
if provider.type in _HTTP_COLLECTION_PROVIDERS:
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
svc = _make_collection_provider(http_session, provider)
|
||||
if svc is None:
|
||||
return []
|
||||
return await svc.list_collections()
|
||||
|
||||
# Non-HTTP providers (e.g. NUT)
|
||||
svc = _make_collection_provider(None, provider) # type: ignore[arg-type]
|
||||
if svc is None:
|
||||
return []
|
||||
return await svc.list_collections()
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -159,27 +158,28 @@ async def _execute_with_provider(
|
||||
)
|
||||
from notify_bridge_core.providers.immich.client import ImmichClient
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
client = ImmichClient(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
client = ImmichClient(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
)
|
||||
external_domain = provider_config.get("external_domain")
|
||||
if external_domain:
|
||||
client.external_domain = external_domain
|
||||
|
||||
# Verify connectivity
|
||||
if not await client.ping():
|
||||
return ActionResult(
|
||||
success=False,
|
||||
error=f"Cannot connect to Immich server ({provider_name})",
|
||||
)
|
||||
external_domain = provider_config.get("external_domain")
|
||||
if external_domain:
|
||||
client.external_domain = external_domain
|
||||
|
||||
# Verify connectivity
|
||||
if not await client.ping():
|
||||
return ActionResult(
|
||||
success=False,
|
||||
error=f"Cannot connect to Immich server ({provider_name})",
|
||||
)
|
||||
|
||||
executor = ImmichActionExecutor(client)
|
||||
if dry_run:
|
||||
return await executor.dry_run(action_type, rule_configs, action_config)
|
||||
return await executor.execute(action_type, rule_configs, action_config)
|
||||
executor = ImmichActionExecutor(client)
|
||||
if dry_run:
|
||||
return await executor.dry_run(action_type, rule_configs, action_config)
|
||||
return await executor.execute(action_type, rule_configs, action_config)
|
||||
|
||||
return ActionResult(
|
||||
success=False,
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Application-level shared aiohttp.ClientSession.
|
||||
|
||||
All outgoing HTTP requests in the server package should use the shared
|
||||
session returned by ``get_http_session()`` instead of creating
|
||||
per-request ``aiohttp.ClientSession`` instances. This keeps a single
|
||||
TCP connection pool alive for the lifetime of the process, avoiding
|
||||
the overhead of pool creation/teardown on every request.
|
||||
|
||||
Call ``close_http_session()`` once during application shutdown.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
|
||||
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||
_session: aiohttp.ClientSession | None = None
|
||||
|
||||
|
||||
async def get_http_session() -> aiohttp.ClientSession:
|
||||
"""Get or create the shared HTTP session."""
|
||||
global _session
|
||||
if _session is None or _session.closed:
|
||||
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
|
||||
return _session
|
||||
|
||||
|
||||
async def close_http_session() -> None:
|
||||
"""Close the shared HTTP session (call on app shutdown)."""
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
await _session.close()
|
||||
_session = None
|
||||
@@ -3,8 +3,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -90,19 +88,21 @@ async def _send_telegram_broadcast(target: NotificationTarget, message: str, rec
|
||||
if not receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
|
||||
results: list[dict] = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
client = TelegramClient(session, bot_token)
|
||||
for recv in receivers:
|
||||
chat_id = recv.get("chat_id")
|
||||
if not chat_id:
|
||||
continue
|
||||
result = await client.send_message(
|
||||
chat_id=str(chat_id),
|
||||
text=message,
|
||||
disable_web_page_preview=bool(disable_preview),
|
||||
)
|
||||
results.append(result)
|
||||
client = TelegramClient(http, bot_token)
|
||||
for recv in receivers:
|
||||
chat_id = recv.get("chat_id")
|
||||
if not chat_id:
|
||||
continue
|
||||
result = await client.send_message(
|
||||
chat_id=str(chat_id),
|
||||
text=message,
|
||||
disable_web_page_preview=bool(disable_preview),
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return _aggregate(results)
|
||||
|
||||
@@ -113,15 +113,17 @@ async def _send_webhook_broadcast(target: NotificationTarget, message: str, rece
|
||||
if not receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
|
||||
results: list[dict] = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for recv in receivers:
|
||||
url = recv.get("url")
|
||||
headers = recv.get("headers", {})
|
||||
if not url:
|
||||
continue
|
||||
client = WebhookClient(session, url, headers)
|
||||
results.append(await client.send({"message": message, "event_type": "notification"}))
|
||||
for recv in receivers:
|
||||
url = recv.get("url")
|
||||
headers = recv.get("headers", {})
|
||||
if not url:
|
||||
continue
|
||||
client = WebhookClient(http, url, headers)
|
||||
results.append(await client.send({"message": message, "event_type": "notification"}))
|
||||
|
||||
return _aggregate(results)
|
||||
|
||||
@@ -178,22 +180,24 @@ async def _send_webhook_like_broadcast(target: NotificationTarget, message: str,
|
||||
if not receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
|
||||
results: list[dict] = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if target.type == "discord":
|
||||
from notify_bridge_core.notifications.discord.client import DiscordClient
|
||||
client = DiscordClient(session)
|
||||
for recv in receivers:
|
||||
url = recv.get("webhook_url")
|
||||
if url:
|
||||
results.append(await client.send(url, message, username=target.config.get("username")))
|
||||
elif target.type == "slack":
|
||||
from notify_bridge_core.notifications.slack.client import SlackClient
|
||||
client = SlackClient(session)
|
||||
for recv in receivers:
|
||||
url = recv.get("webhook_url")
|
||||
if url:
|
||||
results.append(await client.send(url, message, username=target.config.get("username")))
|
||||
if target.type == "discord":
|
||||
from notify_bridge_core.notifications.discord.client import DiscordClient
|
||||
client = DiscordClient(http)
|
||||
for recv in receivers:
|
||||
url = recv.get("webhook_url")
|
||||
if url:
|
||||
results.append(await client.send(url, message, username=target.config.get("username")))
|
||||
elif target.type == "slack":
|
||||
from notify_bridge_core.notifications.slack.client import SlackClient
|
||||
client = SlackClient(http)
|
||||
for recv in receivers:
|
||||
url = recv.get("webhook_url")
|
||||
if url:
|
||||
results.append(await client.send(url, message, username=target.config.get("username")))
|
||||
|
||||
return _aggregate(results)
|
||||
|
||||
@@ -207,18 +211,20 @@ async def _send_ntfy_broadcast(target: NotificationTarget, message: str, receive
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
from notify_bridge_core.notifications.ntfy.client import NtfyClient
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
|
||||
results: list[dict] = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
client = NtfyClient(session)
|
||||
for recv in receivers:
|
||||
topic = recv.get("topic")
|
||||
if topic:
|
||||
results.append(await client.send(
|
||||
server_url, topic, message,
|
||||
title="Notify Bridge",
|
||||
priority=recv.get("priority", 3),
|
||||
auth_token=auth_token,
|
||||
))
|
||||
client = NtfyClient(http)
|
||||
for recv in receivers:
|
||||
topic = recv.get("topic")
|
||||
if topic:
|
||||
results.append(await client.send(
|
||||
server_url, topic, message,
|
||||
title="Notify Bridge",
|
||||
priority=recv.get("priority", 3),
|
||||
auth_token=auth_token,
|
||||
))
|
||||
|
||||
return _aggregate(results)
|
||||
|
||||
@@ -243,13 +249,15 @@ async def _send_matrix_broadcast(target: NotificationTarget, message: str, recei
|
||||
if not receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
|
||||
results: list[dict] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = MatrixClient(http, homeserver, access_token)
|
||||
for recv in receivers:
|
||||
room_id = recv.get("room_id")
|
||||
if room_id:
|
||||
results.append(await client.send_message(room_id, message, html_message=message))
|
||||
client = MatrixClient(http, homeserver, access_token)
|
||||
for recv in receivers:
|
||||
room_id = recv.get("room_id")
|
||||
if room_id:
|
||||
results.append(await client.send_message(room_id, message, html_message=message))
|
||||
|
||||
return _aggregate(results)
|
||||
|
||||
|
||||
@@ -31,11 +31,50 @@ async def start_scheduler() -> None:
|
||||
from .telegram_poller import start_command_listener_polling
|
||||
await start_command_listener_polling()
|
||||
|
||||
# Schedule daily cleanup of old event log entries
|
||||
_schedule_event_cleanup()
|
||||
|
||||
# Start debounced command auto-sync scheduler
|
||||
from .command_sync import start_sync_scheduler
|
||||
start_sync_scheduler()
|
||||
|
||||
|
||||
def _schedule_event_cleanup() -> None:
|
||||
"""Schedule a daily job to delete EventLog entries older than 90 days."""
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
scheduler = get_scheduler()
|
||||
job_id = "cleanup_old_events"
|
||||
if scheduler.get_job(job_id):
|
||||
return
|
||||
scheduler.add_job(
|
||||
_cleanup_old_events,
|
||||
CronTrigger(hour=3, minute=0),
|
||||
id=job_id,
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Scheduled daily event log cleanup at 03:00 UTC")
|
||||
|
||||
|
||||
async def _cleanup_old_events() -> None:
|
||||
"""Delete EventLog entries older than 90 days."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlmodel import delete
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import EventLog
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
await session.exec(delete(EventLog).where(EventLog.created_at < cutoff))
|
||||
await session.commit()
|
||||
_LOGGER.info("Cleaned up event log entries older than %s", cutoff.date())
|
||||
|
||||
|
||||
async def _load_tracker_jobs() -> None:
|
||||
"""Load enabled trackers and schedule polling jobs."""
|
||||
from sqlmodel import select
|
||||
@@ -50,13 +89,16 @@ async def _load_tracker_jobs() -> None:
|
||||
result = await session.exec(select(NotificationTracker).where(NotificationTracker.enabled == True))
|
||||
trackers = result.all()
|
||||
|
||||
# Pre-load provider types for scheduler detection
|
||||
# Batch-load provider types for scheduler detection
|
||||
unique_provider_ids = list({t.provider_id for t in trackers})
|
||||
provider_types: dict[int, str] = {}
|
||||
for tracker in trackers:
|
||||
if tracker.provider_id not in provider_types:
|
||||
provider = await session.get(ServiceProviderModel, tracker.provider_id)
|
||||
if provider:
|
||||
provider_types[tracker.provider_id] = provider.type
|
||||
if unique_provider_ids:
|
||||
provider_result = await session.exec(
|
||||
select(ServiceProviderModel).where(
|
||||
ServiceProviderModel.id.in_(unique_provider_ids)
|
||||
)
|
||||
)
|
||||
provider_types = {p.id: p.type for p in provider_result.all()}
|
||||
|
||||
for tracker in trackers:
|
||||
job_id = f"tracker_{tracker.id}"
|
||||
@@ -86,6 +128,7 @@ async def _load_tracker_jobs() -> None:
|
||||
id=job_id,
|
||||
args=[tracker.id],
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Scheduled tracker %d (%s) every %ds", tracker.id, tracker.name, tracker.scan_interval)
|
||||
|
||||
@@ -106,6 +149,7 @@ def _add_cron_job(
|
||||
id=job_id,
|
||||
args=[tracker_id],
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Scheduled tracker %d (%s) with cron: %s", tracker_id, tracker_name, cron_expression)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -47,10 +46,18 @@ async def _get_bot_ids_with_active_listeners() -> set[int]:
|
||||
listeners = result.all()
|
||||
|
||||
active_bot_ids: set[int] = set()
|
||||
for listener in listeners:
|
||||
tracker = await session.get(CommandTracker, listener.command_tracker_id)
|
||||
if tracker and tracker.enabled:
|
||||
active_bot_ids.add(listener.listener_id)
|
||||
tracker_ids = list({l.command_tracker_id for l in listeners})
|
||||
if tracker_ids:
|
||||
tracker_result = await session.exec(
|
||||
select(CommandTracker).where(
|
||||
CommandTracker.id.in_(tracker_ids),
|
||||
CommandTracker.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
enabled_tracker_ids = {t.id for t in tracker_result.all()}
|
||||
for listener in listeners:
|
||||
if listener.command_tracker_id in enabled_tracker_ids:
|
||||
active_bot_ids.add(listener.listener_id)
|
||||
|
||||
return active_bot_ids
|
||||
|
||||
@@ -145,21 +152,23 @@ async def _poll_bot(bot_id: int) -> None:
|
||||
if not bot or bot.update_mode != "polling":
|
||||
unschedule_bot_polling(bot_id)
|
||||
return
|
||||
# Extract what we need before closing session
|
||||
# Copy attributes before session closes to avoid detached-instance errors
|
||||
from types import SimpleNamespace
|
||||
bot_token = bot.token
|
||||
bot_obj = bot
|
||||
bot_obj = SimpleNamespace(id=bot.id, name=bot.name, token=bot.token)
|
||||
|
||||
offset = _last_update_id.get(bot_id, 0)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, bot_token)
|
||||
result = await client.get_updates(
|
||||
offset=offset + 1 if offset else None, limit=50,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return
|
||||
updates = result.get("result", [])
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, bot_token)
|
||||
result = await client.get_updates(
|
||||
offset=offset + 1 if offset else None, limit=50,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return
|
||||
updates = result.get("result", [])
|
||||
except Exception as e:
|
||||
_LOGGER.debug("Polling error for bot %d: %s", bot_id, e)
|
||||
return
|
||||
@@ -209,17 +218,13 @@ async def _poll_bot(bot_id: int) -> None:
|
||||
continue
|
||||
effective_lang = chat_row.language_override or msg_language
|
||||
message_id = message.get("message_id")
|
||||
cmd_response = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
|
||||
if cmd_response is not None:
|
||||
if isinstance(cmd_response, dict) and "media" in cmd_response:
|
||||
# Text + media: send text first, media as reply
|
||||
from ..commands.handler import send_reply as _reply
|
||||
await _reply(bot_token, chat_id, cmd_response["text"], reply_to_message_id=message_id)
|
||||
await send_media_group(bot_token, chat_id, cmd_response["media"], reply_to_message_id=message_id)
|
||||
elif isinstance(cmd_response, list):
|
||||
await send_media_group(bot_token, chat_id, cmd_response, reply_to_message_id=message_id)
|
||||
else:
|
||||
await send_reply(bot_token, chat_id, cmd_response, reply_to_message_id=message_id)
|
||||
responses = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
|
||||
if responses:
|
||||
for resp in responses:
|
||||
if resp.text:
|
||||
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||
if resp.media:
|
||||
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||
except Exception:
|
||||
_LOGGER.error("Error handling command from bot %d", bot_id, exc_info=True)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ objects and dispatches through the same path the watcher uses.
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -183,58 +182,59 @@ async def _build_immich_event(
|
||||
memory_source = getattr(tracking_config, "memory_source", "albums") if tracking_config else "albums"
|
||||
is_memory = test_type == "memory"
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = ImmichServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
provider_config.get("external_domain"),
|
||||
provider_name,
|
||||
)
|
||||
if not await immich.connect():
|
||||
return None
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
immich = ImmichServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
provider_config.get("external_domain"),
|
||||
provider_name,
|
||||
)
|
||||
if not await immich.connect():
|
||||
return None
|
||||
|
||||
# Native Immich memories API path
|
||||
if is_memory and memory_source == "native":
|
||||
return await _build_native_memory_event(
|
||||
immich, ext_domain, provider_name, tracker_name,
|
||||
collection_ids, limit, asset_type, favorite_only, min_rating,
|
||||
)
|
||||
|
||||
# Album-based path: use shared collect_scheduled_assets
|
||||
albums: dict[str, ImmichAlbumData] = {}
|
||||
shared_links: dict[str, list[SharedLinkInfo]] = {}
|
||||
for album_id in collection_ids:
|
||||
album = await immich.client.get_album(album_id)
|
||||
if album:
|
||||
albums[album_id] = album
|
||||
shared_links[album_id] = await immich.client.get_shared_links(album_id)
|
||||
|
||||
assets, collections_extra = collect_scheduled_assets(
|
||||
albums, shared_links, ext_domain,
|
||||
limit=limit,
|
||||
asset_type=asset_type,
|
||||
favorite_only=favorite_only,
|
||||
min_rating=min_rating,
|
||||
is_memory=is_memory,
|
||||
# Native Immich memories API path
|
||||
if is_memory and memory_source == "native":
|
||||
return await _build_native_memory_event(
|
||||
immich, ext_domain, provider_name, tracker_name,
|
||||
collection_ids, limit, asset_type, favorite_only, min_rating,
|
||||
)
|
||||
|
||||
first_col = collections_extra[0] if collections_extra else {}
|
||||
return ServiceEvent(
|
||||
event_type=EventType.SCHEDULED_MESSAGE,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
provider_name=provider_name,
|
||||
collection_id=collection_ids[0] if collection_ids else "",
|
||||
collection_name=first_col.get("name", tracker_name),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
added_assets=assets,
|
||||
added_count=len(assets),
|
||||
extra={
|
||||
"collections": collections_extra,
|
||||
"albums": collections_extra,
|
||||
**(first_col if first_col else {}),
|
||||
},
|
||||
)
|
||||
# Album-based path: use shared collect_scheduled_assets
|
||||
albums: dict[str, ImmichAlbumData] = {}
|
||||
shared_links: dict[str, list[SharedLinkInfo]] = {}
|
||||
for album_id in collection_ids:
|
||||
album = await immich.client.get_album(album_id)
|
||||
if album:
|
||||
albums[album_id] = album
|
||||
shared_links[album_id] = await immich.client.get_shared_links(album_id)
|
||||
|
||||
assets, collections_extra = collect_scheduled_assets(
|
||||
albums, shared_links, ext_domain,
|
||||
limit=limit,
|
||||
asset_type=asset_type,
|
||||
favorite_only=favorite_only,
|
||||
min_rating=min_rating,
|
||||
is_memory=is_memory,
|
||||
)
|
||||
|
||||
first_col = collections_extra[0] if collections_extra else {}
|
||||
return ServiceEvent(
|
||||
event_type=EventType.SCHEDULED_MESSAGE,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
provider_name=provider_name,
|
||||
collection_id=collection_ids[0] if collection_ids else "",
|
||||
collection_name=first_col.get("name", tracker_name),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
added_assets=assets,
|
||||
added_count=len(assets),
|
||||
extra={
|
||||
"collections": collections_extra,
|
||||
"albums": collections_extra,
|
||||
**(first_col if first_col else {}),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _build_native_memory_event(
|
||||
|
||||
@@ -6,7 +6,6 @@ import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -102,19 +101,20 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
|
||||
if provider_type == "immich":
|
||||
from notify_bridge_core.providers.immich import ImmichServiceProvider
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = ImmichServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
provider_config.get("external_domain"),
|
||||
provider_name,
|
||||
)
|
||||
connected = await immich.connect()
|
||||
if not connected:
|
||||
return {"status": "error", "reason": "failed to connect to provider"}
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
immich = ImmichServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
provider_config.get("external_domain"),
|
||||
provider_name,
|
||||
)
|
||||
connected = await immich.connect()
|
||||
if not connected:
|
||||
return {"status": "error", "reason": "failed to connect to provider"}
|
||||
|
||||
events, new_state = await immich.poll(collection_ids, state_dict)
|
||||
events, new_state = await immich.poll(collection_ids, state_dict)
|
||||
elif provider_type == "gitea":
|
||||
# Gitea is webhook-based — events arrive via /api/webhooks/gitea endpoint.
|
||||
# The scheduler still calls check_tracker but there's nothing to poll.
|
||||
@@ -143,18 +143,22 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
events, new_state = await nut.poll(collection_ids, state_dict)
|
||||
elif provider_type == "google_photos":
|
||||
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gp = GooglePhotosServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("client_id", ""),
|
||||
provider_config.get("client_secret", ""),
|
||||
provider_config.get("refresh_token", ""),
|
||||
provider_name,
|
||||
)
|
||||
connected = await gp.connect()
|
||||
if not connected:
|
||||
return {"status": "error", "reason": "failed to connect to Google Photos"}
|
||||
events, new_state = await gp.poll(collection_ids, state_dict)
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
gp = GooglePhotosServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("client_id", ""),
|
||||
provider_config.get("client_secret", ""),
|
||||
provider_config.get("refresh_token", ""),
|
||||
provider_name,
|
||||
)
|
||||
connected = await gp.connect()
|
||||
if not connected:
|
||||
return {"status": "error", "reason": "failed to connect to Google Photos"}
|
||||
events, new_state = await gp.poll(collection_ids, state_dict)
|
||||
elif provider_type == "webhook":
|
||||
# Webhook providers receive events via inbound HTTP; no polling needed.
|
||||
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
|
||||
else:
|
||||
return {"status": "error", "reason": f"unsupported provider type: {provider_type}"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user