diff --git a/.gitea/workflows/release.yml b/.gitea/workflows/release.yml index 3f08ee2..239cc7e 100644 --- a/.gitea/workflows/release.yml +++ b/.gitea/workflows/release.yml @@ -43,6 +43,17 @@ jobs: cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max + - name: Trigger Portainer redeploy + continue-on-error: true + run: | + if [ -n "${{ secrets.DOCKER_REDEPLOY_WEBHOOK_URL }}" ]; then + echo "Triggering Portainer redeploy..." + curl -sf -X POST "${{ secrets.DOCKER_REDEPLOY_WEBHOOK_URL }}" \ + --max-time 30 || echo "::warning::Portainer webhook failed" + else + echo "DOCKER_REDEPLOY_WEBHOOK_URL not set — skipping auto-deploy" + fi + - name: Generate changelog id: changelog run: | @@ -56,7 +67,29 @@ jobs: - name: Create Gitea Release run: | - BODY=$(cat /tmp/changelog.txt | python3 -c "import sys,json; print(json.dumps(sys.stdin.read()))") + if [ -f RELEASE_NOTES.md ]; then + export RELEASE_NOTES=$(cat RELEASE_NOTES.md) + echo "Found RELEASE_NOTES.md" + else + export RELEASE_NOTES="" + echo "No RELEASE_NOTES.md found" + fi + + BODY=$(python3 -c " + import json, os, sys + + release_notes = os.environ.get('RELEASE_NOTES', '') + changelog = open('/tmp/changelog.txt').read().strip() + + sections = [] + if release_notes.strip(): + sections.append(release_notes.strip()) + if changelog: + sections.append('## Changelog\n\n' + changelog) + + print(json.dumps('\n\n'.join(sections))) + ") + curl -s -X POST \ "https://${{ env.REGISTRY }}/api/v1/repos/${{ env.IMAGE_NAME }}/releases" \ -H "Authorization: token ${{ secrets.RELEASE_TOKEN }}" \ diff --git a/frontend/src/lib/components/Button.svelte b/frontend/src/lib/components/Button.svelte new file mode 100644 index 0000000..fe647e7 --- /dev/null +++ b/frontend/src/lib/components/Button.svelte @@ -0,0 +1,85 @@ + + +{#if href && !disabled} + + {@render children()} + +{:else} + +{/if} + + diff --git a/frontend/src/lib/components/EntitySelect.svelte b/frontend/src/lib/components/EntitySelect.svelte index 49b8ec4..99e2412 100644 --- a/frontend/src/lib/components/EntitySelect.svelte +++ b/frontend/src/lib/components/EntitySelect.svelte @@ -1,5 +1,6 @@ + +{#if message} +
+ {message} +
+{/if} diff --git a/frontend/src/lib/components/IconGridSelect.svelte b/frontend/src/lib/components/IconGridSelect.svelte index 8d06f68..566ebf5 100644 --- a/frontend/src/lib/components/IconGridSelect.svelte +++ b/frontend/src/lib/components/IconGridSelect.svelte @@ -1,5 +1,6 @@ - + {#if showEmailForm} - {#if error}
{error}
{/if} +
diff --git a/frontend/src/routes/bots/MatrixBotTab.svelte b/frontend/src/routes/bots/MatrixBotTab.svelte index 3332155..d1d7e68 100644 --- a/frontend/src/routes/bots/MatrixBotTab.svelte +++ b/frontend/src/routes/bots/MatrixBotTab.svelte @@ -10,6 +10,8 @@ import ConfirmModal from '$lib/components/ConfirmModal.svelte'; import IconButton from '$lib/components/IconButton.svelte'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; + import Button from '$lib/components/Button.svelte'; + import ErrorBanner from '$lib/components/ErrorBanner.svelte'; import type { MatrixBot } from '$lib/types'; let { onreload }: { onreload: () => Promise } = $props(); @@ -70,22 +72,21 @@ try { const res = await api(`/matrix-bots/${botId}/test`, { method: 'POST' }); if (res.success) snackSuccess(t('snack.matrixBotTestOk')); - else snackError(res.error || 'Failed'); + else snackError(res.error || t('matrixBot.operationFailed')); } catch (err: any) { snackError(err.message); } matrixTesting = { ...matrixTesting, [botId]: false }; } - + {#if showMatrixForm} - {#if error}
{error}
{/if} +
diff --git a/frontend/src/routes/bots/TelegramBotTab.svelte b/frontend/src/routes/bots/TelegramBotTab.svelte index 1f2b843..5e65221 100644 --- a/frontend/src/routes/bots/TelegramBotTab.svelte +++ b/frontend/src/routes/bots/TelegramBotTab.svelte @@ -12,6 +12,8 @@ import IconButton from '$lib/components/IconButton.svelte'; import EntitySelect from '$lib/components/EntitySelect.svelte'; import { snackSuccess, snackError, snackInfo } from '$lib/stores/snackbar.svelte'; + import Button from '$lib/components/Button.svelte'; + import ErrorBanner from '$lib/components/ErrorBanner.svelte'; import type { TelegramBot, TelegramChat } from '$lib/types'; interface CommandTrackerSummary { id: number; name: string; icon?: string; enabled: boolean } @@ -186,7 +188,7 @@ try { const res = await api(`/telegram-bots/${botId}/sync-commands`, { method: 'POST' }); if (res.success) snackSuccess(t('telegramBot.commandsSynced')); - else snackError(res.error || 'Failed'); + else snackError(res.error || t('telegramBot.saveFailed')); } catch (err: any) { snackError(err.message); } modeChanging = { ...modeChanging, [botId]: false }; } @@ -218,7 +220,7 @@ snackSuccess(res.verified ? t('telegramBot.webhookVerified') : t('telegramBot.webhookRegistered')); await loadWebhookStatus(botId); } else { - snackError(res.error || 'Failed to register webhook'); + snackError(res.error || t('telegramBot.webhookFailed')); } } catch (err: any) { snackError(err.message); } modeChanging = { ...modeChanging, [botId]: false }; @@ -229,7 +231,7 @@ try { const res = await api(`/telegram-bots/${botId}/webhook/unregister`, { method: 'POST' }); if (res.success) { snackSuccess(t('telegramBot.webhookUnregistered')); await loadWebhookStatus(botId); } - else snackError(res.error || 'Failed'); + else snackError(res.error || t('telegramBot.saveFailed')); } catch (err: any) { snackError(err.message); } modeChanging = { ...modeChanging, [botId]: false }; } @@ -260,7 +262,7 @@ try { const res = await api(`/telegram-bots/${botId}/chats/${chatId}/test?locale=${getLocale()}`, { method: 'POST' }); if (res.success) snackSuccess(t('snack.targetTestSent')); - else snackError(res.error || 'Failed'); + else snackError(res.error || t('telegramBot.saveFailed')); } catch (err: any) { snackError(err.message); } chatTesting = { ...chatTesting, [key]: false }; } @@ -277,15 +279,14 @@ - + {#if showForm} - {#if error}
{error}
{/if} +
diff --git a/frontend/src/routes/command-configs/+page.svelte b/frontend/src/routes/command-configs/+page.svelte index 29ab61f..97b7e63 100644 --- a/frontend/src/routes/command-configs/+page.svelte +++ b/frontend/src/routes/command-configs/+page.svelte @@ -15,6 +15,7 @@ import IconGridSelect from '$lib/components/IconGridSelect.svelte'; import { providerTypeItems, providerTypeFilterItems, responseModeItems } from '$lib/grid-items'; import EntitySelect from '$lib/components/EntitySelect.svelte'; + import Button from '$lib/components/Button.svelte'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; import { highlightFromUrl } from '$lib/highlight'; import { globalProviderFilter } from '$lib/stores/provider-filter.svelte'; @@ -37,7 +38,7 @@ let cmdTemplateConfigs = $derived(commandTemplateConfigsCache.items); const templateItems = $derived(cmdTemplateConfigs .filter((c) => c.provider_type === form.provider_type) - .map((c) => ({ value: c.id, label: c.name + (c.user_id === 0 ? ' (System)' : ''), icon: c.icon || 'mdiCodeBracesBox', desc: c.provider_type })) + .map((c) => ({ value: c.id, label: c.name + (c.user_id === 0 ? t('common.systemSuffix') : ''), icon: c.icon || 'mdiCodeBracesBox', desc: c.provider_type })) ); let loaded = $state(false); let showForm = $state(false); @@ -151,10 +152,9 @@ - + {#if !loaded}{:else} diff --git a/frontend/src/routes/login/+page.svelte b/frontend/src/routes/login/+page.svelte index e694309..1528fe9 100644 --- a/frontend/src/routes/login/+page.svelte +++ b/frontend/src/routes/login/+page.svelte @@ -32,7 +32,7 @@ await login(username, password); window.location.href = '/'; } catch (err: any) { - error = err.message || 'Login failed'; + error = err.message || t('auth.loginFailed'); } submitting = false; } diff --git a/frontend/src/routes/notification-trackers/+page.svelte b/frontend/src/routes/notification-trackers/+page.svelte index cc5e29b..a368328 100644 --- a/frontend/src/routes/notification-trackers/+page.svelte +++ b/frontend/src/routes/notification-trackers/+page.svelte @@ -17,6 +17,8 @@ import { providerDefaultIcon } from '$lib/grid-items'; import { globalProviderFilter } from '$lib/stores/provider-filter.svelte'; import { getDescriptor } from '$lib/providers'; + import Button from '$lib/components/Button.svelte'; + import ErrorBanner from '$lib/components/ErrorBanner.svelte'; import type { Tracker, TrackerTarget, TrackingConfig, TemplateConfig, NotificationTarget } from '$lib/types'; import TrackerForm from './TrackerForm.svelte'; @@ -119,7 +121,7 @@ capabilitiesCache.fetch(), ]); } catch (err: any) { - loadError = err.message || 'Failed to load data'; + loadError = err.message || t('common.loadFailed'); snackError(loadError); } finally { loaded = true; highlightFromUrl(); } } @@ -212,7 +214,7 @@ } } } - if (created > 0) snackSuccess(`Created ${created} public link(s)`); + if (created > 0) snackSuccess(t('notificationTracker.createdLinks').replace('{count}', String(created))); linkWarning = null; linkCreating = false; await doSave(); @@ -361,17 +363,16 @@ - + {#if !loaded} {:else if loadError} -
{loadError}
+
{:else if showForm} - + {#if !loaded} @@ -158,9 +159,7 @@ {#if showForm}
- {#if error} -
{error}
- {/if} +
@@ -211,10 +210,9 @@

{t('providers.webhookUrlHint')}

{/if} - +
diff --git a/frontend/src/routes/providers/new/+page.svelte b/frontend/src/routes/providers/new/+page.svelte index 784dfb6..ee3407c 100644 --- a/frontend/src/routes/providers/new/+page.svelte +++ b/frontend/src/routes/providers/new/+page.svelte @@ -5,8 +5,11 @@ import Card from '$lib/components/Card.svelte'; import IconPicker from '$lib/components/IconPicker.svelte'; import IconGridSelect from '$lib/components/IconGridSelect.svelte'; + import { goto } from '$app/navigation'; import { providerTypeItems } from '$lib/grid-items'; import { getDescriptor, buildProviderFormDefaults } from '$lib/providers'; + import Button from '$lib/components/Button.svelte'; + import ErrorBanner from '$lib/components/ErrorBanner.svelte'; let form = $state(buildProviderFormDefaults()); let error = $state(''); @@ -16,7 +19,7 @@ async function testAndSave() { const desc = descriptor; - if (!desc) { error = 'Select a provider type'; return; } + if (!desc) { error = t('providers.selectType'); return; } const { config, error: buildError } = desc.buildConfig(form, false); if (buildError) { error = t(buildError); snackError(error); return; } @@ -32,22 +35,22 @@ if (!result.ok) { await api(`/providers/${provider.id}`, { method: 'DELETE' }).catch(() => {}); createdId = null; - error = result.message || 'Connection test failed'; + error = result.message || t('providers.testFailed'); snackError(error); } else { snackSuccess(t('snack.providerSaved')); - window.location.href = '/providers'; + goto('/providers'); } } catch (e: any) { if (createdId) await api(`/providers/${createdId}`, { method: 'DELETE' }).catch(() => {}); - error = e.message || 'Test failed'; snackError(error); + error = e.message || t('providers.testFailed'); snackError(error); } finally { testing = false; } } async function saveWithoutTest() { const desc = descriptor; - if (!desc) { error = 'Select a provider type'; return; } + if (!desc) { error = t('providers.selectType'); return; } const { config, error: buildError } = desc.buildConfig(form, false); if (buildError) { error = t(buildError); snackError(error); return; } @@ -58,8 +61,8 @@ body: JSON.stringify({ type: form.type, name: form.name || desc.defaultName, icon: form.icon, config }), }); snackSuccess(t('snack.providerSaved')); - window.location.href = '/providers'; - } catch (e: any) { error = e.message || 'Save failed'; snackError(error); } + goto('/providers'); + } catch (e: any) { error = e.message || t('common.saveFailed'); snackError(error); } finally { saving = false; } } @@ -112,22 +115,18 @@
{/each} - {#if error} -

{error}

- {/if} +
- - + - + +
diff --git a/frontend/src/routes/settings/+page.svelte b/frontend/src/routes/settings/+page.svelte index 0953678..f552715 100644 --- a/frontend/src/routes/settings/+page.svelte +++ b/frontend/src/routes/settings/+page.svelte @@ -7,10 +7,12 @@ import Loading from '$lib/components/Loading.svelte'; import MdiIcon from '$lib/components/MdiIcon.svelte'; import Hint from '$lib/components/Hint.svelte'; + import ErrorBanner from '$lib/components/ErrorBanner.svelte'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; let loaded = $state(false); let saving = $state(false); + let error = $state(''); let settings = $state({ external_url: '', telegram_webhook_secret: '', @@ -20,16 +22,16 @@ onMount(async () => { try { settings = await api('/settings'); - } catch (err: any) { snackError(err.message); } + } catch (err: any) { error = err.message; snackError(err.message); } finally { loaded = true; } }); async function save() { - saving = true; + saving = true; error = ''; try { settings = await api('/settings', { method: 'PUT', body: JSON.stringify(settings) }); snackSuccess(t('settings.saved')); - } catch (err: any) { snackError(err.message); } + } catch (err: any) { error = err.message; snackError(err.message); } saving = false; } @@ -39,6 +41,7 @@ {#if !loaded} {:else} +
diff --git a/frontend/src/routes/setup/+page.svelte b/frontend/src/routes/setup/+page.svelte index 2d30dd6..7910aa1 100644 --- a/frontend/src/routes/setup/+page.svelte +++ b/frontend/src/routes/setup/+page.svelte @@ -25,7 +25,7 @@ try { await setup(username, password); window.location.href = '/'; - } catch (err: any) { error = err.message || 'Setup failed'; } + } catch (err: any) { error = err.message || t('auth.setupFailed'); } submitting = false; } diff --git a/frontend/src/routes/targets/+page.svelte b/frontend/src/routes/targets/+page.svelte index 979a9dd..e9ed685 100644 --- a/frontend/src/routes/targets/+page.svelte +++ b/frontend/src/routes/targets/+page.svelte @@ -15,6 +15,7 @@ import { chatActionItems } from '$lib/grid-items'; import { snackSuccess, snackError } from '$lib/stores/snackbar.svelte'; import { highlightFromUrl } from '$lib/highlight'; + import ErrorBanner from '$lib/components/ErrorBanner.svelte'; import type { NotificationTarget, TargetReceiver, TelegramChat } from '$lib/types'; import TargetForm from './TargetForm.svelte'; @@ -419,7 +420,7 @@ {#if !loaded}{:else} {#if loadError} -
{loadError}
+ {/if} {#if showForm} diff --git a/packages/core/src/notify_bridge_core/notifications/dispatcher.py b/packages/core/src/notify_bridge_core/notifications/dispatcher.py index d7ce5f7..b13ae25 100644 --- a/packages/core/src/notify_bridge_core/notifications/dispatcher.py +++ b/packages/core/src/notify_bridge_core/notifications/dispatcher.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import logging from dataclasses import dataclass, field from typing import Any @@ -68,14 +69,17 @@ class NotificationDispatcher: Returns list of results (one per target). """ + raw_results = await asyncio.gather( + *[self._send_to_target(event, t) for t in targets], + return_exceptions=True, + ) results = [] - for target in targets: - try: - result = await self._send_to_target(event, target) - results.append(result) - except Exception as e: - _LOGGER.error("Failed to dispatch to target: %s", e) - results.append({"success": False, "error": str(e)}) + for raw in raw_results: + if isinstance(raw, Exception): + _LOGGER.error("Failed to dispatch to target: %s", raw) + results.append({"success": False, "error": str(raw)}) + else: + results.append(raw) return results def _resolve_template( diff --git a/packages/core/src/notify_bridge_core/providers/gitea/client.py b/packages/core/src/notify_bridge_core/providers/gitea/client.py index a3b7a79..1818cb8 100644 --- a/packages/core/src/notify_bridge_core/providers/gitea/client.py +++ b/packages/core/src/notify_bridge_core/providers/gitea/client.py @@ -85,6 +85,20 @@ class GiteaClient: return repos + async def get_repo(self, owner: str, repo: str) -> dict[str, Any] | None: + """Fetch a single repository by owner/repo name.""" + try: + async with self._session.get( + f"{self._url}/api/v1/repos/{owner}/{repo}", + headers=self._headers, + ) as response: + if response.status == 200: + return await response.json() + _LOGGER.warning("Failed to fetch repo %s/%s: HTTP %s", owner, repo, response.status) + except aiohttp.ClientError as err: + _LOGGER.warning("Failed to fetch repo %s/%s: %s", owner, repo, err) + return None + async def get_repo_issues( self, owner: str, repo: str, state: str = "open", limit: int = 10, ) -> list[dict[str, Any]]: diff --git a/packages/core/src/notify_bridge_core/providers/nut/client.py b/packages/core/src/notify_bridge_core/providers/nut/client.py index b63c71a..c1a8dbe 100644 --- a/packages/core/src/notify_bridge_core/providers/nut/client.py +++ b/packages/core/src/notify_bridge_core/providers/nut/client.py @@ -14,12 +14,28 @@ _DEFAULT_PORT = 3493 _READ_TIMEOUT = 10.0 _CONNECT_TIMEOUT = 5.0 +# Allowed characters for NUT protocol identifiers (UPS names, variable names). +# Prevents command injection via newlines or special characters. +_SAFE_NAME_RE = re.compile(r"^[\w.\-]+$") + # Regex to parse VAR lines: VAR "" _VAR_RE = re.compile(r'^VAR\s+(\S+)\s+(\S+)\s+"(.*)"$') # Regex to parse UPS lines: UPS "" _UPS_RE = re.compile(r'^UPS\s+(\S+)\s+"(.*)"$') +def _validate_name(value: str, label: str) -> None: + """Validate that *value* is a safe NUT protocol identifier. + + Raises ``NutClientError`` if *value* contains characters outside + ``[\\w.\\-]``, which could be used for protocol command injection. + """ + if not _SAFE_NAME_RE.match(value): + raise NutClientError( + f"Invalid {label}: {value!r} contains disallowed characters" + ) + + class NutClientError(Exception): """Error communicating with NUT server.""" @@ -91,6 +107,7 @@ class NutClient: async def list_var(self, ups_name: str) -> dict[str, str]: """Get all variables for a UPS device.""" + _validate_name(ups_name, "UPS name") lines = await self._list_command(f"LIST VAR {ups_name}") variables: dict[str, str] = {} for line in lines: @@ -101,6 +118,8 @@ class NutClient: async def get_var(self, ups_name: str, var_name: str) -> str: """Get a single variable value.""" + _validate_name(ups_name, "UPS name") + _validate_name(var_name, "variable name") response = await self._command(f"GET VAR {ups_name} {var_name}") m = _VAR_RE.match(response) if m: diff --git a/packages/core/src/notify_bridge_core/templates/renderer.py b/packages/core/src/notify_bridge_core/templates/renderer.py index 8b302dc..6fd3914 100644 --- a/packages/core/src/notify_bridge_core/templates/renderer.py +++ b/packages/core/src/notify_bridge_core/templates/renderer.py @@ -10,7 +10,7 @@ from jinja2.sandbox import SandboxedEnvironment _LOGGER = logging.getLogger(__name__) -_env = SandboxedEnvironment(autoescape=False) +_env = SandboxedEnvironment(autoescape=True) def render_template(template_str: str, context: dict[str, Any]) -> str: diff --git a/packages/server/src/notify_bridge_server/api/action_rules.py b/packages/server/src/notify_bridge_server/api/action_rules.py index 72c24fb..fa302c1 100644 --- a/packages/server/src/notify_bridge_server/api/action_rules.py +++ b/packages/server/src/notify_bridge_server/api/action_rules.py @@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from ..auth.dependencies import get_current_user from ..database.engine import get_session from ..database.models import Action, ActionRule, User +from .helpers import get_owned_entity _LOGGER = logging.getLogger(__name__) @@ -59,10 +60,9 @@ def _rule_response(rule: ActionRule) -> dict: async def _get_user_action( session: AsyncSession, action_id: int, user: User ) -> Action: - action = await session.get(Action, action_id) - if not action or action.user_id != user.id: - raise HTTPException(status_code=404, detail="Action not found") - return action + return await get_owned_entity( + session, Action, action_id, user.id, not_found_msg="Action not found", + ) # --------------------------------------------------------------------------- diff --git a/packages/server/src/notify_bridge_server/api/command_template_configs.py b/packages/server/src/notify_bridge_server/api/command_template_configs.py index 5bf25bb..6c1f7b6 100644 --- a/packages/server/src/notify_bridge_server/api/command_template_configs.py +++ b/packages/server/src/notify_bridge_server/api/command_template_configs.py @@ -12,12 +12,10 @@ from pydantic import BaseModel from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession -from jinja2.sandbox import SandboxedEnvironment -from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined - from ..auth.dependencies import get_current_user from ..database.engine import get_session from ..database.models import CommandTemplateConfig, CommandTemplateSlot, User +from .slot_helpers import load_slots, render_template_preview, save_slots _LOGGER = logging.getLogger(__name__) @@ -44,38 +42,11 @@ class CommandTemplateConfigUpdate(BaseModel): # --------------------------------------------------------------------------- async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]: - """Load slots as {slot_name: {locale: template}}.""" - result = await session.exec( - select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == config_id) - ) - nested: dict[str, dict[str, str]] = {} - for s in result.all(): - nested.setdefault(s.slot_name, {})[s.locale] = s.template - return nested + return await load_slots(session, CommandTemplateSlot, config_id) async def _save_slots(session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]]) -> None: - """Save slots from {slot_name: {locale: template}} format.""" - for slot_name, locale_map in slots.items(): - for locale, template_text in locale_map.items(): - result = await session.exec( - select(CommandTemplateSlot).where( - CommandTemplateSlot.config_id == config_id, - CommandTemplateSlot.slot_name == slot_name, - CommandTemplateSlot.locale == locale, - ) - ) - existing = result.first() - if existing: - existing.template = template_text - session.add(existing) - else: - session.add(CommandTemplateSlot( - config_id=config_id, - slot_name=slot_name, - locale=locale, - template=template_text, - )) + await save_slots(session, CommandTemplateSlot, config_id, slots) async def _response(session: AsyncSession, c: CommandTemplateConfig) -> dict[str, Any]: @@ -367,18 +338,4 @@ async def preview_raw( "wait": 15, } - try: - env = SandboxedEnvironment(autoescape=False) - env.from_string(body.template) - except TemplateSyntaxError as e: - return {"rendered": None, "error": e.message, "error_line": e.lineno} - - try: - strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined) - tmpl = strict_env.from_string(body.template) - rendered = tmpl.render(**sample_ctx) - return {"rendered": rendered} - except UndefinedError as e: - return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"} - except Exception as e: - return {"rendered": None, "error": str(e), "error_line": None} + return render_template_preview(body.template, sample_ctx) diff --git a/packages/server/src/notify_bridge_server/api/command_trackers.py b/packages/server/src/notify_bridge_server/api/command_trackers.py index 6891af0..262c60b 100644 --- a/packages/server/src/notify_bridge_server/api/command_trackers.py +++ b/packages/server/src/notify_bridge_server/api/command_trackers.py @@ -17,6 +17,7 @@ from ..database.models import ( TelegramBot, User, ) +from .helpers import get_owned_entity _LOGGER = logging.getLogger(__name__) @@ -401,7 +402,7 @@ async def _listener_response(session: AsyncSession, l: CommandTrackerListener) - async def _get_user_tracker( session: AsyncSession, tracker_id: int, user_id: int ) -> CommandTracker: - tracker = await session.get(CommandTracker, tracker_id) - if not tracker or tracker.user_id != user_id: - raise HTTPException(status_code=404, detail="Command tracker not found") - return tracker + return await get_owned_entity( + session, CommandTracker, tracker_id, user_id, + not_found_msg="Command tracker not found", + ) diff --git a/packages/server/src/notify_bridge_server/api/email_bots.py b/packages/server/src/notify_bridge_server/api/email_bots.py index 4bf3c8a..a2a9443 100644 --- a/packages/server/src/notify_bridge_server/api/email_bots.py +++ b/packages/server/src/notify_bridge_server/api/email_bots.py @@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from ..auth.dependencies import get_current_user from ..database.engine import get_session from ..database.models import EmailBot, User +from .helpers import get_owned_entity _LOGGER = logging.getLogger(__name__) @@ -156,7 +157,6 @@ def _response(bot: EmailBot) -> dict: async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> EmailBot: - bot = await session.get(EmailBot, bot_id) - if not bot or bot.user_id != user_id: - raise HTTPException(status_code=404, detail="Email bot not found") - return bot + return await get_owned_entity( + session, EmailBot, bot_id, user_id, not_found_msg="Email bot not found", + ) diff --git a/packages/server/src/notify_bridge_server/api/helpers.py b/packages/server/src/notify_bridge_server/api/helpers.py new file mode 100644 index 0000000..17f7314 --- /dev/null +++ b/packages/server/src/notify_bridge_server/api/helpers.py @@ -0,0 +1,25 @@ +"""Shared helpers for API route modules.""" + +from typing import TypeVar + +from fastapi import HTTPException +from sqlmodel import SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession + +T = TypeVar("T", bound=SQLModel) + + +async def get_owned_entity( + session: AsyncSession, + model: type[T], + entity_id: int, + user_id: int, + *, + owner_field: str = "user_id", + not_found_msg: str = "Not found", +) -> T: + """Fetch an entity by PK and verify ownership, or raise 404.""" + entity = await session.get(model, entity_id) + if not entity or getattr(entity, owner_field) != user_id: + raise HTTPException(status_code=404, detail=not_found_msg) + return entity diff --git a/packages/server/src/notify_bridge_server/api/matrix_bots.py b/packages/server/src/notify_bridge_server/api/matrix_bots.py index 6d4b170..b1dd953 100644 --- a/packages/server/src/notify_bridge_server/api/matrix_bots.py +++ b/packages/server/src/notify_bridge_server/api/matrix_bots.py @@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from ..auth.dependencies import get_current_user from ..database.engine import get_session from ..database.models import MatrixBot, User +from .helpers import get_owned_entity _LOGGER = logging.getLogger(__name__) @@ -108,33 +109,34 @@ async def test_matrix_bot( bot = await _get_user_bot(session, bot_id, user.id) import aiohttp - async with aiohttp.ClientSession() as http: - # Verify token with /whoami - whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami" - headers = {"Authorization": f"Bearer {bot.access_token}"} - try: - async with http.get(whoami_url, headers=headers) as resp: - if resp.status != 200: - body = await resp.text() - return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"} - whoami = await resp.json() - except aiohttp.ClientError as e: - return {"success": False, "error": f"Connection failed: {e}"} + from ..services.http_session import get_http_session + http = await get_http_session() + # Verify token with /whoami + whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami" + headers = {"Authorization": f"Bearer {bot.access_token}"} + try: + async with http.get(whoami_url, headers=headers) as resp: + if resp.status != 200: + body = await resp.text() + return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"} + whoami = await resp.json() + except aiohttp.ClientError as e: + return {"success": False, "error": f"Connection failed: {e}"} - result = {"success": True, "user_id": whoami.get("user_id", "")} + result = {"success": True, "user_id": whoami.get("user_id", "")} - # Optionally send a test message - if room_id: - from notify_bridge_core.notifications.matrix.client import MatrixClient - client = MatrixClient(http, bot.homeserver_url, bot.access_token) - send_result = await client.send_message( - room_id, - "Test message from Notify Bridge", - html_message="Test message from Notify Bridge", - ) - result["send_result"] = send_result + # Optionally send a test message + if room_id: + from notify_bridge_core.notifications.matrix.client import MatrixClient + client = MatrixClient(http, bot.homeserver_url, bot.access_token) + send_result = await client.send_message( + room_id, + "Test message from Notify Bridge", + html_message="Test message from Notify Bridge", + ) + result["send_result"] = send_result - return result + return result def _response(bot: MatrixBot) -> dict: @@ -150,7 +152,6 @@ def _response(bot: MatrixBot) -> dict: async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> MatrixBot: - bot = await session.get(MatrixBot, bot_id) - if not bot or bot.user_id != user_id: - raise HTTPException(status_code=404, detail="Matrix bot not found") - return bot + return await get_owned_entity( + session, MatrixBot, bot_id, user_id, not_found_msg="Matrix bot not found", + ) diff --git a/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py b/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py index 95acae3..5967d6f 100644 --- a/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py +++ b/packages/server/src/notify_bridge_server/api/notification_tracker_targets.py @@ -23,6 +23,7 @@ from ..database.models import ( ) from ..services.notifier import send_test_notification from ..services.test_dispatch import dispatch_test_notification +from .helpers import get_owned_entity _LOGGER = logging.getLogger(__name__) @@ -277,7 +278,7 @@ async def _tt_response(session: AsyncSession, tt: NotificationTrackerTarget) -> async def _get_user_tracker( session: AsyncSession, tracker_id: int, user_id: int ) -> NotificationTracker: - tracker = await session.get(NotificationTracker, tracker_id) - if not tracker or tracker.user_id != user_id: - raise HTTPException(status_code=404, detail="Tracker not found") - return tracker + return await get_owned_entity( + session, NotificationTracker, tracker_id, user_id, + not_found_msg="Tracker not found", + ) diff --git a/packages/server/src/notify_bridge_server/api/notification_trackers.py b/packages/server/src/notify_bridge_server/api/notification_trackers.py index 2edd383..bca3181 100644 --- a/packages/server/src/notify_bridge_server/api/notification_trackers.py +++ b/packages/server/src/notify_bridge_server/api/notification_trackers.py @@ -18,6 +18,7 @@ from ..database.models import ( User, ) from ..services.scheduler import schedule_tracker, unschedule_tracker +from .helpers import get_owned_entity from .notification_tracker_targets import _tt_response _LOGGER = logging.getLogger(__name__) @@ -205,7 +206,7 @@ async def _tracker_response(session: AsyncSession, t: NotificationTracker) -> di async def _get_user_tracker( session: AsyncSession, tracker_id: int, user_id: int ) -> NotificationTracker: - tracker = await session.get(NotificationTracker, tracker_id) - if not tracker or tracker.user_id != user_id: - raise HTTPException(status_code=404, detail="Tracker not found") - return tracker + return await get_owned_entity( + session, NotificationTracker, tracker_id, user_id, + not_found_msg="Tracker not found", + ) diff --git a/packages/server/src/notify_bridge_server/api/providers.py b/packages/server/src/notify_bridge_server/api/providers.py index 3f9f1c2..6a47e6e 100644 --- a/packages/server/src/notify_bridge_server/api/providers.py +++ b/packages/server/src/notify_bridge_server/api/providers.py @@ -13,7 +13,12 @@ import aiohttp from ..auth.dependencies import get_current_user from ..database.engine import get_session from ..database.models import ServiceProvider, User -from ..services import make_immich_provider, make_gitea_provider, make_planka_provider, make_nut_provider, make_google_photos_provider +from ..services import ( + make_immich_provider, make_gitea_provider, make_planka_provider, + make_nut_provider, make_google_photos_provider, list_provider_collections, +) +from ..services.http_session import get_http_session +from .helpers import get_owned_entity _LOGGER = logging.getLogger(__name__) @@ -82,6 +87,20 @@ class GooglePhotosProviderConfig(BaseModel): refresh_token: str +class PayloadMapping(BaseModel): + variable: str + jsonpath: str + default: str | None = None + + +class WebhookProviderConfig(BaseModel): + auth_mode: str = "none" + webhook_secret: str | None = None + payload_mappings: list[PayloadMapping] = [] + event_type_path: str | None = None + collection_path: str | None = None + + _PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = { "immich": ImmichProviderConfig, "gitea": GiteaProviderConfig, @@ -89,6 +108,7 @@ _PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = { "scheduler": SchedulerProviderConfig, "nut": NutProviderConfig, "google_photos": GooglePhotosProviderConfig, + "webhook": WebhookProviderConfig, } @@ -106,6 +126,70 @@ def _validate_provider_config(provider_type: str, config: dict[str, Any]) -> Non ) +async def _test_provider_connection(provider: ServiceProvider) -> dict[str, Any]: + """Test provider connection and return the result dict. + + For providers that lack optional credentials (gitea without api_token, + planka without api_key), returns a success stub. + """ + http_session = await get_http_session() + + if provider.type == "immich": + immich = make_immich_provider(http_session, provider) + return await immich.test_connection() + + if provider.type == "gitea": + if not provider.config.get("api_token"): + return {"ok": True, "message": "Gitea webhook-only mode (no API token for testing)"} + gitea = make_gitea_provider(http_session, provider) + return await gitea.test_connection() + + if provider.type == "planka": + if not provider.config.get("api_key"): + return {"ok": True, "message": "Planka webhook-only mode (no API key for testing)"} + planka = make_planka_provider(http_session, provider) + return await planka.test_connection() + + if provider.type == "nut": + nut = make_nut_provider(provider) + return await nut.test_connection() + + if provider.type == "google_photos": + gp = make_google_photos_provider(http_session, provider) + return await gp.test_connection() + + if provider.type in ("scheduler", "webhook"): + return {"ok": True, "message": "Virtual provider — always available"} + + return {"ok": False, "message": f"Unknown provider type: {provider.type}"} + + +async def _validate_provider_connection(provider: ServiceProvider) -> dict[str, Any]: + """Test provider connection. Raise HTTPException on failure. + + Returns the test_result dict on success (caller may inspect extra fields + like ``external_domain``). + """ + try: + test_result = await _test_provider_connection(provider) + except aiohttp.ClientError as err: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Connection error: {err}", + ) + except OSError as err: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Connection error: {err}", + ) + if not test_result.get("ok"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=test_result.get("message", f"Cannot connect to {provider.type} provider"), + ) + return test_result + + @router.get("") async def list_providers( user: User = Depends(get_current_user), @@ -128,96 +212,15 @@ async def create_provider( """Add a new service provider (validates connection for known types).""" _validate_provider_config(body.type, body.config) - # Validate connection for known provider types - try: - if body.type == "immich": - from notify_bridge_core.providers.immich import ImmichServiceProvider - config = body.config - async with aiohttp.ClientSession() as http_session: - immich = ImmichServiceProvider( - http_session, config.get("url", ""), config.get("api_key", ""), - config.get("external_domain"), body.name, - ) - test_result = await immich.test_connection() - if not test_result.get("ok"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=test_result.get("message", f"Cannot connect to {body.type} provider"), - ) - # Store external_domain from server config if available - if test_result.get("external_domain"): - config["external_domain"] = test_result["external_domain"] + # Build a temporary ServiceProvider for connection testing + temp_provider = ServiceProvider( + id=0, user_id=0, type=body.type, name=body.name, config=body.config, + ) + test_result = await _validate_provider_connection(temp_provider) - elif body.type == "gitea": - config = body.config - # api_token is optional (webhook_secret is required, but token only for repo listing) - if config.get("api_token"): - async with aiohttp.ClientSession() as http_session: - from notify_bridge_core.providers.gitea import GiteaServiceProvider - gitea = GiteaServiceProvider( - http_session, config.get("url", ""), config.get("api_token", ""), body.name, - ) - test_result = await gitea.test_connection() - if not test_result.get("ok"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=test_result.get("message", "Cannot connect to Gitea"), - ) - - elif body.type == "planka": - config = body.config - if config.get("api_key"): - async with aiohttp.ClientSession() as http_session: - from notify_bridge_core.providers.planka import PlankaServiceProvider - planka = PlankaServiceProvider( - http_session, config.get("url", ""), config.get("api_key", ""), body.name, - ) - test_result = await planka.test_connection() - if not test_result.get("ok"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=test_result.get("message", "Cannot connect to Planka"), - ) - - elif body.type == "nut": - nut = make_nut_provider(ServiceProvider( - id=0, user_id=0, type="nut", name=body.name, config=body.config, - )) - test_result = await nut.test_connection() - if not test_result.get("ok"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=test_result.get("message", "Cannot connect to NUT server"), - ) - - elif body.type == "google_photos": - config = body.config - async with aiohttp.ClientSession() as http_session: - from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider - gp = GooglePhotosServiceProvider( - http_session, config.get("client_id", ""), config.get("client_secret", ""), - config.get("refresh_token", ""), body.name, - ) - test_result = await gp.test_connection() - if not test_result.get("ok"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=test_result.get("message", "Cannot connect to Google Photos"), - ) - except HTTPException: - raise - except aiohttp.ClientError as err: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Connection error: {err}", - ) - except OSError as err: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Connection error: {err}", - ) - - # Scheduler: no validation needed (virtual provider) + # Store external_domain from Immich server config if available + if test_result.get("external_domain"): + body.config["external_domain"] = test_result["external_domain"] provider = ServiceProvider( user_id=user.id, @@ -307,78 +310,10 @@ async def update_provider( provider.config = body.config # Re-validate connection when config changes for known provider types - if config_changed and provider.type == "immich": - try: - async with aiohttp.ClientSession() as http_session: - immich = make_immich_provider(http_session, provider) - test_result = await immich.test_connection() - if not test_result.get("ok"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=test_result.get("message", f"Cannot connect to {provider.type} provider"), - ) - if test_result.get("external_domain"): - provider.config = {**provider.config, "external_domain": test_result["external_domain"]} - except aiohttp.ClientError as err: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Connection error: {err}", - ) - elif config_changed and provider.type == "gitea": - if provider.config.get("api_token"): - try: - async with aiohttp.ClientSession() as http_session: - gitea = make_gitea_provider(http_session, provider) - test_result = await gitea.test_connection() - if not test_result.get("ok"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=test_result.get("message", "Cannot connect to Gitea"), - ) - except aiohttp.ClientError as err: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Connection error: {err}", - ) - elif config_changed and provider.type == "planka": - if provider.config.get("api_key"): - try: - async with aiohttp.ClientSession() as http_session: - planka = make_planka_provider(http_session, provider) - test_result = await planka.test_connection() - if not test_result.get("ok"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=test_result.get("message", "Cannot connect to Planka"), - ) - except aiohttp.ClientError as err: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Connection error: {err}", - ) - elif config_changed and provider.type == "nut": - nut = make_nut_provider(provider) - test_result = await nut.test_connection() - if not test_result.get("ok"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=test_result.get("message", "Cannot connect to NUT server"), - ) - elif config_changed and provider.type == "google_photos": - try: - async with aiohttp.ClientSession() as http_session: - gp = make_google_photos_provider(http_session, provider) - test_result = await gp.test_connection() - if not test_result.get("ok"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=test_result.get("message", "Cannot connect to Google Photos"), - ) - except aiohttp.ClientError as err: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Connection error: {err}", - ) + if config_changed: + test_result = await _validate_provider_connection(provider) + if test_result.get("external_domain"): + provider.config = {**provider.config, "external_domain": test_result["external_domain"]} session.add(provider) await session.commit() @@ -408,39 +343,7 @@ async def test_provider( ): """Check if a service provider is reachable.""" provider = await _get_user_provider(session, provider_id, user.id) - - if provider.type == "immich": - async with aiohttp.ClientSession() as http_session: - immich = make_immich_provider(http_session, provider) - return await immich.test_connection() - - if provider.type == "gitea": - if not provider.config.get("api_token"): - return {"ok": True, "message": "Gitea webhook-only mode (no API token for testing)"} - async with aiohttp.ClientSession() as http_session: - gitea = make_gitea_provider(http_session, provider) - return await gitea.test_connection() - - if provider.type == "planka": - if not provider.config.get("api_key"): - return {"ok": True, "message": "Planka webhook-only mode (no API key for testing)"} - async with aiohttp.ClientSession() as http_session: - planka = make_planka_provider(http_session, provider) - return await planka.test_connection() - - if provider.type == "scheduler": - return {"ok": True, "message": "Virtual provider — always available"} - - if provider.type == "nut": - nut = make_nut_provider(provider) - return await nut.test_connection() - - if provider.type == "google_photos": - async with aiohttp.ClientSession() as http_session: - gp = make_google_photos_provider(http_session, provider) - return await gp.test_connection() - - return {"ok": False, "message": f"Unknown provider type: {provider.type}"} + return await _test_provider_connection(provider) @router.get("/{provider_id}/people") @@ -454,14 +357,14 @@ async def list_people( if provider.type == "immich": from notify_bridge_core.providers.immich.client import ImmichClient - async with aiohttp.ClientSession() as http_session: - client = ImmichClient( - http_session, - provider.config.get("url", ""), - provider.config.get("api_key", ""), - ) - people = await client.get_people() - return [{"id": pid, "name": name} for pid, name in people.items()] + http_session = await get_http_session() + client = ImmichClient( + http_session, + provider.config.get("url", ""), + provider.config.get("api_key", ""), + ) + people = await client.get_people() + return [{"id": pid, "name": name} for pid, name in people.items()] return [] @@ -475,35 +378,7 @@ async def list_collections( """Fetch collections from a service provider.""" provider = await _get_user_provider(session, provider_id, user.id) - if provider.type == "immich": - async with aiohttp.ClientSession() as http_session: - immich = make_immich_provider(http_session, provider) - return await immich.list_collections() - - if provider.type == "gitea": - if not provider.config.get("api_token"): - return [] - async with aiohttp.ClientSession() as http_session: - gitea = make_gitea_provider(http_session, provider) - return await gitea.list_collections() - - if provider.type == "planka": - if not provider.config.get("api_key"): - return [] - async with aiohttp.ClientSession() as http_session: - planka = make_planka_provider(http_session, provider) - return await planka.list_collections() - - if provider.type == "nut": - nut = make_nut_provider(provider) - return await nut.list_collections() - - if provider.type == "google_photos": - async with aiohttp.ClientSession() as http_session: - gp = make_google_photos_provider(http_session, provider) - return await gp.list_collections() - - return [] + return await list_provider_collections(provider) @router.get("/{provider_id}/albums/{album_id}/shared-links") @@ -517,19 +392,19 @@ async def get_album_shared_links( provider = await _get_user_provider(session, provider_id, user.id) if provider.type == "immich": - async with aiohttp.ClientSession() as http_session: - immich = make_immich_provider(http_session, provider) - links = await immich.client.get_shared_links(album_id) - return [ - { - "id": link.id, - "key": link.key, - "has_password": link.has_password, - "is_expired": link.is_expired, - "is_accessible": link.is_accessible, - } - for link in links - ] + http_session = await get_http_session() + immich = make_immich_provider(http_session, provider) + links = await immich.client.get_shared_links(album_id) + return [ + { + "id": link.id, + "key": link.key, + "has_password": link.has_password, + "is_expired": link.is_expired, + "is_accessible": link.is_accessible, + } + for link in links + ] return [] @@ -545,15 +420,13 @@ async def create_album_shared_link( provider = await _get_user_provider(session, provider_id, user.id) if provider.type == "immich": - async with aiohttp.ClientSession() as http_session: - immich = make_immich_provider(http_session, provider) - success = await immich.client.create_shared_link(album_id) - if success: - return {"success": True} - from fastapi import HTTPException - raise HTTPException(status_code=400, detail="Failed to create shared link") + http_session = await get_http_session() + immich = make_immich_provider(http_session, provider) + success = await immich.client.create_shared_link(album_id) + if success: + return {"success": True} + raise HTTPException(status_code=400, detail="Failed to create shared link") - from fastapi import HTTPException raise HTTPException(status_code=400, detail="Provider type does not support shared links") @@ -580,7 +453,7 @@ async def _get_user_provider( session: AsyncSession, provider_id: int, user_id: int ) -> ServiceProvider: """Get a provider owned by the user, or raise 404.""" - provider = await session.get(ServiceProvider, provider_id) - if not provider or provider.user_id != user_id: - raise HTTPException(status_code=404, detail="Provider not found") - return provider + return await get_owned_entity( + session, ServiceProvider, provider_id, user_id, + not_found_msg="Provider not found", + ) diff --git a/packages/server/src/notify_bridge_server/api/slot_helpers.py b/packages/server/src/notify_bridge_server/api/slot_helpers.py new file mode 100644 index 0000000..1d975c0 --- /dev/null +++ b/packages/server/src/notify_bridge_server/api/slot_helpers.py @@ -0,0 +1,90 @@ +"""Shared slot load/save and Jinja2 preview helpers for template config APIs.""" + +from typing import TypeVar + +from jinja2 import StrictUndefined, TemplateSyntaxError, UndefinedError +from jinja2.sandbox import SandboxedEnvironment +from sqlmodel import SQLModel, select +from sqlmodel.ext.asyncio.session import AsyncSession + +S = TypeVar("S", bound=SQLModel) + + +async def load_slots( + session: AsyncSession, + slot_model: type[S], + config_id: int, +) -> dict[str, dict[str, str]]: + """Load all template slots for a config as {slot_name: {locale: template}}. + + Works for both TemplateSlot and CommandTemplateSlot — they share the same + column names (config_id, slot_name, locale, template). + """ + result = await session.exec( + select(slot_model).where(slot_model.config_id == config_id) # type: ignore[attr-defined] + ) + slots: dict[str, dict[str, str]] = {} + for s in result.all(): + slots.setdefault(s.slot_name, {})[s.locale] = s.template # type: ignore[attr-defined] + return slots + + +async def save_slots( + session: AsyncSession, + slot_model: type[S], + config_id: int, + slots: dict[str, dict[str, str]], +) -> None: + """Create or update template slots for a config (locale-aware). + + Works for both TemplateSlot and CommandTemplateSlot. + """ + for slot_name, locale_map in slots.items(): + for locale, template_text in locale_map.items(): + result = await session.exec( + select(slot_model).where( + slot_model.config_id == config_id, # type: ignore[attr-defined] + slot_model.slot_name == slot_name, # type: ignore[attr-defined] + slot_model.locale == locale, # type: ignore[attr-defined] + ) + ) + existing = result.first() + if existing: + existing.template = template_text # type: ignore[attr-defined] + session.add(existing) + else: + session.add(slot_model( + config_id=config_id, + slot_name=slot_name, + locale=locale, + template=template_text, + )) + + +def render_template_preview(template: str, context: dict) -> dict: + """Two-pass Jinja2 render: syntax check, then strict render. + + Returns a dict with either ``{"rendered": str}`` on success, or + ``{"rendered": None, "error": str, ...}`` on failure. + """ + # Pass 1: syntax check (default Undefined — catches parse errors only) + try: + env = SandboxedEnvironment(autoescape=False) + env.from_string(template) + except TemplateSyntaxError as e: + return { + "rendered": None, + "error": e.message, + "error_line": e.lineno, + } + + # Pass 2: render with StrictUndefined to catch unknown variables + try: + strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined) + tmpl = strict_env.from_string(template) + rendered = tmpl.render(**context) + return {"rendered": rendered} + except UndefinedError as e: + return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"} + except Exception as e: + return {"rendered": None, "error": str(e), "error_line": None} diff --git a/packages/server/src/notify_bridge_server/api/status.py b/packages/server/src/notify_bridge_server/api/status.py index dbf89e6..9fb254e 100644 --- a/packages/server/src/notify_bridge_server/api/status.py +++ b/packages/server/src/notify_bridge_server/api/status.py @@ -112,8 +112,16 @@ async def get_nav_counts( user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session), ): - """Return entity counts for sidebar navigation badges.""" - counts = {} + """Return entity counts for sidebar navigation badges. + + Note: queries run sequentially because SQLAlchemy AsyncSession is NOT safe + for concurrent use within a single session (no asyncio.gather). We + minimise round-trips by combining user + system counts and per-type + target counts into single aggregate queries where possible. + """ + counts: dict[str, int] = {} + + # --- 1) User-owned entity counts (one query per model) --- for model, key in [ (ServiceProvider, "providers"), (NotificationTracker, "notification_trackers"), @@ -132,7 +140,7 @@ async def get_nav_counts( )).one() counts[key] = count - # System-owned entities (user_id=0) count as well + # --- 2) Add system-owned counts (user_id=0) for shared entities --- for model, key in [ (TemplateConfig, "template_configs"), (CommandTemplateConfig, "command_template_configs"), @@ -144,15 +152,22 @@ async def get_nav_counts( )).one() counts[key] += system_count - # Per-type target counts for nav badges - for target_type in ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"): - type_count = (await session.exec( - select(func.count()).select_from(NotificationTarget).where( - NotificationTarget.user_id == user.id, - NotificationTarget.type == target_type, - ) - )).one() - counts[f"targets_{target_type}"] = type_count + # --- 3) Per-type target counts in a single query using conditional aggregation --- + target_types = ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix") + type_counts_result = (await session.exec( + select( + NotificationTarget.type, + func.count(), + ) + .where( + NotificationTarget.user_id == user.id, + NotificationTarget.type.in_(target_types), + ) + .group_by(NotificationTarget.type) + )).all() + type_counts_map = dict(type_counts_result) + for target_type in target_types: + counts[f"targets_{target_type}"] = type_counts_map.get(target_type, 0) return counts diff --git a/packages/server/src/notify_bridge_server/api/target_receivers.py b/packages/server/src/notify_bridge_server/api/target_receivers.py index fd06ef9..2ae6919 100644 --- a/packages/server/src/notify_bridge_server/api/target_receivers.py +++ b/packages/server/src/notify_bridge_server/api/target_receivers.py @@ -12,6 +12,7 @@ from ..auth.dependencies import get_current_user from ..database.engine import get_session from ..database.models import NotificationTarget, TargetReceiver, User from ..services.notifier import send_to_receiver +from .helpers import get_owned_entity _LOGGER = logging.getLogger(__name__) @@ -170,7 +171,7 @@ def _response(r: TargetReceiver) -> dict: async def _get_user_target(session: AsyncSession, target_id: int, user_id: int) -> NotificationTarget: - target = await session.get(NotificationTarget, target_id) - if not target or target.user_id != user_id: - raise HTTPException(status_code=404, detail="Target not found") - return target + return await get_owned_entity( + session, NotificationTarget, target_id, user_id, + not_found_msg="Target not found", + ) diff --git a/packages/server/src/notify_bridge_server/api/targets.py b/packages/server/src/notify_bridge_server/api/targets.py index 4fabb25..c6d7944 100644 --- a/packages/server/src/notify_bridge_server/api/targets.py +++ b/packages/server/src/notify_bridge_server/api/targets.py @@ -12,6 +12,7 @@ from ..auth.dependencies import get_current_user from ..database.engine import get_session from ..database.models import NotificationTarget, NotificationTrackerTarget, TargetReceiver, TelegramBot, TelegramChat, User from ..services.notifier import send_test_notification +from .helpers import get_owned_entity from .target_receivers import _receiver_key _LOGGER = logging.getLogger(__name__) @@ -306,8 +307,15 @@ async def _validate_broadcast_children( return if exclude_target_id and exclude_target_id in child_ids: raise HTTPException(status_code=400, detail="A broadcast target cannot include itself") + + # Batch-load all children in a single IN query instead of N+1 individual fetches + children = (await session.exec( + select(NotificationTarget).where(NotificationTarget.id.in_(child_ids)) + )).all() + children_by_id = {c.id: c for c in children} + for child_id in child_ids: - child = await session.get(NotificationTarget, child_id) + child = children_by_id.get(child_id) if not child or child.user_id != user_id: raise HTTPException(status_code=400, detail=f"Child target {child_id} not found") if child.type == "broadcast": @@ -378,7 +386,7 @@ def _safe_config(target: NotificationTarget) -> dict: async def _get_user_target( session: AsyncSession, target_id: int, user_id: int ) -> NotificationTarget: - target = await session.get(NotificationTarget, target_id) - if not target or target.user_id != user_id: - raise HTTPException(status_code=404, detail="Target not found") - return target + return await get_owned_entity( + session, NotificationTarget, target_id, user_id, + not_found_msg="Target not found", + ) diff --git a/packages/server/src/notify_bridge_server/api/telegram_bots.py b/packages/server/src/notify_bridge_server/api/telegram_bots.py index b13dda2..f0370c4 100644 --- a/packages/server/src/notify_bridge_server/api/telegram_bots.py +++ b/packages/server/src/notify_bridge_server/api/telegram_bots.py @@ -7,8 +7,6 @@ from pydantic import BaseModel from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession -import aiohttp - from notify_bridge_core.notifications.telegram.client import TelegramClient from ..auth.dependencies import get_current_user @@ -19,6 +17,7 @@ from ..database.models import AppSetting, NotificationTarget, TargetReceiver, Te from ..services.notifier import _get_test_message from ..services.telegram_poller import schedule_bot_polling, unschedule_bot_polling from .app_settings import get_setting +from .helpers import get_owned_entity _LOGGER = logging.getLogger(__name__) @@ -290,10 +289,11 @@ async def test_chat( ): """Send a test message to a chat via the bot.""" bot = await _get_user_bot(session, bot_id, user.id) + from ..services.http_session import get_http_session message = _get_test_message(locale, "telegram") - async with aiohttp.ClientSession() as http: - client = TelegramClient(http, bot.token) - return await client.send_message(chat_id, message) + http = await get_http_session() + client = TelegramClient(http, bot.token) + return await client.send_message(chat_id, message) class ChatUpdate(BaseModel): @@ -344,41 +344,44 @@ async def delete_chat( async def _get_webhook_info(token: str) -> dict | None: """Call Telegram getWebhookInfo via TelegramClient.""" - async with aiohttp.ClientSession() as http: - client = TelegramClient(http, token) - result = await client.get_webhook_info() - return result.get("result") if result.get("success") else None + from ..services.http_session import get_http_session + http = await get_http_session() + client = TelegramClient(http, token) + result = await client.get_webhook_info() + return result.get("result") if result.get("success") else None async def _get_me(token: str) -> dict | None: """Call Telegram getMe via TelegramClient.""" - async with aiohttp.ClientSession() as http: - client = TelegramClient(http, token) - result = await client.get_me() - return result.get("result") if result.get("success") else None + from ..services.http_session import get_http_session + http = await get_http_session() + client = TelegramClient(http, token) + result = await client.get_me() + return result.get("result") if result.get("success") else None async def _fetch_chats_from_telegram(token: str) -> list[dict]: """Fetch chats from Telegram getUpdates via TelegramClient.""" - async with aiohttp.ClientSession() as http: - client = TelegramClient(http, token) - result = await client.get_updates(limit=100) - if not result.get("success"): - return [] + from ..services.http_session import get_http_session + http = await get_http_session() + client = TelegramClient(http, token) + result = await client.get_updates(limit=100) + if not result.get("success"): + return [] - seen: dict[int, dict] = {} - for update in result.get("result", []): - msg = update.get("message", {}) - chat = msg.get("chat", {}) - chat_id = chat.get("id") - if chat_id and chat_id not in seen: - seen[chat_id] = { - "id": chat_id, - "title": chat.get("title") or (chat.get("first_name", "") + (" " + chat.get("last_name", "")).strip()), - "type": chat.get("type", "private"), - "username": chat.get("username", ""), - } - return list(seen.values()) + seen: dict[int, dict] = {} + for update in result.get("result", []): + msg = update.get("message", {}) + chat = msg.get("chat", {}) + chat_id = chat.get("id") + if chat_id and chat_id not in seen: + seen[chat_id] = { + "id": chat_id, + "title": chat.get("title") or (chat.get("first_name", "") + (" " + chat.get("last_name", "")).strip()), + "type": chat.get("type", "private"), + "username": chat.get("username", ""), + } + return list(seen.values()) def _chat_response(c: TelegramChat) -> dict: @@ -410,10 +413,9 @@ def _bot_response(b: TelegramBot) -> dict: async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> TelegramBot: - bot = await session.get(TelegramBot, bot_id) - if not bot or bot.user_id != user_id: - raise HTTPException(status_code=404, detail="Bot not found") - return bot + return await get_owned_entity( + session, TelegramBot, bot_id, user_id, not_found_msg="Bot not found", + ) diff --git a/packages/server/src/notify_bridge_server/api/template_configs.py b/packages/server/src/notify_bridge_server/api/template_configs.py index 5b6a8fa..c78db85 100644 --- a/packages/server/src/notify_bridge_server/api/template_configs.py +++ b/packages/server/src/notify_bridge_server/api/template_configs.py @@ -13,12 +13,12 @@ from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from jinja2.sandbox import SandboxedEnvironment -from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined from ..auth.dependencies import get_current_user from ..database.engine import get_session from ..database.models import TemplateConfig, TemplateSlot, User from ..services.sample_context import _SAMPLE_CONTEXT +from .slot_helpers import load_slots, render_template_preview, save_slots _LOGGER = logging.getLogger(__name__) @@ -49,40 +49,13 @@ class TemplateConfigUpdate(BaseModel): # --------------------------------------------------------------------------- async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]: - """Load all template slots for a config as {slot_name: {locale: template}}.""" - result = await session.exec( - select(TemplateSlot).where(TemplateSlot.config_id == config_id) - ) - slots: dict[str, dict[str, str]] = {} - for s in result.all(): - slots.setdefault(s.slot_name, {})[s.locale] = s.template - return slots + return await load_slots(session, TemplateSlot, config_id) async def _save_slots( session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]] ) -> None: - """Create or update template slots for a config (locale-aware).""" - for slot_name, locale_map in slots.items(): - for locale, template_text in locale_map.items(): - result = await session.exec( - select(TemplateSlot).where( - TemplateSlot.config_id == config_id, - TemplateSlot.slot_name == slot_name, - TemplateSlot.locale == locale, - ) - ) - existing = result.first() - if existing: - existing.template = template_text - session.add(existing) - else: - session.add(TemplateSlot( - config_id=config_id, - slot_name=slot_name, - locale=locale, - template=template_text, - )) + await save_slots(session, TemplateSlot, config_id, slots) async def _response(session: AsyncSession, c: TemplateConfig) -> dict[str, Any]: @@ -155,7 +128,7 @@ async def get_template_variables( "photo_count": "Total photo count in album", "video_count": "Total video count in album", "owner": "Album owner name", - "target_type": "Target type: 'telegram' or 'webhook'", + "target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix", "has_videos": "Whether added assets contain videos (boolean)", "has_photos": "Whether added assets contain photos (boolean)", "has_oversized_videos": "Whether any video exceeds the target's size limit (boolean)", @@ -206,7 +179,7 @@ async def get_template_variables( } scheduled_vars = { "date": "Current date string", - "target_type": "Target type: 'telegram' or 'webhook'", + "target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix", } return { @@ -284,7 +257,7 @@ def _webhook_variables() -> dict: "source_ip": "IP address of the webhook sender", "raw_payload": "Full JSON payload as dict (use raw_payload.field or raw_payload | tojson)", "timestamp": "When the webhook was received", - "target_type": "Target type: 'telegram' or 'webhook'", + "target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix", }, }, } @@ -529,7 +502,7 @@ async def preview_date_format( class PreviewRequest(BaseModel): template: str - target_type: str = "telegram" # "telegram" or "webhook" + target_type: str = "telegram" # telegram, webhook, email, discord, slack, ntfy, matrix date_format: str = "%d.%m.%Y, %H:%M UTC" date_only_format: str = "%d.%m.%Y" @@ -545,33 +518,12 @@ async def preview_raw( 1. Parse with default Undefined (catches syntax errors) 2. Render with StrictUndefined (catches unknown variables like {{ asset.a }}) """ - # Pass 1: syntax check + from datetime import datetime + ctx = {**_SAMPLE_CONTEXT, "target_type": body.target_type, + "date_format": body.date_format, "date_only_format": body.date_only_format} + # Format common_date using the provided date_only_format try: - env = SandboxedEnvironment(autoescape=False) - env.from_string(body.template) - except TemplateSyntaxError as e: - return { - "rendered": None, - "error": e.message, - "error_line": e.lineno, - } - - # Pass 2: render with strict undefined to catch unknown variables - try: - from datetime import datetime - ctx = {**_SAMPLE_CONTEXT, "target_type": body.target_type, - "date_format": body.date_format, "date_only_format": body.date_only_format} - # Format common_date using the provided date_only_format - try: - ctx["common_date"] = datetime(2026, 3, 19).strftime(body.date_only_format) - except (ValueError, TypeError): - ctx["common_date"] = "19.03.2026" - strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined) - tmpl = strict_env.from_string(body.template) - rendered = tmpl.render(**ctx) - return {"rendered": rendered} - except UndefinedError as e: - # Still a valid template syntactically, but references unknown variable - return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"} - except Exception as e: - return {"rendered": None, "error": str(e), "error_line": None} + ctx["common_date"] = datetime(2026, 3, 19).strftime(body.date_only_format) + except (ValueError, TypeError): + ctx["common_date"] = "19.03.2026" + return render_template_preview(body.template, ctx) diff --git a/packages/server/src/notify_bridge_server/commands/base.py b/packages/server/src/notify_bridge_server/commands/base.py index 915e86e..1bd1f46 100644 --- a/packages/server/src/notify_bridge_server/commands/base.py +++ b/packages/server/src/notify_bridge_server/commands/base.py @@ -3,9 +3,18 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import Any -from ..database.models import CommandTracker, CommandConfig, ServiceProvider, TelegramBot +from ..database.models import CommandConfig, CommandTracker, ServiceProvider, TelegramBot + + +@dataclass(frozen=True) +class CommandResponse: + """A single response from one tracker's command execution.""" + + text: str | None = None + media: list[dict[str, Any]] = field(default_factory=list) class ProviderCommandHandler(ABC): @@ -14,6 +23,8 @@ class ProviderCommandHandler(ABC): Each provider (Immich, Gitea, etc.) implements this interface to handle its own set of commands. The dispatch layer routes commands to the correct handler based on the provider type. + + Each handler call receives a single (tracker, config, provider) context. """ provider_type: str @@ -35,26 +46,28 @@ class ProviderCommandHandler(ABC): count: int, locale: str, response_mode: str, - providers_map: dict[int, ServiceProvider], + provider: ServiceProvider, cmd_templates: dict[str, dict[str, str]], bot: TelegramBot, - ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], - ) -> str | list[dict[str, Any]] | None: - """Handle a provider-specific command. + tracker: CommandTracker, + config: CommandConfig, + ) -> CommandResponse | None: + """Handle a provider-specific command for a single tracker. Args: cmd: The command name (without '/'). args: Arguments after the command. count: Number of results to return. locale: User's locale ('en', 'ru'). - response_mode: 'media' or 'text'. - providers_map: Provider instances keyed by ID. - cmd_templates: Template slots {slot_name: {locale: template}}. + response_mode: 'media' or 'text' (from this tracker's config). + provider: The service provider instance for this tracker. + cmd_templates: Template slots for this tracker's command template config. bot: The Telegram bot instance. - ctx_tuples: Command context tuples for this provider type. + tracker: The command tracker being dispatched. + config: The command config for this tracker. Returns: - Text response, media list, or None if unhandled. + A CommandResponse, or None if unhandled. """ def get_rate_categories(self) -> dict[str, str]: diff --git a/packages/server/src/notify_bridge_server/commands/command_utils.py b/packages/server/src/notify_bridge_server/commands/command_utils.py new file mode 100644 index 0000000..287b318 --- /dev/null +++ b/packages/server/src/notify_bridge_server/commands/command_utils.py @@ -0,0 +1,67 @@ +"""Shared command handler utilities to reduce boilerplate across providers.""" + +from __future__ import annotations + +import logging + +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from ..database.engine import get_engine +from ..database.models import EventLog, NotificationTracker, ServiceProvider + +_LOGGER = logging.getLogger(__name__) + + +async def get_trackers_for_provider(provider_id: int) -> list[NotificationTracker]: + """Get notification trackers for a single provider.""" + from .handler import _get_notification_trackers_for_providers + + return await _get_notification_trackers_for_providers({provider_id}) + + +async def get_last_event_str(tracker_ids: list[int]) -> str: + """Get formatted timestamp of most recent event for given trackers. + + Returns a 'YYYY-MM-DD HH:MM' string, or '-' if no events exist. + """ + if not tracker_ids: + return "-" + engine = get_engine() + async with AsyncSession(engine) as session: + result = await session.exec( + select(EventLog) + .where(EventLog.tracker_id.in_(tracker_ids)) + .order_by(EventLog.created_at.desc()) + .limit(1) + ) + last_event = result.first() + return last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-" + + +def get_tracked_collection_ids( + provider: ServiceProvider, + trackers: list[NotificationTracker], + *, + max_items: int = 20, +) -> list[str]: + """Get deduplicated collection IDs from trackers for a provider. + + Iterates all trackers belonging to *provider*, collects IDs from both + ``collection_ids`` and ``filters.collections``, deduplicates while + preserving order, and caps at *max_items*. + """ + seen: set[str] = set() + result: list[str] = [] + for tracker in trackers: + if tracker.provider_id != provider.id: + continue + for cid in tracker.collection_ids or []: + if cid not in seen: + seen.add(cid) + result.append(cid) + for cid in (tracker.filters or {}).get("collections", []): + if cid not in seen: + seen.add(cid) + result.append(cid) + return result[:max_items] diff --git a/packages/server/src/notify_bridge_server/commands/gitea_handler.py b/packages/server/src/notify_bridge_server/commands/gitea_handler.py index 0ccc1e1..f6ed77d 100644 --- a/packages/server/src/notify_bridge_server/commands/gitea_handler.py +++ b/packages/server/src/notify_bridge_server/commands/gitea_handler.py @@ -2,27 +2,55 @@ from __future__ import annotations +import asyncio import logging +from collections.abc import Callable, Coroutine from typing import Any -import aiohttp -from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession - -from ..database.engine import get_engine from ..database.models import ( - CommandConfig, CommandTracker, EventLog, - NotificationTracker, ServiceProvider, TelegramBot, + CommandConfig, CommandTracker, ServiceProvider, TelegramBot, ) from ..services import make_gitea_provider -from .base import ProviderCommandHandler -from .handler import _render_cmd_template, _get_notification_trackers_for_providers +from ..services.http_session import get_http_session +from .base import CommandResponse, ProviderCommandHandler +from .command_utils import get_last_event_str, get_tracked_collection_ids, get_trackers_for_provider +from .handler import _render_cmd_template _LOGGER = logging.getLogger(__name__) _GITEA_COMMANDS = {"status", "repos", "issues", "prs", "commits"} +def _get_tracked_repos( + provider: ServiceProvider, + trackers: list, +) -> list[tuple[ServiceProvider, str, str]]: + """Get (provider, owner, repo) tuples from tracked collection_ids.""" + if not provider.config.get("api_token"): + return [] + collection_ids = get_tracked_collection_ids(provider, trackers) + repos: list[tuple[ServiceProvider, str, str]] = [] + for full_name in collection_ids: + parts = full_name.split("/", 1) + if len(parts) == 2: + repos.append((provider, parts[0], parts[1])) + return repos + + +# --------------------------------------------------------------------------- +# Command dispatch table +# --------------------------------------------------------------------------- + +_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {} + + +def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]: + """Register a function in the text command dispatch table.""" + name = fn.__name__.removeprefix("_cmd_") + _TEXT_COMMANDS[name] = fn + return fn + + class GiteaCommandHandler(ProviderCommandHandler): """Handles Gitea-specific bot commands.""" @@ -44,91 +72,35 @@ class GiteaCommandHandler(ProviderCommandHandler): count: int, locale: str, response_mode: str, - providers_map: dict[int, ServiceProvider], + provider: ServiceProvider, cmd_templates: dict[str, dict[str, str]], bot: TelegramBot, - ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], - ) -> str | list[dict[str, Any]] | None: - if cmd == "status": - ctx = await _cmd_status(providers_map) - return _render_cmd_template(cmd_templates, "status", locale, ctx) - if cmd == "repos": - ctx = await _cmd_repos(providers_map) - return _render_cmd_template(cmd_templates, "repos", locale, ctx) - if cmd == "issues": - ctx = await _cmd_issues(providers_map, count) - return _render_cmd_template(cmd_templates, "issues", locale, ctx) - if cmd == "prs": - ctx = await _cmd_prs(providers_map, count) - return _render_cmd_template(cmd_templates, "prs", locale, ctx) - if cmd == "commits": - ctx = await _cmd_commits(providers_map, count) - return _render_cmd_template(cmd_templates, "commits", locale, ctx) - return None + tracker: CommandTracker, + config: CommandConfig, + ) -> CommandResponse | None: + fn = _TEXT_COMMANDS.get(cmd) + if fn is None: + return None + ctx = await fn(provider, count) + return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx)) -def _get_tracked_repos( - providers_map: dict[int, ServiceProvider], - trackers: list[NotificationTracker], -) -> list[tuple[ServiceProvider, str, str]]: - """Get (provider, owner, repo) tuples from tracked collection_ids.""" - repos: list[tuple[ServiceProvider, str, str]] = [] - for tracker in trackers: - provider = providers_map.get(tracker.provider_id) - if not provider or provider.type != "gitea": - continue - if not provider.config.get("api_token"): - continue - for full_name in (tracker.collection_ids or []): - parts = full_name.split("/", 1) - if len(parts) == 2: - repos.append((provider, parts[0], parts[1])) - # Also check filters.collections - for tracker in trackers: - provider = providers_map.get(tracker.provider_id) - if not provider or provider.type != "gitea": - continue - if not provider.config.get("api_token"): - continue - for full_name in (tracker.filters or {}).get("collections", []): - parts = full_name.split("/", 1) - if len(parts) == 2: - entry = (provider, parts[0], parts[1]) - if entry not in repos: - repos.append(entry) - return repos[:20] # Cap to prevent API hammering +@_text_cmd +async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]: + trackers = await get_trackers_for_provider(provider.id) + tracked_repos = _get_tracked_repos(provider, trackers) - -async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) - tracked_repos = _get_tracked_repos(providers_map, trackers) - - # Get server version from first Gitea provider with token + # Get server version server_version = "unknown" - async with aiohttp.ClientSession() as http: - for provider in providers_map.values(): - if provider.type == "gitea" and provider.config.get("api_token"): - gitea = make_gitea_provider(http, provider) - version = await gitea.client.get_server_version() - if version: - server_version = version - break + if provider.config.get("api_token"): + http = await get_http_session() + gitea = make_gitea_provider(http, provider) + version = await gitea.client.get_server_version() + if version: + server_version = version - # Last event - engine = get_engine() - async with AsyncSession(engine) as session: - tracker_ids = [t.id for t in trackers] - if tracker_ids: - result = await session.exec( - select(EventLog) - .where(EventLog.tracker_id.in_(tracker_ids)) - .order_by(EventLog.created_at.desc()).limit(1) - ) - last_event = result.first() - else: - last_event = None - last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-" + tracker_ids = [t.id for t in trackers] + last_str = await get_last_event_str(tracker_ids) return { "repos_count": len(tracked_repos), @@ -137,116 +109,139 @@ async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, An } -async def _cmd_repos(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) - tracked_repos = _get_tracked_repos(providers_map, trackers) +@_text_cmd +async def _cmd_repos(provider: ServiceProvider, count: int) -> dict[str, Any]: + trackers = await get_trackers_for_provider(provider.id) + tracked_repos = _get_tracked_repos(provider, trackers) repos_data: list[dict[str, Any]] = [] - async with aiohttp.ClientSession() as http: - for provider, owner, repo in tracked_repos: - gitea = make_gitea_provider(http, provider) - try: - all_repos = await gitea.client.get_repos(limit=50) - for r in all_repos: - if r.get("full_name") == f"{owner}/{repo}": - repos_data.append({ - "full_name": r.get("full_name", ""), - "description": r.get("description", ""), - "stars": r.get("stars_count", 0), - "url": r.get("html_url", ""), - }) - break - else: - repos_data.append({ - "full_name": f"{owner}/{repo}", - "description": "", - "stars": 0, - "url": "", - }) - except Exception: - repos_data.append({ - "full_name": f"{owner}/{repo}", - "description": "?", - "stars": 0, - "url": "", - }) + http = await get_http_session() + + async def _fetch_repo(prov: ServiceProvider, owner: str, repo: str) -> dict[str, Any]: + gitea = make_gitea_provider(http, prov) + # Use direct get_repo endpoint instead of listing all repos + r = await gitea.client.get_repo(owner, repo) + if r: + return { + "full_name": r.get("full_name", ""), + "description": r.get("description", ""), + "stars": r.get("stars_count", 0), + "url": r.get("html_url", ""), + } + return { + "full_name": f"{owner}/{repo}", + "description": "", + "stars": 0, + "url": "", + } + + tasks = [_fetch_repo(prov, owner, repo) for prov, owner, repo in tracked_repos] + results = await asyncio.gather(*tasks, return_exceptions=True) + for (prov, owner, repo), result in zip(tracked_repos, results): + if isinstance(result, Exception): + _LOGGER.warning("Failed to fetch repo %s/%s: %s", owner, repo, result) + repos_data.append({ + "full_name": f"{owner}/{repo}", + "description": "?", + "stars": 0, + "url": "", + }) + else: + repos_data.append(result) return {"repos": repos_data} -async def _cmd_issues( - providers_map: dict[int, ServiceProvider], count: int, -) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) - tracked_repos = _get_tracked_repos(providers_map, trackers) +@_text_cmd +async def _cmd_issues(provider: ServiceProvider, count: int) -> dict[str, Any]: + trackers = await get_trackers_for_provider(provider.id) + tracked_repos = _get_tracked_repos(provider, trackers) all_issues: list[dict[str, Any]] = [] - async with aiohttp.ClientSession() as http: - for provider, owner, repo in tracked_repos: - gitea = make_gitea_provider(http, provider) - issues = await gitea.client.get_repo_issues(owner, repo, limit=count) - for issue in issues: - all_issues.append({ - "repo": f"{owner}/{repo}", - "number": issue.get("number", 0), - "title": issue.get("title", ""), - "url": issue.get("html_url", ""), - "user": issue.get("user", {}).get("login", ""), - "state": issue.get("state", ""), - }) + http = await get_http_session() + + async def _fetch_issues(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]: + gitea = make_gitea_provider(http, prov) + return await gitea.client.get_repo_issues(owner, repo, limit=count) + + tasks = [_fetch_issues(prov, owner, repo) for prov, owner, repo in tracked_repos] + results = await asyncio.gather(*tasks, return_exceptions=True) + for (prov, owner, repo), result in zip(tracked_repos, results): + if isinstance(result, Exception): + _LOGGER.warning("Failed to fetch issues for %s/%s: %s", owner, repo, result) + continue + for issue in result: + all_issues.append({ + "repo": f"{owner}/{repo}", + "number": issue.get("number", 0), + "title": issue.get("title", ""), + "url": issue.get("html_url", ""), + "user": issue.get("user", {}).get("login", ""), + "state": issue.get("state", ""), + }) all_issues.sort(key=lambda i: i.get("number", 0), reverse=True) return {"issues": all_issues[:count]} -async def _cmd_prs( - providers_map: dict[int, ServiceProvider], count: int, -) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) - tracked_repos = _get_tracked_repos(providers_map, trackers) +@_text_cmd +async def _cmd_prs(provider: ServiceProvider, count: int) -> dict[str, Any]: + trackers = await get_trackers_for_provider(provider.id) + tracked_repos = _get_tracked_repos(provider, trackers) all_prs: list[dict[str, Any]] = [] - async with aiohttp.ClientSession() as http: - for provider, owner, repo in tracked_repos: - gitea = make_gitea_provider(http, provider) - prs = await gitea.client.get_repo_pulls(owner, repo, limit=count) - for pr in prs: - all_prs.append({ - "repo": f"{owner}/{repo}", - "number": pr.get("number", 0), - "title": pr.get("title", ""), - "url": pr.get("html_url", ""), - "user": pr.get("user", {}).get("login", ""), - "state": pr.get("state", ""), - }) + http = await get_http_session() + + async def _fetch_prs(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]: + gitea = make_gitea_provider(http, prov) + return await gitea.client.get_repo_pulls(owner, repo, limit=count) + + tasks = [_fetch_prs(prov, owner, repo) for prov, owner, repo in tracked_repos] + results = await asyncio.gather(*tasks, return_exceptions=True) + for (prov, owner, repo), result in zip(tracked_repos, results): + if isinstance(result, Exception): + _LOGGER.warning("Failed to fetch PRs for %s/%s: %s", owner, repo, result) + continue + for pr in result: + all_prs.append({ + "repo": f"{owner}/{repo}", + "number": pr.get("number", 0), + "title": pr.get("title", ""), + "url": pr.get("html_url", ""), + "user": pr.get("user", {}).get("login", ""), + "state": pr.get("state", ""), + }) all_prs.sort(key=lambda p: p.get("number", 0), reverse=True) return {"prs": all_prs[:count]} -async def _cmd_commits( - providers_map: dict[int, ServiceProvider], count: int, -) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) - tracked_repos = _get_tracked_repos(providers_map, trackers) +@_text_cmd +async def _cmd_commits(provider: ServiceProvider, count: int) -> dict[str, Any]: + trackers = await get_trackers_for_provider(provider.id) + tracked_repos = _get_tracked_repos(provider, trackers) all_commits: list[dict[str, Any]] = [] - async with aiohttp.ClientSession() as http: - for provider, owner, repo in tracked_repos: - gitea = make_gitea_provider(http, provider) - commits = await gitea.client.get_repo_commits(owner, repo, limit=count) - for c in commits: - commit_data = c.get("commit", {}) - all_commits.append({ - "repo": f"{owner}/{repo}", - "short_id": c.get("sha", "")[:7], - "message": commit_data.get("message", "").split("\n")[0][:80], - "author": commit_data.get("author", {}).get("name", ""), - "url": c.get("html_url", ""), - }) + http = await get_http_session() + + async def _fetch_commits(prov: ServiceProvider, owner: str, repo: str) -> list[dict[str, Any]]: + gitea = make_gitea_provider(http, prov) + return await gitea.client.get_repo_commits(owner, repo, limit=count) + + tasks = [_fetch_commits(prov, owner, repo) for prov, owner, repo in tracked_repos] + results = await asyncio.gather(*tasks, return_exceptions=True) + for (prov, owner, repo), result in zip(tracked_repos, results): + if isinstance(result, Exception): + _LOGGER.warning("Failed to fetch commits for %s/%s: %s", owner, repo, result) + continue + for c in result: + commit_data = c.get("commit", {}) + all_commits.append({ + "repo": f"{owner}/{repo}", + "short_id": c.get("sha", "")[:7], + "message": commit_data.get("message", "").split("\n")[0][:80], + "author": commit_data.get("author", {}).get("name", ""), + "url": c.get("html_url", ""), + }) return {"commits": all_commits[:count]} diff --git a/packages/server/src/notify_bridge_server/commands/handler.py b/packages/server/src/notify_bridge_server/commands/handler.py index b16ee82..d886457 100644 --- a/packages/server/src/notify_bridge_server/commands/handler.py +++ b/packages/server/src/notify_bridge_server/commands/handler.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging import time +from functools import lru_cache from typing import Any import aiohttp @@ -25,17 +26,21 @@ from ..database.models import ( ServiceProvider, TelegramBot, ) +from .base import CommandResponse from .parser import parse_command from .registry import get_rate_category _LOGGER = logging.getLogger(__name__) # Singleton Jinja2 environment for template rendering (Phase 4d) -_JINJA_ENV = SandboxedEnvironment(autoescape=False) +_JINJA_ENV = SandboxedEnvironment(autoescape=True) # Rate limit state with automatic TTL expiry (Phase 4e) _rate_limits: TTLCache = TTLCache(maxsize=10000, ttl=3600) +# Maximum responses per command to avoid Telegram rate limits +_MAX_RESPONSES_PER_COMMAND = 5 + def _check_rate_limit(bot_id: int, chat_id: str, cmd: str, limits: dict[str, int]) -> int | None: """Check rate limit. Returns seconds to wait, or None if OK.""" @@ -60,6 +65,12 @@ def _resolve_template( return locale_map.get(locale) or locale_map.get("en") +@lru_cache(maxsize=256) +def _compile_template(template_str: str): + """Cache compiled Jinja2 templates to avoid re-parsing identical strings.""" + return _JINJA_ENV.from_string(template_str) + + def _render_cmd_template( templates: dict[str, dict[str, str]], slot_name: str, locale: str, context: dict[str, Any], @@ -70,20 +81,28 @@ def _render_cmd_template( _LOGGER.warning("No command template found for slot '%s' locale '%s'", slot_name, locale) return f"[No template: {slot_name}]" try: - tmpl = _JINJA_ENV.from_string(template_str) + tmpl = _compile_template(template_str) return tmpl.render(**context) except Exception as e: _LOGGER.warning("Failed to render command template '%s': %s", slot_name, e) return f"[Template error: {slot_name}]" +# --------------------------------------------------------------------------- +# Context resolution +# --------------------------------------------------------------------------- + async def _resolve_command_context( bot: TelegramBot, -) -> tuple[list[tuple[CommandTracker, CommandConfig, ServiceProvider]], dict[str, dict[str, str]]]: +) -> tuple[ + list[tuple[CommandTracker, CommandConfig, ServiceProvider]], + dict[int, dict[str, dict[str, str]]], +]: """Resolve all enabled command trackers, configs, and providers for a bot. - Returns (context_tuples, cmd_template_slots). - cmd_template_slots is {slot_name: {locale: template}}. + Returns: + (context_tuples, templates_by_config_id) + templates_by_config_id is {command_template_config_id: {slot_name: {locale: template}}}. """ engine = get_engine() async with AsyncSession(engine) as session: @@ -142,8 +161,8 @@ async def _resolve_command_context( continue tuples.append((tracker, config, provider)) - # Load command template slots — merge from all configs - cmd_template_slots: dict[str, dict[str, str]] = {} + # Load command template slots per config (not merged) + templates_by_config_id: dict[int, dict[str, dict[str, str]]] = {} seen_config_ids: set[int] = set() for _, config, _ in tuples: cfg_id = config.command_template_config_id @@ -154,98 +173,136 @@ async def _resolve_command_context( CommandTemplateSlot.config_id == cfg_id ) ) + slots: dict[str, dict[str, str]] = {} for s in slot_result.all(): - cmd_template_slots.setdefault(s.slot_name, {})[s.locale] = s.template + slots.setdefault(s.slot_name, {})[s.locale] = s.template + templates_by_config_id[cfg_id] = slots - return tuples, cmd_template_slots + return tuples, templates_by_config_id -def _merge_command_context( +def _templates_for_config( + templates_by_config_id: dict[int, dict[str, dict[str, str]]], + config: CommandConfig, +) -> dict[str, dict[str, str]]: + """Get template slots for a specific command config.""" + cfg_id = config.command_template_config_id + if cfg_id and cfg_id in templates_by_config_id: + return templates_by_config_id[cfg_id] + return {} + + +def _merge_all_templates( + templates_by_config_id: dict[int, dict[str, dict[str, str]]], +) -> dict[str, dict[str, str]]: + """Merge all template config slots into one dict (for universal commands).""" + merged: dict[str, dict[str, str]] = {} + for slots in templates_by_config_id.values(): + for slot_name, locale_map in slots.items(): + merged.setdefault(slot_name, {}).update(locale_map) + return merged + + +def _merge_enabled_commands( ctx: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], -) -> tuple[list[str], str, int, dict[str, Any]]: - """Merge enabled_commands from all configs and pick defaults from first config.""" +) -> tuple[list[str], dict[str, Any]]: + """Merge enabled_commands (union) and rate_limits from all configs. + + Rate limits use the most restrictive (minimum) cooldown per category. + """ if not ctx: - return [], "media", 5, {} + return [], {} enabled: set[str] = set() + merged_limits: dict[str, int] = {} for _, config, _ in ctx: enabled.update(config.enabled_commands or []) + for category, cooldown in (config.rate_limits or {}).items(): + if category not in merged_limits: + merged_limits[category] = cooldown + else: + merged_limits[category] = min(merged_limits[category], cooldown) - first_config = ctx[0][1] - response_mode = first_config.response_mode or "media" - default_count = first_config.default_count or 5 - rate_limits = first_config.rate_limits or {} + return sorted(enabled), merged_limits - return sorted(enabled), response_mode, default_count, rate_limits +# --------------------------------------------------------------------------- +# Main dispatcher +# --------------------------------------------------------------------------- async def handle_command( bot: TelegramBot, chat_id: str, text: str, language_code: str = "", -) -> str | list[dict[str, Any]] | None: +) -> list[CommandResponse] | None: """Handle a bot command. Routes to provider-specific handlers. - Returns text response, media list, or None. + Returns a list of CommandResponse objects (one per tracker), or None. + Universal commands (/start, /help) return a single-element list. + Provider-specific commands dispatch per-tracker with per-tracker config. """ cmd, args, count_override = parse_command(text) if not cmd: return None - ctx_tuples, cmd_templates = await _resolve_command_context(bot) - enabled, response_mode, default_count, rate_limits = _merge_command_context(ctx_tuples) + ctx_tuples, templates_by_config_id = await _resolve_command_context(bot) + enabled, rate_limits = _merge_enabled_commands(ctx_tuples) locale = language_code[:2].lower() if language_code else "en" if locale not in ("en", "ru"): locale = "en" + # Merged templates for universal commands + merged_templates = _merge_all_templates(templates_by_config_id) + if cmd == "start": - return _render_cmd_template(cmd_templates, "start", locale, {"bot_name": bot.name}) + text_resp = _render_cmd_template(merged_templates, "start", locale, {"bot_name": bot.name}) + return [CommandResponse(text=text_resp)] if cmd not in enabled and cmd != "start": return None - # Rate limit check + # Rate limit check (once per command, shared across all trackers) wait = _check_rate_limit(bot.id, chat_id, cmd, rate_limits) if wait is not None: - return _render_cmd_template(cmd_templates, "rate_limited", locale, {"wait": wait}) + text_resp = _render_cmd_template(merged_templates, "rate_limited", locale, {"wait": wait}) + return [CommandResponse(text=text_resp)] - count = min(count_override or default_count, 20) - - # Build providers map from command context - providers_map: dict[int, ServiceProvider] = {} - for _, _, provider in ctx_tuples: - providers_map[provider.id] = provider - - # Universal commands + # Universal commands — single merged response if cmd == "help": - ctx = _cmd_help(enabled, locale, cmd_templates) - return _render_cmd_template(cmd_templates, "help", locale, ctx) + ctx = _cmd_help(enabled, locale, merged_templates) + text_resp = _render_cmd_template(merged_templates, "help", locale, ctx) + return [CommandResponse(text=text_resp)] - # Provider-specific dispatch + # Provider-specific dispatch — per-tracker from .dispatch import get_handler - # Group ctx_tuples by provider type - by_type: dict[str, list[tuple[CommandTracker, CommandConfig, ServiceProvider]]] = {} - for t in ctx_tuples: - ptype = t[2].type - by_type.setdefault(ptype, []).append(t) - - # Find which handler claims this command - for ptype, ptuples in by_type.items(): - handler = get_handler(ptype) - if handler and cmd in handler.get_provider_commands(): - # Build provider map filtered to this provider type - pmap = {p.id: p for _, _, p in ptuples} - result = await handler.handle( - cmd, args, count, locale, response_mode, - pmap, cmd_templates, bot, ptuples, + responses: list[CommandResponse] = [] + for tracker, config, provider in ctx_tuples: + if len(responses) >= _MAX_RESPONSES_PER_COMMAND: + _LOGGER.warning( + "Truncated command responses at %d for bot %d cmd /%s", + _MAX_RESPONSES_PER_COMMAND, bot.id, cmd, ) - if result is not None: - return result + break - return None + handler = get_handler(provider.type) + if not handler or cmd not in handler.get_provider_commands(): + continue + + tracker_templates = _templates_for_config(templates_by_config_id, config) + count = min(count_override or config.default_count or 5, 20) + response_mode = config.response_mode or "media" + + result = await handler.handle( + cmd, args, count, locale, response_mode, + provider, tracker_templates, bot, tracker, config, + ) + if result is not None: + responses.append(result) + + return responses if responses else None def _cmd_help( @@ -283,17 +340,13 @@ async def send_reply( session: aiohttp.ClientSession | None = None, ) -> None: """Send a text reply via TelegramClient.""" - async def _send(http: aiohttp.ClientSession) -> None: - client = TelegramClient(http, bot_token) - result = await client.send_message(chat_id, text, reply_to_message_id=reply_to_message_id) - if not result.get("success"): - _LOGGER.warning("Telegram reply failed: %s", result.get("error")) - - if session is not None: - await _send(session) - else: - async with aiohttp.ClientSession() as http: - await _send(http) + if session is None: + from ..services.http_session import get_http_session + session = await get_http_session() + client = TelegramClient(session, bot_token) + result = await client.send_message(chat_id, text, reply_to_message_id=reply_to_message_id) + if not result.get("success"): + _LOGGER.warning("Telegram reply failed: %s", result.get("error")) async def send_media_group( @@ -319,52 +372,50 @@ async def send_media_group( captions = [item.get("caption", "") for item in media_items if item.get("caption")] caption = "\n".join(captions) if captions else None - async def _send(http: aiohttp.ClientSession) -> None: - client = TelegramClient(http, bot_token) - result = await client.send_notification( - chat_id, assets=assets, caption=caption, - reply_to_message_id=reply_to_message_id, - chat_action=None, - ) - if not result.get("success"): - _LOGGER.warning("Telegram media group failed: %s", result.get("error")) - - if session is not None: - await _send(session) - else: - async with aiohttp.ClientSession() as http: - await _send(http) + if session is None: + from ..services.http_session import get_http_session + session = await get_http_session() + client = TelegramClient(session, bot_token) + result = await client.send_notification( + chat_id, assets=assets, caption=caption, + reply_to_message_id=reply_to_message_id, + chat_action=None, + ) + if not result.get("success"): + _LOGGER.warning("Telegram media group failed: %s", result.get("error")) async def register_commands_with_telegram(bot: TelegramBot) -> bool: """Register enabled commands with Telegram BotFather API via TelegramClient.""" - ctx_tuples, templates = await _resolve_command_context(bot) - enabled, _, _, _ = _merge_command_context(ctx_tuples) + ctx_tuples, templates_by_config_id = await _resolve_command_context(bot) + enabled, _ = _merge_enabled_commands(ctx_tuples) + templates = _merge_all_templates(templates_by_config_id) - async with aiohttp.ClientSession() as http: - client = TelegramClient(http, bot.token) - success = False + from ..services.http_session import get_http_session + http = await get_http_session() + client = TelegramClient(http, bot.token) + success = False - # Register per-locale commands - for locale in ("en", "ru"): - commands = [] - for cmd in enabled: - desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd - commands.append({"command": cmd, "description": desc}) - result = await client.set_my_commands(commands, language_code=locale) - if result.get("success"): - success = True - else: - _LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error")) - - # Register default (no language_code) with EN descriptions - en_commands = [] + # Register per-locale commands + for locale in ("en", "ru"): + commands = [] for cmd in enabled: - desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd - en_commands.append({"command": cmd, "description": desc}) - result = await client.set_my_commands(en_commands) + desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd + commands.append({"command": cmd, "description": desc}) + result = await client.set_my_commands(commands, language_code=locale) if result.get("success"): - _LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username) success = True + else: + _LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error")) - return success + # Register default (no language_code) with EN descriptions + en_commands = [] + for cmd in enabled: + desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd + en_commands.append({"command": cmd, "description": desc}) + result = await client.set_my_commands(en_commands) + if result.get("success"): + _LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username) + success = True + + return success diff --git a/packages/server/src/notify_bridge_server/commands/immich/albums.py b/packages/server/src/notify_bridge_server/commands/immich/albums.py index f792913..11afc63 100644 --- a/packages/server/src/notify_bridge_server/commands/immich/albums.py +++ b/packages/server/src/notify_bridge_server/commands/immich/albums.py @@ -6,70 +6,48 @@ import asyncio import logging from typing import Any -import aiohttp - -from notify_bridge_core.providers.immich.asset_utils import get_public_url - -from ...database.models import ServiceProvider, TelegramBot +from ...database.models import ServiceProvider from ...services import make_immich_provider -from ..handler import _get_notification_trackers_for_providers, _render_cmd_template -from .common import _format_assets, build_asset_dict +from ...services.http_session import get_http_session +from ..command_utils import get_trackers_for_provider +from ..handler import _render_cmd_template +from .common import _format_assets, build_asset_dict, fetch_albums_with_links _LOGGER = logging.getLogger(__name__) async def _cmd_albums( - bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str, + provider: ServiceProvider, locale: str, ) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) + trackers = await get_trackers_for_provider(provider.id) if not trackers: return {"albums": []} - albums_data: list[dict] = [] - async with aiohttp.ClientSession() as http: - for tracker in trackers: - provider = providers_map.get(tracker.provider_id) - if not provider or provider.type != "immich": - continue - immich = make_immich_provider(http, provider) - album_ids = tracker.collection_ids or [] - if not album_ids: - continue + # Deduplicate album IDs while preserving order + seen: set[str] = set() + album_ids: list[str] = [] + for tracker in trackers: + for aid in tracker.collection_ids or []: + if aid not in seen: + seen.add(aid) + album_ids.append(aid) + if not album_ids: + return {"albums": []} - ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/") - album_results = await asyncio.gather( - *[immich.client.get_album(aid) for aid in album_ids], - return_exceptions=True, - ) - link_results = await asyncio.gather( - *[immich.client.get_shared_links(aid) for aid in album_ids], - return_exceptions=True, - ) - for album_id, result, links in zip(album_ids, album_results, link_results): - if isinstance(result, Exception): - _LOGGER.warning("Failed to fetch album %s: %s", album_id, result) - albums_data.append({ - "name": f"{album_id[:8]}...", "asset_count": "?", "id": album_id, - }) - elif result: - pub_url = "" - if not isinstance(links, Exception) and ext_domain: - pub_url = get_public_url(ext_domain, links) or "" - albums_data.append({ - "name": result.name, "asset_count": result.asset_count, - "id": album_id, "public_url": pub_url, - }) + ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/") + http = await get_http_session() + immich = make_immich_provider(http, provider) + albums_data = await fetch_albums_with_links(immich.client, album_ids, ext_domain) return {"albums": albums_data} async def cmd_favorites( - bot: TelegramBot, providers_map: dict[int, ServiceProvider], + providers_map: dict[int, ServiceProvider], all_album_ids: list[str], count: int, locale: str, response_mode: str, client: Any, cmd_templates: dict[str, dict[str, str]], -) -> str | list[dict[str, Any]]: +) -> str | dict[str, Any]: """Handle /favorites command with concurrent album fetching.""" album_ids = all_album_ids[:10] if not album_ids: @@ -104,28 +82,6 @@ async def cmd_summary( if not all_album_ids: return _render_cmd_template(cmd_templates, "summary", locale, {"albums": []}) - album_results = await asyncio.gather( - *[client.get_album(aid) for aid in all_album_ids], - return_exceptions=True, - ) - link_results = await asyncio.gather( - *[client.get_shared_links(aid) for aid in all_album_ids], - return_exceptions=True, - ) ext = external_domain.rstrip("/") - - albums_data: list[dict] = [] - for album_id, result, links in zip(all_album_ids, album_results, link_results): - if isinstance(result, Exception): - _LOGGER.warning("Failed to fetch album %s: %s", album_id, result) - continue - if result: - pub_url = "" - if not isinstance(links, Exception) and ext: - pub_url = get_public_url(ext, links) or "" - albums_data.append({ - "name": result.name, "asset_count": result.asset_count, - "id": album_id, "public_url": pub_url, - }) - + albums_data = await fetch_albums_with_links(client, all_album_ids, ext, include_failed=False) return _render_cmd_template(cmd_templates, "summary", locale, {"albums": albums_data}) diff --git a/packages/server/src/notify_bridge_server/commands/immich/common.py b/packages/server/src/notify_bridge_server/commands/immich/common.py index 0c2906f..c637f6a 100644 --- a/packages/server/src/notify_bridge_server/commands/immich/common.py +++ b/packages/server/src/notify_bridge_server/commands/immich/common.py @@ -2,10 +2,12 @@ from __future__ import annotations +import asyncio import logging from typing import Any -from ...services import make_immich_provider +from notify_bridge_core.providers.immich.asset_utils import get_public_url + from ..handler import _render_cmd_template _LOGGER = logging.getLogger(__name__) @@ -17,6 +19,53 @@ _IMMICH_COMMANDS = { } +async def fetch_albums_with_links( + client: Any, + album_ids: list[str], + ext_domain: str, + *, + include_failed: bool = True, +) -> list[dict[str, Any]]: + """Fetch albums and their shared links concurrently. + + Returns a list of album data dicts with keys: name, asset_count, id, + public_url, and ``_album`` (the raw album object for callers that need + asset-level access). + + When *include_failed* is True, albums that fail to fetch are included + with placeholder data (``"?"`` for counts). When False, they are + silently skipped. + """ + album_results = await asyncio.gather( + *[client.get_album(aid) for aid in album_ids], + return_exceptions=True, + ) + link_results = await asyncio.gather( + *[client.get_shared_links(aid) for aid in album_ids], + return_exceptions=True, + ) + + albums_data: list[dict[str, Any]] = [] + for album_id, result, links in zip(album_ids, album_results, link_results): + if isinstance(result, Exception): + _LOGGER.warning("Failed to fetch album %s: %s", album_id, result) + if include_failed: + albums_data.append({ + "name": f"{album_id[:8]}...", "asset_count": "?", + "id": album_id, "public_url": "", "_album": None, + }) + continue + if result: + pub_url = "" + if not isinstance(links, Exception) and ext_domain: + pub_url = get_public_url(ext_domain, links) or "" + albums_data.append({ + "name": result.name, "asset_count": result.asset_count, + "id": album_id, "public_url": pub_url, "_album": result, + }) + return albums_data + + def build_asset_dict( asset: Any, *, @@ -56,8 +105,14 @@ def _format_assets( assets: list[dict[str, Any]], cmd: str, query: str, locale: str, response_mode: str, client: Any, cmd_templates: dict[str, dict[str, str]], -) -> str | list[dict[str, Any]]: - """Format asset results as text or media payload.""" +) -> str | dict[str, Any]: + """Format asset results as text or a text-plus-media payload. + + Returns: + str: rendered text when *response_mode* is ``"text"`` (or no assets). + dict: ``{"text": ..., "media": [...]}`` when *response_mode* is + ``"media"`` and assets are present. + """ if not assets: return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": query}) @@ -68,7 +123,7 @@ def _format_assets( }) if response_mode == "media": - media_items = [] + media_items: list[dict[str, Any]] = [] for asset in assets: asset_id = asset.get("id", "") media_items.append({ diff --git a/packages/server/src/notify_bridge_server/commands/immich/events.py b/packages/server/src/notify_bridge_server/commands/immich/events.py index 9d983b0..640bb95 100644 --- a/packages/server/src/notify_bridge_server/commands/immich/events.py +++ b/packages/server/src/notify_bridge_server/commands/immich/events.py @@ -13,23 +13,22 @@ from sqlmodel.ext.asyncio.session import AsyncSession from ...database.engine import get_engine from ...database.models import ( - EventLog, NotificationTarget, NotificationTrackerTarget, - ServiceProvider, TelegramBot, TrackingConfig, + EventLog, NotificationTracker, NotificationTrackerTarget, + ServiceProvider, TrackingConfig, ) -from notify_bridge_core.providers.immich.asset_utils import get_public_url -from ..handler import _get_notification_trackers_for_providers, _render_cmd_template -from .common import _format_assets, build_asset_dict +from ..command_utils import get_trackers_for_provider +from ..handler import _render_cmd_template +from .common import _format_assets, build_asset_dict, fetch_albums_with_links _LOGGER = logging.getLogger(__name__) async def _cmd_events( - bot: TelegramBot, providers_map: dict[int, ServiceProvider], + provider: ServiceProvider, count: int, locale: str, ) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) + trackers = await get_trackers_for_provider(provider.id) tracker_ids = [t.id for t in trackers] if not tracker_ids: return {"events": []} @@ -57,32 +56,21 @@ async def cmd_latest( locale: str, response_mode: str, cmd_templates: dict[str, dict[str, str]], external_domain: str = "", -) -> str | list[dict[str, Any]]: +) -> str | dict[str, Any]: """Handle /latest command with concurrent album fetching.""" album_ids = all_album_ids[:10] if not album_ids: return _format_assets([], "latest", "", locale, response_mode, client, cmd_templates) - album_results = await asyncio.gather( - *[client.get_album(aid) for aid in album_ids], - return_exceptions=True, - ) - link_results = await asyncio.gather( - *[client.get_shared_links(aid) for aid in album_ids], - return_exceptions=True, - ) ext = external_domain.rstrip("/") + fetched = await fetch_albums_with_links(client, album_ids, ext, include_failed=False) latest_assets: list[dict[str, Any]] = [] - for album_id, result, links in zip(album_ids, album_results, link_results): - if isinstance(result, Exception): - _LOGGER.warning("Failed to fetch album %s: %s", album_id, result) - continue - if result: - pub_url = "" - if not isinstance(links, Exception) and ext: - pub_url = get_public_url(ext, links) or "" - for aid, asset in list(result.assets.items())[:count]: + for album_data in fetched: + pub_url = album_data.get("public_url", "") + album_obj = album_data.get("_album") + if album_obj: + for aid, asset in list(album_obj.assets.items())[:count]: asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else "" latest_assets.append(build_asset_dict(asset, public_url=asset_pub)) @@ -95,32 +83,21 @@ async def cmd_random( locale: str, response_mode: str, cmd_templates: dict[str, dict[str, str]], external_domain: str = "", -) -> str | list[dict[str, Any]]: +) -> str | dict[str, Any]: """Handle /random command with concurrent album fetching.""" album_ids = all_album_ids[:10] if not album_ids: return _format_assets([], "random", "", locale, response_mode, client, cmd_templates) - album_results = await asyncio.gather( - *[client.get_album(aid) for aid in album_ids], - return_exceptions=True, - ) - link_results = await asyncio.gather( - *[client.get_shared_links(aid) for aid in album_ids], - return_exceptions=True, - ) ext = external_domain.rstrip("/") + fetched = await fetch_albums_with_links(client, album_ids, ext, include_failed=False) random_assets: list[dict[str, Any]] = [] - for album_id, result, links in zip(album_ids, album_results, link_results): - if isinstance(result, Exception): - _LOGGER.warning("Failed to fetch album %s: %s", album_id, result) - continue - if result: - pub_url = "" - if not isinstance(links, Exception) and ext: - pub_url = get_public_url(ext, links) or "" - asset_list = list(result.assets.values()) + for album_data in fetched: + pub_url = album_data.get("public_url", "") + album_obj = album_data.get("_album") + if album_obj: + asset_list = list(album_obj.assets.values()) sampled = rng.sample(asset_list, min(count, len(asset_list))) for asset in sampled: asset_pub = f"{pub_url}/photos/{asset.id}" if pub_url else "" @@ -130,40 +107,40 @@ async def cmd_random( return _format_assets(random_assets[:count], "random", "", locale, response_mode, client, cmd_templates) -async def _check_native_memory(bot: TelegramBot) -> bool: - """Check if any tracker-target linked to this bot uses native memory source.""" +async def _check_native_memory(provider_id: int) -> bool: + """Check if any notification tracker for this provider uses native memory source.""" engine = get_engine() async with AsyncSession(engine) as session: - result = await session.exec( - select(NotificationTarget).where( - NotificationTarget.type == "telegram", - NotificationTarget.user_id == bot.user_id, + tracker_result = await session.exec( + select(NotificationTracker).where( + NotificationTracker.provider_id == provider_id, ) ) - targets = result.all() - bot_target_ids = {t.id for t in targets if t.config.get("bot_token") == bot.token} - if not bot_target_ids: + trackers = tracker_result.all() + tracker_ids = [t.id for t in trackers] + if not tracker_ids: return False tt_result = await session.exec( select(NotificationTrackerTarget).where( - NotificationTrackerTarget.target_id.in_(bot_target_ids) + NotificationTrackerTarget.tracker_id.in_(tracker_ids) ) ) - for tt in tt_result.all(): - if tt.tracking_config_id: - tc = await session.get(TrackingConfig, tt.tracking_config_id) - if tc and tc.memory_source == "native": - return True - return False + tc_ids = list({tt.tracking_config_id for tt in tt_result.all() if tt.tracking_config_id}) + if not tc_ids: + return False + tc_result = await session.exec( + select(TrackingConfig).where(TrackingConfig.id.in_(tc_ids)) + ) + return any(tc.memory_source == "native" for tc in tc_result.all()) async def cmd_memory( - bot: TelegramBot, client: Any, all_album_ids: list[str], count: int, + provider_id: int, client: Any, all_album_ids: list[str], count: int, locale: str, response_mode: str, cmd_templates: dict[str, dict[str, str]], -) -> str | list[dict[str, Any]]: +) -> str | dict[str, Any]: """Handle /memory command with concurrent album fetching.""" - use_native = await _check_native_memory(bot) + use_native = await _check_native_memory(provider_id) today = datetime.now(timezone.utc) memory_assets: list[dict[str, Any]] = [] diff --git a/packages/server/src/notify_bridge_server/commands/immich/handler.py b/packages/server/src/notify_bridge_server/commands/immich/handler.py index a916ee4..24f8398 100644 --- a/packages/server/src/notify_bridge_server/commands/immich/handler.py +++ b/packages/server/src/notify_bridge_server/commands/immich/handler.py @@ -2,26 +2,21 @@ from __future__ import annotations -import asyncio import logging from typing import Any -import aiohttp -from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession - -from ...database.engine import get_engine from ...database.models import ( - CommandConfig, CommandTracker, EventLog, + CommandConfig, CommandTracker, ServiceProvider, TelegramBot, ) from ...services import make_immich_provider -from ..base import ProviderCommandHandler -from ..handler import _get_notification_trackers_for_providers, _render_cmd_template -from notify_bridge_core.providers.immich.asset_utils import get_public_url +from ...services.http_session import get_http_session +from ..base import CommandResponse, ProviderCommandHandler +from ..command_utils import get_last_event_str, get_trackers_for_provider +from ..handler import _render_cmd_template from .albums import _cmd_albums, cmd_favorites, cmd_summary -from .common import _IMMICH_COMMANDS +from .common import _IMMICH_COMMANDS, fetch_albums_with_links from .events import _cmd_events, cmd_latest, cmd_memory, cmd_random from .search import cmd_find, cmd_person, cmd_place, cmd_search @@ -29,21 +24,15 @@ _LOGGER = logging.getLogger(__name__) async def _cmd_status( - bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str, + provider: ServiceProvider, locale: str, ) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) + trackers = await get_trackers_for_provider(provider.id) active = sum(1 for t in trackers if t.enabled) total = len(trackers) total_albums = sum(len(t.collection_ids or []) for t in trackers) - engine = get_engine() - async with AsyncSession(engine) as session: - result = await session.exec( - select(EventLog).order_by(EventLog.created_at.desc()).limit(1) - ) - last_event = result.first() - last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-" + tracker_ids = [t.id for t in trackers] + last_str = await get_last_event_str(tracker_ids) return { "trackers_active": active, "trackers_total": total, @@ -52,16 +41,13 @@ async def _cmd_status( async def _cmd_people( - providers_map: dict[int, ServiceProvider], locale: str, + provider: ServiceProvider, locale: str, ) -> dict[str, Any]: all_people: dict[str, str] = {} - async with aiohttp.ClientSession() as http: - for provider in providers_map.values(): - if provider.type != "immich": - continue - immich = make_immich_provider(http, provider) - people = await immich.client.get_people() - all_people.update(people) + http = await get_http_session() + immich = make_immich_provider(http, provider) + people = await immich.client.get_people() + all_people.update(people) names = sorted(all_people.values()) return {"people": names} @@ -87,106 +73,92 @@ class ImmichCommandHandler(ProviderCommandHandler): count: int, locale: str, response_mode: str, - providers_map: dict[int, ServiceProvider], + provider: ServiceProvider, cmd_templates: dict[str, dict[str, str]], bot: TelegramBot, - ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], - ) -> str | list[dict[str, Any]] | None: + tracker: CommandTracker, + config: CommandConfig, + ) -> CommandResponse | None: if cmd == "status": - ctx = await _cmd_status(bot, providers_map, locale) - return _render_cmd_template(cmd_templates, "status", locale, ctx) + ctx = await _cmd_status(provider, locale) + return CommandResponse(text=_render_cmd_template(cmd_templates, "status", locale, ctx)) if cmd == "albums": - ctx = await _cmd_albums(bot, providers_map, locale) - return _render_cmd_template(cmd_templates, "albums", locale, ctx) + ctx = await _cmd_albums(provider, locale) + return CommandResponse(text=_render_cmd_template(cmd_templates, "albums", locale, ctx)) if cmd == "events": - ctx = await _cmd_events(bot, providers_map, count, locale) - return _render_cmd_template(cmd_templates, "events", locale, ctx) + ctx = await _cmd_events(provider, count, locale) + return CommandResponse(text=_render_cmd_template(cmd_templates, "events", locale, ctx)) if cmd == "people": - ctx = await _cmd_people(providers_map, locale) - return _render_cmd_template(cmd_templates, "people", locale, ctx) + ctx = await _cmd_people(provider, locale) + return CommandResponse(text=_render_cmd_template(cmd_templates, "people", locale, ctx)) if cmd in ("search", "find", "person", "place", "latest", "random", "favorites", "summary", "memory"): return await _cmd_immich( - bot, cmd, args, count, locale, response_mode, - providers_map, cmd_templates, + cmd, args, count, locale, response_mode, + provider, cmd_templates, ) return None async def _cmd_immich( - bot: TelegramBot, cmd: str, args: str, count: int, locale: str, - response_mode: str, providers_map: dict[int, ServiceProvider], + cmd: str, args: str, count: int, locale: str, + response_mode: str, provider: ServiceProvider, cmd_templates: dict[str, dict[str, str]], -) -> str | list[dict[str, Any]]: +) -> CommandResponse | None: """Handle commands that need Immich API access and may return media.""" - if not providers_map: - return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args}) - - provider_ids = set(providers_map.keys()) - notification_trackers = await _get_notification_trackers_for_providers(provider_ids) + notification_trackers = await get_trackers_for_provider(provider.id) all_album_ids: list[str] = [] for t in notification_trackers: all_album_ids.extend(t.collection_ids or []) - provider: ServiceProvider | None = None - for p in providers_map.values(): - if p.type == "immich": - provider = p - break - if not provider: - return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args}) - ext_domain = (provider.config.get("external_domain") or provider.config.get("url", "")).rstrip("/") - async with aiohttp.ClientSession() as http: - immich = make_immich_provider(http, provider) - client = immich.client + http = await get_http_session() + immich = make_immich_provider(http, provider) + client = immich.client - # Build asset_id → public_url map from tracked albums' shared links - asset_public_urls: dict[str, str] = {} - if ext_domain and all_album_ids and cmd in ("search", "find", "person", "place", "favorites"): - link_results = await asyncio.gather( - *[client.get_shared_links(aid) for aid in all_album_ids], - return_exceptions=True, - ) - album_results = await asyncio.gather( - *[client.get_album(aid) for aid in all_album_ids], - return_exceptions=True, - ) - for album_id, links, album in zip(all_album_ids, link_results, album_results): - if isinstance(links, Exception) or isinstance(album, Exception): - continue - pub_url = get_public_url(ext_domain, links) - if pub_url and album: - for asset_id in album.assets: - asset_public_urls[asset_id] = f"{pub_url}/photos/{asset_id}" + # Build asset_id → public_url map from tracked albums' shared links + asset_public_urls: dict[str, str] = {} + if ext_domain and all_album_ids and cmd in ("search", "find", "person", "place", "favorites"): + fetched = await fetch_albums_with_links(client, all_album_ids, ext_domain, include_failed=False) + for album_data in fetched: + pub_url = album_data.get("public_url", "") + album_obj = album_data.get("_album") + if pub_url and album_obj: + for asset_id in album_obj.assets: + asset_public_urls[asset_id] = f"{pub_url}/photos/{asset_id}" - if cmd == "search": - return await cmd_search(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls) + # Wrap single-provider in a map for functions that still expect it + providers_map = {provider.id: provider} - if cmd == "find": - return await cmd_find(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls) + result: str | dict[str, Any] | None = None - if cmd == "person": - return await cmd_person(client, args, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls) + if cmd == "search": + result = await cmd_search(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls) + elif cmd == "find": + result = await cmd_find(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls) + elif cmd == "person": + result = await cmd_person(client, args, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls) + elif cmd == "place": + result = await cmd_place(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls) + elif cmd == "favorites": + result = await cmd_favorites(providers_map, all_album_ids, count, locale, response_mode, client, cmd_templates) + elif cmd == "latest": + result = await cmd_latest(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain) + elif cmd == "random": + result = await cmd_random(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain) + elif cmd == "summary": + result = await cmd_summary(client, all_album_ids, locale, cmd_templates, external_domain=ext_domain) + elif cmd == "memory": + result = await cmd_memory(provider.id, client, all_album_ids, count, locale, response_mode, cmd_templates) - if cmd == "place": - return await cmd_place(client, args, all_album_ids, count, locale, response_mode, cmd_templates, asset_public_urls=asset_public_urls) - - if cmd == "favorites": - return await cmd_favorites(bot, providers_map, all_album_ids, count, locale, response_mode, client, cmd_templates) - - if cmd == "latest": - return await cmd_latest(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain) - - if cmd == "random": - return await cmd_random(client, all_album_ids, count, locale, response_mode, cmd_templates, external_domain=ext_domain) - - if cmd == "summary": - return await cmd_summary(client, all_album_ids, locale, cmd_templates, external_domain=ext_domain) - - if cmd == "memory": - return await cmd_memory(bot, client, all_album_ids, count, locale, response_mode, cmd_templates) - - return None + if result is None: + return None + # _format_assets returns {"text": ..., "media": [...]} for media mode + if isinstance(result, dict): + return CommandResponse( + text=result.get("text"), + media=result.get("media", []), + ) + return CommandResponse(text=result) diff --git a/packages/server/src/notify_bridge_server/commands/immich/search.py b/packages/server/src/notify_bridge_server/commands/immich/search.py index 7881ff2..e3e6573 100644 --- a/packages/server/src/notify_bridge_server/commands/immich/search.py +++ b/packages/server/src/notify_bridge_server/commands/immich/search.py @@ -9,14 +9,15 @@ from .common import _format_assets def _enrich_assets(assets: list[dict[str, Any]], asset_public_urls: dict[str, str]) -> list[dict[str, Any]]: - """Add public_url to assets from the pre-built map.""" + """Add public_url to assets from the pre-built map. Returns new list without mutating inputs.""" if not asset_public_urls: return assets - for asset in assets: - aid = asset.get("id", "") - if aid and aid in asset_public_urls and not asset.get("public_url"): - asset["public_url"] = asset_public_urls[aid] - return assets + return [ + {**asset, "public_url": asset_public_urls.get(asset.get("id", ""), "")} + if asset.get("id", "") in asset_public_urls and not asset.get("public_url") + else asset + for asset in assets + ] async def cmd_search( @@ -24,7 +25,7 @@ async def cmd_search( locale: str, response_mode: str, cmd_templates: dict[str, dict[str, str]], asset_public_urls: dict[str, str] | None = None, -) -> str | list[dict[str, Any]]: +) -> str | dict[str, Any]: """Handle /search command.""" if not args: return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "search", "query": ""}) @@ -38,7 +39,7 @@ async def cmd_find( locale: str, response_mode: str, cmd_templates: dict[str, dict[str, str]], asset_public_urls: dict[str, str] | None = None, -) -> str | list[dict[str, Any]]: +) -> str | dict[str, Any]: """Handle /find command.""" if not args: return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "find", "query": ""}) @@ -52,7 +53,7 @@ async def cmd_person( locale: str, response_mode: str, cmd_templates: dict[str, dict[str, str]], asset_public_urls: dict[str, str] | None = None, -) -> str | list[dict[str, Any]]: +) -> str | dict[str, Any]: """Handle /person command.""" if not args: return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": ""}) @@ -74,7 +75,7 @@ async def cmd_place( locale: str, response_mode: str, cmd_templates: dict[str, dict[str, str]], asset_public_urls: dict[str, str] | None = None, -) -> str | list[dict[str, Any]]: +) -> str | dict[str, Any]: """Handle /place command.""" if not args: return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "place", "query": ""}) diff --git a/packages/server/src/notify_bridge_server/commands/nut_handler.py b/packages/server/src/notify_bridge_server/commands/nut_handler.py index d8a579c..e0bfa4a 100644 --- a/packages/server/src/notify_bridge_server/commands/nut_handler.py +++ b/packages/server/src/notify_bridge_server/commands/nut_handler.py @@ -3,17 +3,31 @@ from __future__ import annotations import logging +from collections.abc import Callable, Coroutine from typing import Any from ..database.models import CommandConfig, CommandTracker, ServiceProvider, TelegramBot from ..services import make_nut_provider -from .base import ProviderCommandHandler +from .base import CommandResponse, ProviderCommandHandler from .handler import _render_cmd_template _LOGGER = logging.getLogger(__name__) _NUT_COMMANDS = {"status", "devices", "battery"} +# --------------------------------------------------------------------------- +# Command dispatch table +# --------------------------------------------------------------------------- + +_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {} + + +def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]: + """Register a function in the text command dispatch table.""" + name = fn.__name__.removeprefix("_cmd_") + _TEXT_COMMANDS[name] = fn + return fn + class NutCommandHandler(ProviderCommandHandler): """Handles NUT-specific bot commands.""" @@ -33,80 +47,73 @@ class NutCommandHandler(ProviderCommandHandler): count: int, locale: str, response_mode: str, - providers_map: dict[int, ServiceProvider], + provider: ServiceProvider, cmd_templates: dict[str, dict[str, str]], bot: TelegramBot, - ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], - ) -> str | list[dict[str, Any]] | None: - if cmd == "status": - ctx = await _cmd_status(providers_map) - return _render_cmd_template(cmd_templates, "status", locale, ctx) - if cmd == "devices": - ctx = await _cmd_devices(providers_map) - return _render_cmd_template(cmd_templates, "devices", locale, ctx) - if cmd == "battery": - ctx = await _cmd_battery(providers_map) - return _render_cmd_template(cmd_templates, "battery", locale, ctx) - return None + tracker: CommandTracker, + config: CommandConfig, + ) -> CommandResponse | None: + fn = _TEXT_COMMANDS.get(cmd) + if fn is None: + return None + ctx = await fn(provider, count) + return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx)) -async def _query_all_ups( - providers_map: dict[int, ServiceProvider], +async def _query_ups( + provider: ServiceProvider, ) -> list[dict[str, Any]]: - """Connect to all NUT providers and query UPS data.""" + """Connect to a NUT provider and query UPS data.""" from notify_bridge_core.providers.nut.models import NutUpsData results: list[dict[str, Any]] = [] - for provider in providers_map.values(): - if provider.type != "nut": - continue - nut = make_nut_provider(provider) + nut = make_nut_provider(provider) + try: + client = nut._make_client() + await client.connect() try: - client = nut._make_client() - await client.connect() - try: - devices = await client.list_ups() - for dev in devices: - variables = await client.list_var(dev.name) - data = NutUpsData.from_variables(dev.name, variables) - results.append({ - "name": data.name, - "description": data.description, - "model": data.model, - "manufacturer": data.manufacturer, - "status": data.status, - "battery_charge": int(data.battery_charge) if data.battery_charge is not None else None, - "battery_runtime": data.battery_runtime_formatted, - "ups_load": int(data.ups_load) if data.ups_load is not None else None, - "input_voltage": str(data.input_voltage) if data.input_voltage is not None else None, - "output_voltage": str(data.output_voltage) if data.output_voltage is not None else None, - }) - finally: - await client.disconnect() - except Exception as exc: - _LOGGER.warning("Failed to query NUT provider %s: %s", provider.name, exc) + devices = await client.list_ups() + for dev in devices: + variables = await client.list_var(dev.name) + data = NutUpsData.from_variables(dev.name, variables) + results.append({ + "name": data.name, + "description": data.description, + "model": data.model, + "manufacturer": data.manufacturer, + "status": data.status, + "battery_charge": int(data.battery_charge) if data.battery_charge is not None else None, + "battery_runtime": data.battery_runtime_formatted, + "ups_load": int(data.ups_load) if data.ups_load is not None else None, + "input_voltage": str(data.input_voltage) if data.input_voltage is not None else None, + "output_voltage": str(data.output_voltage) if data.output_voltage is not None else None, + }) + finally: + await client.disconnect() + except Exception as exc: + _LOGGER.warning("Failed to query NUT provider %s: %s", provider.name, exc) return results -async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: - devices = await _query_all_ups(providers_map) +@_text_cmd +async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]: + devices = await _query_ups(provider) return {"devices": devices} -async def _cmd_devices(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: +@_text_cmd +async def _cmd_devices(provider: ServiceProvider, count: int) -> dict[str, Any]: devices: list[dict[str, Any]] = [] - for provider in providers_map.values(): - if provider.type != "nut": - continue - nut = make_nut_provider(provider) - try: - device_list = await nut.list_collections() - devices.extend(device_list) - except Exception as exc: - _LOGGER.warning("Failed to list devices from %s: %s", provider.name, exc) + nut = make_nut_provider(provider) + try: + device_list = await nut.list_collections() + devices.extend(device_list) + except Exception as exc: + _LOGGER.warning("Failed to list devices from %s: %s", provider.name, exc) return {"devices": devices} -async def _cmd_battery(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: - devices = await _query_all_ups(providers_map) +@_text_cmd +async def _cmd_battery(provider: ServiceProvider, count: int) -> dict[str, Any]: + devices = await _query_ups(provider) return {"devices": devices} diff --git a/packages/server/src/notify_bridge_server/commands/planka_handler.py b/packages/server/src/notify_bridge_server/commands/planka_handler.py index ad10f8c..ae00565 100644 --- a/packages/server/src/notify_bridge_server/commands/planka_handler.py +++ b/packages/server/src/notify_bridge_server/commands/planka_handler.py @@ -3,26 +3,47 @@ from __future__ import annotations import logging +from collections.abc import Callable, Coroutine from typing import Any -import aiohttp -from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession - -from ..database.engine import get_engine from ..database.models import ( - CommandConfig, CommandTracker, EventLog, - NotificationTracker, ServiceProvider, TelegramBot, + CommandConfig, CommandTracker, ServiceProvider, TelegramBot, ) from ..services import make_planka_provider -from .base import ProviderCommandHandler -from .handler import _render_cmd_template, _get_notification_trackers_for_providers +from ..services.http_session import get_http_session +from .base import CommandResponse, ProviderCommandHandler +from .command_utils import get_last_event_str, get_tracked_collection_ids, get_trackers_for_provider +from .handler import _render_cmd_template _LOGGER = logging.getLogger(__name__) _PLANKA_COMMANDS = {"status", "boards", "cards", "lists"} +def _get_tracked_board_ids( + provider: ServiceProvider, + trackers: list, +) -> list[str]: + """Get board IDs from tracked collection_ids for this provider.""" + if not provider.config.get("api_key"): + return [] + return get_tracked_collection_ids(provider, trackers) + + +# --------------------------------------------------------------------------- +# Command dispatch table +# --------------------------------------------------------------------------- + +_TEXT_COMMANDS: dict[str, Callable[..., Coroutine[Any, Any, dict[str, Any]]]] = {} + + +def _text_cmd(fn: Callable[..., Coroutine[Any, Any, dict[str, Any]]]) -> Callable[..., Coroutine[Any, Any, dict[str, Any]]]: + """Register a function in the text command dispatch table.""" + name = fn.__name__.removeprefix("_cmd_") + _TEXT_COMMANDS[name] = fn + return fn + + class PlankaCommandHandler(ProviderCommandHandler): """Handles Planka-specific bot commands.""" @@ -43,69 +64,26 @@ class PlankaCommandHandler(ProviderCommandHandler): count: int, locale: str, response_mode: str, - providers_map: dict[int, ServiceProvider], + provider: ServiceProvider, cmd_templates: dict[str, dict[str, str]], bot: TelegramBot, - ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]], - ) -> str | list[dict[str, Any]] | None: - if cmd == "status": - ctx = await _cmd_status(providers_map) - return _render_cmd_template(cmd_templates, "status", locale, ctx) - if cmd == "boards": - ctx = await _cmd_boards(providers_map) - return _render_cmd_template(cmd_templates, "boards", locale, ctx) - if cmd == "cards": - ctx = await _cmd_cards(providers_map, count) - return _render_cmd_template(cmd_templates, "cards", locale, ctx) - if cmd == "lists": - ctx = await _cmd_lists(providers_map) - return _render_cmd_template(cmd_templates, "lists", locale, ctx) - return None + tracker: CommandTracker, + config: CommandConfig, + ) -> CommandResponse | None: + fn = _TEXT_COMMANDS.get(cmd) + if fn is None: + return None + ctx = await fn(provider, count) + return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx)) -def _get_tracked_board_ids( - providers_map: dict[int, ServiceProvider], - trackers: list[NotificationTracker], -) -> list[tuple[ServiceProvider, str]]: - """Get (provider, board_id) tuples from tracked collection_ids.""" - boards: list[tuple[ServiceProvider, str]] = [] - for tracker in trackers: - provider = providers_map.get(tracker.provider_id) - if not provider or provider.type != "planka": - continue - if not provider.config.get("api_key"): - continue - for board_id in (tracker.collection_ids or []): - entry = (provider, board_id) - if entry not in boards: - boards.append(entry) - # Also check filters.collections - for board_id in (tracker.filters or {}).get("collections", []): - entry = (provider, board_id) - if entry not in boards: - boards.append(entry) - return boards[:20] +@_text_cmd +async def _cmd_status(provider: ServiceProvider, count: int) -> dict[str, Any]: + trackers = await get_trackers_for_provider(provider.id) + tracked_boards = _get_tracked_board_ids(provider, trackers) - -async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) - tracked_boards = _get_tracked_board_ids(providers_map, trackers) - - # Last event - engine = get_engine() - async with AsyncSession(engine) as session: - tracker_ids = [t.id for t in trackers] - if tracker_ids: - result = await session.exec( - select(EventLog) - .where(EventLog.tracker_id.in_(tracker_ids)) - .order_by(EventLog.created_at.desc()).limit(1) - ) - last_event = result.first() - else: - last_event = None - last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-" + tracker_ids = [t.id for t in trackers] + last_str = await get_last_event_str(tracker_ids) return { "boards_count": len(tracked_boards), @@ -113,81 +91,69 @@ async def _cmd_status(providers_map: dict[int, ServiceProvider]) -> dict[str, An } -async def _cmd_boards(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) - tracked_boards = _get_tracked_board_ids(providers_map, trackers) +@_text_cmd +async def _cmd_boards(provider: ServiceProvider, count: int) -> dict[str, Any]: + trackers = await get_trackers_for_provider(provider.id) + tracked_boards = _get_tracked_board_ids(provider, trackers) boards_data: list[dict[str, Any]] = [] - async with aiohttp.ClientSession() as http: - for provider, board_id in tracked_boards: - planka = make_planka_provider(http, provider) - all_boards = await planka.client.get_boards() - for b in all_boards: - if str(b.get("id", "")) == board_id: - boards_data.append({"name": b.get("name", board_id)}) - break - else: - boards_data.append({"name": board_id}) + http = await get_http_session() + planka = make_planka_provider(http, provider) + all_boards = await planka.client.get_boards() + board_names = {str(b.get("id", "")): b.get("name", "") for b in all_boards} + for board_id in tracked_boards: + boards_data.append({"name": board_names.get(board_id, board_id)}) return {"boards": boards_data} -async def _cmd_cards( - providers_map: dict[int, ServiceProvider], count: int, -) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) - tracked_boards = _get_tracked_board_ids(providers_map, trackers) +@_text_cmd +async def _cmd_cards(provider: ServiceProvider, count: int) -> dict[str, Any]: + trackers = await get_trackers_for_provider(provider.id) + tracked_boards = _get_tracked_board_ids(provider, trackers) all_cards: list[dict[str, Any]] = [] - async with aiohttp.ClientSession() as http: - for provider, board_id in tracked_boards: - planka = make_planka_provider(http, provider) - cards = await planka.client.get_board_cards(board_id, limit=count) - lists = await planka.client.get_board_lists(board_id) - lists_by_id = {str(lst.get("id", "")): lst.get("name", "") for lst in lists} + http = await get_http_session() + planka = make_planka_provider(http, provider) + boards = await planka.client.get_boards() + board_names = {str(b.get("id", "")): b.get("name", "") for b in boards} - boards = await planka.client.get_boards() - board_name = board_id - for b in boards: - if str(b.get("id", "")) == board_id: - board_name = b.get("name", board_id) - break + for board_id in tracked_boards: + cards = await planka.client.get_board_cards(board_id, limit=count) + lists = await planka.client.get_board_lists(board_id) + lists_by_id = {str(lst.get("id", "")): lst.get("name", "") for lst in lists} + board_name = board_names.get(board_id, board_id) - for card in cards: - list_id = str(card.get("listId", "")) - all_cards.append({ - "name": card.get("name", ""), - "list_name": lists_by_id.get(list_id, ""), - "board_name": board_name, - }) + for card in cards: + list_id = str(card.get("listId", "")) + all_cards.append({ + "name": card.get("name", ""), + "list_name": lists_by_id.get(list_id, ""), + "board_name": board_name, + }) return {"cards": all_cards[:count]} -async def _cmd_lists(providers_map: dict[int, ServiceProvider]) -> dict[str, Any]: - provider_ids = set(providers_map.keys()) - trackers = await _get_notification_trackers_for_providers(provider_ids) - tracked_boards = _get_tracked_board_ids(providers_map, trackers) +@_text_cmd +async def _cmd_lists(provider: ServiceProvider, count: int) -> dict[str, Any]: + trackers = await get_trackers_for_provider(provider.id) + tracked_boards = _get_tracked_board_ids(provider, trackers) all_lists: list[dict[str, Any]] = [] - async with aiohttp.ClientSession() as http: - for provider, board_id in tracked_boards: - planka = make_planka_provider(http, provider) - lists = await planka.client.get_board_lists(board_id) + http = await get_http_session() + planka = make_planka_provider(http, provider) + boards = await planka.client.get_boards() + board_names = {str(b.get("id", "")): b.get("name", "") for b in boards} - boards = await planka.client.get_boards() - board_name = board_id - for b in boards: - if str(b.get("id", "")) == board_id: - board_name = b.get("name", board_id) - break + for board_id in tracked_boards: + lists = await planka.client.get_board_lists(board_id) + board_name = board_names.get(board_id, board_id) - for lst in lists: - all_lists.append({ - "name": lst.get("name", ""), - "board_name": board_name, - }) + for lst in lists: + all_lists.append({ + "name": lst.get("name", ""), + "board_name": board_name, + }) return {"lists": all_lists} diff --git a/packages/server/src/notify_bridge_server/commands/webhook.py b/packages/server/src/notify_bridge_server/commands/webhook.py index f7e514f..da33a29 100644 --- a/packages/server/src/notify_bridge_server/commands/webhook.py +++ b/packages/server/src/notify_bridge_server/commands/webhook.py @@ -6,7 +6,6 @@ import hmac import logging from typing import Any -import aiohttp from fastapi import APIRouter, Depends, Header, HTTPException, Request from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -16,6 +15,7 @@ from notify_bridge_core.notifications.telegram.client import TelegramClient from ..database.engine import get_session from ..database.models import TelegramBot, TelegramChat from ..services.telegram import save_chat_from_webhook +from .base import CommandResponse from .handler import handle_command, send_media_group, send_reply _LOGGER = logging.getLogger(__name__) @@ -89,15 +89,13 @@ async def telegram_webhook( return {"ok": True, "skipped": "commands_disabled"} effective_lang = chat_row.language_override or msg_language message_id = message.get("message_id") - cmd_response = await handle_command(bot, chat_id, text, language_code=effective_lang) - if cmd_response is not None: - if isinstance(cmd_response, dict) and "media" in cmd_response: - await send_reply(bot.token, chat_id, cmd_response["text"], reply_to_message_id=message_id) - await send_media_group(bot.token, chat_id, cmd_response["media"], reply_to_message_id=message_id) - elif isinstance(cmd_response, list): - await send_media_group(bot.token, chat_id, cmd_response, reply_to_message_id=message_id) - else: - await send_reply(bot.token, chat_id, cmd_response, reply_to_message_id=message_id) + responses = await handle_command(bot, chat_id, text, language_code=effective_lang) + if responses: + for resp in responses: + if resp.text: + await send_reply(bot.token, chat_id, resp.text, reply_to_message_id=message_id) + if resp.media: + await send_media_group(bot.token, chat_id, resp.media, reply_to_message_id=message_id) return {"ok": True} return {"ok": True, "skipped": "not_a_command"} @@ -105,13 +103,15 @@ async def telegram_webhook( async def register_webhook(bot_token: str, webhook_url: str, secret: str | None = None) -> dict: """Register webhook URL with Telegram Bot API via TelegramClient.""" - async with aiohttp.ClientSession() as http: - client = TelegramClient(http, bot_token) - return await client.set_webhook(webhook_url, secret=secret) + from ..services.http_session import get_http_session + http = await get_http_session() + client = TelegramClient(http, bot_token) + return await client.set_webhook(webhook_url, secret=secret) async def unregister_webhook(bot_token: str) -> dict: """Remove webhook from Telegram Bot API via TelegramClient.""" - async with aiohttp.ClientSession() as http: - client = TelegramClient(http, bot_token) - return await client.delete_webhook() + from ..services.http_session import get_http_session + http = await get_http_session() + client = TelegramClient(http, bot_token) + return await client.delete_webhook() diff --git a/packages/server/src/notify_bridge_server/database/models.py b/packages/server/src/notify_bridge_server/database/models.py index 4be0b6b..767f8d0 100644 --- a/packages/server/src/notify_bridge_server/database/models.py +++ b/packages/server/src/notify_bridge_server/database/models.py @@ -359,6 +359,7 @@ class NotificationTrackerState(SQLModel, table=True): # Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id tracker_id: int = Field( foreign_key="notification_tracker.id", + index=True, sa_column_kwargs={"name": "notification_tracker_id"}, ) collection_id: str @@ -458,7 +459,7 @@ class CommandTrackerListener(SQLModel, table=True): id: int | None = Field(default=None, primary_key=True) command_tracker_id: int = Field( foreign_key="command_tracker.id", - + index=True, ) listener_type: str # e.g. "telegram_bot" diff --git a/packages/server/src/notify_bridge_server/main.py b/packages/server/src/notify_bridge_server/main.py index 02254f4..c1c4d43 100644 --- a/packages/server/src/notify_bridge_server/main.py +++ b/packages/server/src/notify_bridge_server/main.py @@ -73,6 +73,8 @@ async def lifespan(app: FastAPI): await start_scheduler() yield # Graceful shutdown + from .services.http_session import close_http_session + await close_http_session() scheduler = get_scheduler() if scheduler.running: scheduler.shutdown() diff --git a/packages/server/src/notify_bridge_server/services/__init__.py b/packages/server/src/notify_bridge_server/services/__init__.py index 3f46259..3bf5a9b 100644 --- a/packages/server/src/notify_bridge_server/services/__init__.py +++ b/packages/server/src/notify_bridge_server/services/__init__.py @@ -1,5 +1,11 @@ """Shared service utilities.""" +from __future__ import annotations + +from typing import Any, Protocol + +import aiohttp + from notify_bridge_core.providers.immich import ImmichServiceProvider from notify_bridge_core.providers.gitea import GiteaServiceProvider from notify_bridge_core.providers.planka import PlankaServiceProvider @@ -8,8 +14,23 @@ from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvid from ..database.models import ServiceProvider +# Default timeout for all outgoing HTTP requests to external services. +DEFAULT_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30) -def make_immich_provider(http_session, provider: ServiceProvider) -> ImmichServiceProvider: + +class CollectionProvider(Protocol): + """Protocol for providers that can list collections.""" + + async def list_collections(self) -> list[dict[str, Any]]: ... + + +class TestableProvider(Protocol): + """Protocol for providers that support connection testing.""" + + async def test_connection(self) -> dict[str, Any]: ... + + +def make_immich_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> ImmichServiceProvider: """Create an ImmichServiceProvider from a DB provider model.""" config = provider.config or {} return ImmichServiceProvider( @@ -21,7 +42,7 @@ def make_immich_provider(http_session, provider: ServiceProvider) -> ImmichServi ) -def make_gitea_provider(http_session, provider: ServiceProvider) -> GiteaServiceProvider: +def make_gitea_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> GiteaServiceProvider: """Create a GiteaServiceProvider from a DB provider model.""" config = provider.config or {} return GiteaServiceProvider( @@ -32,7 +53,7 @@ def make_gitea_provider(http_session, provider: ServiceProvider) -> GiteaService ) -def make_planka_provider(http_session, provider: ServiceProvider) -> PlankaServiceProvider: +def make_planka_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> PlankaServiceProvider: """Create a PlankaServiceProvider from a DB provider model.""" config = provider.config or {} return PlankaServiceProvider( @@ -55,7 +76,7 @@ def make_nut_provider(provider: ServiceProvider) -> NutServiceProvider: ) -def make_google_photos_provider(http_session, provider: ServiceProvider) -> GooglePhotosServiceProvider: +def make_google_photos_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> GooglePhotosServiceProvider: """Create a GooglePhotosServiceProvider from a DB provider model.""" config = provider.config or {} return GooglePhotosServiceProvider( @@ -65,3 +86,61 @@ def make_google_photos_provider(http_session, provider: ServiceProvider) -> Goog config.get("refresh_token", ""), provider.name, ) + + +# --------------------------------------------------------------------------- +# Provider factory registry — maps provider type strings to factory callables +# that create a provider with a ``list_collections`` method. Providers that +# require an API credential skip creation when the credential is missing +# (the factory returns None in that case). +# --------------------------------------------------------------------------- + +def _make_collection_provider( + http_session: aiohttp.ClientSession, + provider: ServiceProvider, +) -> CollectionProvider | None: + """Create a CollectionProvider for the given DB provider, or None if unsupported.""" + ptype = provider.type + config = provider.config or {} + + if ptype == "immich": + return make_immich_provider(http_session, provider) + if ptype == "gitea": + if not config.get("api_token"): + return None + return make_gitea_provider(http_session, provider) + if ptype == "planka": + if not config.get("api_key"): + return None + return make_planka_provider(http_session, provider) + if ptype == "google_photos": + return make_google_photos_provider(http_session, provider) + # NUT provider needs no http_session + if ptype == "nut": + return make_nut_provider(provider) # type: ignore[return-value] + return None + + +# Set of provider types that need an aiohttp session for collection listing. +_HTTP_COLLECTION_PROVIDERS = {"immich", "gitea", "planka", "google_photos"} + + +async def list_provider_collections(provider: ServiceProvider) -> list[dict[str, Any]]: + """List collections for any supported provider type. + + Returns an empty list for providers that don't support collections or + are missing required credentials. + """ + if provider.type in _HTTP_COLLECTION_PROVIDERS: + from .http_session import get_http_session + http_session = await get_http_session() + svc = _make_collection_provider(http_session, provider) + if svc is None: + return [] + return await svc.list_collections() + + # Non-HTTP providers (e.g. NUT) + svc = _make_collection_provider(None, provider) # type: ignore[arg-type] + if svc is None: + return [] + return await svc.list_collections() diff --git a/packages/server/src/notify_bridge_server/services/action_runner.py b/packages/server/src/notify_bridge_server/services/action_runner.py index 97abfb5..c6acbfd 100644 --- a/packages/server/src/notify_bridge_server/services/action_runner.py +++ b/packages/server/src/notify_bridge_server/services/action_runner.py @@ -6,7 +6,6 @@ import logging from datetime import datetime, timezone from typing import Any -import aiohttp from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -159,27 +158,28 @@ async def _execute_with_provider( ) from notify_bridge_core.providers.immich.client import ImmichClient - async with aiohttp.ClientSession() as http_session: - client = ImmichClient( - http_session, - provider_config.get("url", ""), - provider_config.get("api_key", ""), + from .http_session import get_http_session + http_session = await get_http_session() + client = ImmichClient( + http_session, + provider_config.get("url", ""), + provider_config.get("api_key", ""), + ) + external_domain = provider_config.get("external_domain") + if external_domain: + client.external_domain = external_domain + + # Verify connectivity + if not await client.ping(): + return ActionResult( + success=False, + error=f"Cannot connect to Immich server ({provider_name})", ) - external_domain = provider_config.get("external_domain") - if external_domain: - client.external_domain = external_domain - # Verify connectivity - if not await client.ping(): - return ActionResult( - success=False, - error=f"Cannot connect to Immich server ({provider_name})", - ) - - executor = ImmichActionExecutor(client) - if dry_run: - return await executor.dry_run(action_type, rule_configs, action_config) - return await executor.execute(action_type, rule_configs, action_config) + executor = ImmichActionExecutor(client) + if dry_run: + return await executor.dry_run(action_type, rule_configs, action_config) + return await executor.execute(action_type, rule_configs, action_config) return ActionResult( success=False, diff --git a/packages/server/src/notify_bridge_server/services/http_session.py b/packages/server/src/notify_bridge_server/services/http_session.py new file mode 100644 index 0000000..bf0021d --- /dev/null +++ b/packages/server/src/notify_bridge_server/services/http_session.py @@ -0,0 +1,33 @@ +"""Application-level shared aiohttp.ClientSession. + +All outgoing HTTP requests in the server package should use the shared +session returned by ``get_http_session()`` instead of creating +per-request ``aiohttp.ClientSession`` instances. This keeps a single +TCP connection pool alive for the lifetime of the process, avoiding +the overhead of pool creation/teardown on every request. + +Call ``close_http_session()`` once during application shutdown. +""" + +from __future__ import annotations + +import aiohttp + +_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30) +_session: aiohttp.ClientSession | None = None + + +async def get_http_session() -> aiohttp.ClientSession: + """Get or create the shared HTTP session.""" + global _session + if _session is None or _session.closed: + _session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT) + return _session + + +async def close_http_session() -> None: + """Close the shared HTTP session (call on app shutdown).""" + global _session + if _session is not None and not _session.closed: + await _session.close() + _session = None diff --git a/packages/server/src/notify_bridge_server/services/notifier.py b/packages/server/src/notify_bridge_server/services/notifier.py index 557cdae..7856c50 100644 --- a/packages/server/src/notify_bridge_server/services/notifier.py +++ b/packages/server/src/notify_bridge_server/services/notifier.py @@ -3,8 +3,6 @@ import logging from typing import Any -import aiohttp - from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -90,19 +88,21 @@ async def _send_telegram_broadcast(target: NotificationTarget, message: str, rec if not receivers: return {"success": False, "error": "No receivers configured"} + from .http_session import get_http_session + http = await get_http_session() + results: list[dict] = [] - async with aiohttp.ClientSession() as session: - client = TelegramClient(session, bot_token) - for recv in receivers: - chat_id = recv.get("chat_id") - if not chat_id: - continue - result = await client.send_message( - chat_id=str(chat_id), - text=message, - disable_web_page_preview=bool(disable_preview), - ) - results.append(result) + client = TelegramClient(http, bot_token) + for recv in receivers: + chat_id = recv.get("chat_id") + if not chat_id: + continue + result = await client.send_message( + chat_id=str(chat_id), + text=message, + disable_web_page_preview=bool(disable_preview), + ) + results.append(result) return _aggregate(results) @@ -113,15 +113,17 @@ async def _send_webhook_broadcast(target: NotificationTarget, message: str, rece if not receivers: return {"success": False, "error": "No receivers configured"} + from .http_session import get_http_session + http = await get_http_session() + results: list[dict] = [] - async with aiohttp.ClientSession() as session: - for recv in receivers: - url = recv.get("url") - headers = recv.get("headers", {}) - if not url: - continue - client = WebhookClient(session, url, headers) - results.append(await client.send({"message": message, "event_type": "notification"})) + for recv in receivers: + url = recv.get("url") + headers = recv.get("headers", {}) + if not url: + continue + client = WebhookClient(http, url, headers) + results.append(await client.send({"message": message, "event_type": "notification"})) return _aggregate(results) @@ -178,22 +180,24 @@ async def _send_webhook_like_broadcast(target: NotificationTarget, message: str, if not receivers: return {"success": False, "error": "No receivers configured"} + from .http_session import get_http_session + http = await get_http_session() + results: list[dict] = [] - async with aiohttp.ClientSession() as session: - if target.type == "discord": - from notify_bridge_core.notifications.discord.client import DiscordClient - client = DiscordClient(session) - for recv in receivers: - url = recv.get("webhook_url") - if url: - results.append(await client.send(url, message, username=target.config.get("username"))) - elif target.type == "slack": - from notify_bridge_core.notifications.slack.client import SlackClient - client = SlackClient(session) - for recv in receivers: - url = recv.get("webhook_url") - if url: - results.append(await client.send(url, message, username=target.config.get("username"))) + if target.type == "discord": + from notify_bridge_core.notifications.discord.client import DiscordClient + client = DiscordClient(http) + for recv in receivers: + url = recv.get("webhook_url") + if url: + results.append(await client.send(url, message, username=target.config.get("username"))) + elif target.type == "slack": + from notify_bridge_core.notifications.slack.client import SlackClient + client = SlackClient(http) + for recv in receivers: + url = recv.get("webhook_url") + if url: + results.append(await client.send(url, message, username=target.config.get("username"))) return _aggregate(results) @@ -207,18 +211,20 @@ async def _send_ntfy_broadcast(target: NotificationTarget, message: str, receive return {"success": False, "error": "No receivers configured"} from notify_bridge_core.notifications.ntfy.client import NtfyClient + from .http_session import get_http_session + http = await get_http_session() + results: list[dict] = [] - async with aiohttp.ClientSession() as session: - client = NtfyClient(session) - for recv in receivers: - topic = recv.get("topic") - if topic: - results.append(await client.send( - server_url, topic, message, - title="Notify Bridge", - priority=recv.get("priority", 3), - auth_token=auth_token, - )) + client = NtfyClient(http) + for recv in receivers: + topic = recv.get("topic") + if topic: + results.append(await client.send( + server_url, topic, message, + title="Notify Bridge", + priority=recv.get("priority", 3), + auth_token=auth_token, + )) return _aggregate(results) @@ -243,13 +249,15 @@ async def _send_matrix_broadcast(target: NotificationTarget, message: str, recei if not receivers: return {"success": False, "error": "No receivers configured"} + from .http_session import get_http_session + http = await get_http_session() + results: list[dict] = [] - async with aiohttp.ClientSession() as http: - client = MatrixClient(http, homeserver, access_token) - for recv in receivers: - room_id = recv.get("room_id") - if room_id: - results.append(await client.send_message(room_id, message, html_message=message)) + client = MatrixClient(http, homeserver, access_token) + for recv in receivers: + room_id = recv.get("room_id") + if room_id: + results.append(await client.send_message(room_id, message, html_message=message)) return _aggregate(results) diff --git a/packages/server/src/notify_bridge_server/services/scheduler.py b/packages/server/src/notify_bridge_server/services/scheduler.py index 33937cd..876a1e0 100644 --- a/packages/server/src/notify_bridge_server/services/scheduler.py +++ b/packages/server/src/notify_bridge_server/services/scheduler.py @@ -31,11 +31,50 @@ async def start_scheduler() -> None: from .telegram_poller import start_command_listener_polling await start_command_listener_polling() + # Schedule daily cleanup of old event log entries + _schedule_event_cleanup() + # Start debounced command auto-sync scheduler from .command_sync import start_sync_scheduler start_sync_scheduler() +def _schedule_event_cleanup() -> None: + """Schedule a daily job to delete EventLog entries older than 90 days.""" + from apscheduler.triggers.cron import CronTrigger + + scheduler = get_scheduler() + job_id = "cleanup_old_events" + if scheduler.get_job(job_id): + return + scheduler.add_job( + _cleanup_old_events, + CronTrigger(hour=3, minute=0), + id=job_id, + replace_existing=True, + max_instances=1, + ) + _LOGGER.info("Scheduled daily event log cleanup at 03:00 UTC") + + +async def _cleanup_old_events() -> None: + """Delete EventLog entries older than 90 days.""" + from datetime import datetime, timedelta, timezone + + from sqlmodel import delete + from sqlmodel.ext.asyncio.session import AsyncSession + + from ..database.engine import get_engine + from ..database.models import EventLog + + cutoff = datetime.now(timezone.utc) - timedelta(days=90) + engine = get_engine() + async with AsyncSession(engine) as session: + await session.exec(delete(EventLog).where(EventLog.created_at < cutoff)) + await session.commit() + _LOGGER.info("Cleaned up event log entries older than %s", cutoff.date()) + + async def _load_tracker_jobs() -> None: """Load enabled trackers and schedule polling jobs.""" from sqlmodel import select @@ -50,13 +89,16 @@ async def _load_tracker_jobs() -> None: result = await session.exec(select(NotificationTracker).where(NotificationTracker.enabled == True)) trackers = result.all() - # Pre-load provider types for scheduler detection + # Batch-load provider types for scheduler detection + unique_provider_ids = list({t.provider_id for t in trackers}) provider_types: dict[int, str] = {} - for tracker in trackers: - if tracker.provider_id not in provider_types: - provider = await session.get(ServiceProviderModel, tracker.provider_id) - if provider: - provider_types[tracker.provider_id] = provider.type + if unique_provider_ids: + provider_result = await session.exec( + select(ServiceProviderModel).where( + ServiceProviderModel.id.in_(unique_provider_ids) + ) + ) + provider_types = {p.id: p.type for p in provider_result.all()} for tracker in trackers: job_id = f"tracker_{tracker.id}" @@ -86,6 +128,7 @@ async def _load_tracker_jobs() -> None: id=job_id, args=[tracker.id], replace_existing=True, + max_instances=1, ) _LOGGER.info("Scheduled tracker %d (%s) every %ds", tracker.id, tracker.name, tracker.scan_interval) @@ -106,6 +149,7 @@ def _add_cron_job( id=job_id, args=[tracker_id], replace_existing=True, + max_instances=1, ) _LOGGER.info("Scheduled tracker %d (%s) with cron: %s", tracker_id, tracker_name, cron_expression) diff --git a/packages/server/src/notify_bridge_server/services/telegram_poller.py b/packages/server/src/notify_bridge_server/services/telegram_poller.py index 83e01af..51127b0 100644 --- a/packages/server/src/notify_bridge_server/services/telegram_poller.py +++ b/packages/server/src/notify_bridge_server/services/telegram_poller.py @@ -13,7 +13,6 @@ from __future__ import annotations import logging from typing import Any -import aiohttp from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -47,10 +46,18 @@ async def _get_bot_ids_with_active_listeners() -> set[int]: listeners = result.all() active_bot_ids: set[int] = set() - for listener in listeners: - tracker = await session.get(CommandTracker, listener.command_tracker_id) - if tracker and tracker.enabled: - active_bot_ids.add(listener.listener_id) + tracker_ids = list({l.command_tracker_id for l in listeners}) + if tracker_ids: + tracker_result = await session.exec( + select(CommandTracker).where( + CommandTracker.id.in_(tracker_ids), + CommandTracker.enabled == True, # noqa: E712 + ) + ) + enabled_tracker_ids = {t.id for t in tracker_result.all()} + for listener in listeners: + if listener.command_tracker_id in enabled_tracker_ids: + active_bot_ids.add(listener.listener_id) return active_bot_ids @@ -145,21 +152,23 @@ async def _poll_bot(bot_id: int) -> None: if not bot or bot.update_mode != "polling": unschedule_bot_polling(bot_id) return - # Extract what we need before closing session + # Copy attributes before session closes to avoid detached-instance errors + from types import SimpleNamespace bot_token = bot.token - bot_obj = bot + bot_obj = SimpleNamespace(id=bot.id, name=bot.name, token=bot.token) offset = _last_update_id.get(bot_id, 0) try: - async with aiohttp.ClientSession() as http: - client = TelegramClient(http, bot_token) - result = await client.get_updates( - offset=offset + 1 if offset else None, limit=50, - ) - if not result.get("success"): - return - updates = result.get("result", []) + from .http_session import get_http_session + http = await get_http_session() + client = TelegramClient(http, bot_token) + result = await client.get_updates( + offset=offset + 1 if offset else None, limit=50, + ) + if not result.get("success"): + return + updates = result.get("result", []) except Exception as e: _LOGGER.debug("Polling error for bot %d: %s", bot_id, e) return @@ -209,17 +218,13 @@ async def _poll_bot(bot_id: int) -> None: continue effective_lang = chat_row.language_override or msg_language message_id = message.get("message_id") - cmd_response = await handle_command(bot_obj, chat_id, text, language_code=effective_lang) - if cmd_response is not None: - if isinstance(cmd_response, dict) and "media" in cmd_response: - # Text + media: send text first, media as reply - from ..commands.handler import send_reply as _reply - await _reply(bot_token, chat_id, cmd_response["text"], reply_to_message_id=message_id) - await send_media_group(bot_token, chat_id, cmd_response["media"], reply_to_message_id=message_id) - elif isinstance(cmd_response, list): - await send_media_group(bot_token, chat_id, cmd_response, reply_to_message_id=message_id) - else: - await send_reply(bot_token, chat_id, cmd_response, reply_to_message_id=message_id) + responses = await handle_command(bot_obj, chat_id, text, language_code=effective_lang) + if responses: + for resp in responses: + if resp.text: + await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id) + if resp.media: + await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id) except Exception: _LOGGER.error("Error handling command from bot %d", bot_id, exc_info=True) diff --git a/packages/server/src/notify_bridge_server/services/test_dispatch.py b/packages/server/src/notify_bridge_server/services/test_dispatch.py index e310a17..6360ff9 100644 --- a/packages/server/src/notify_bridge_server/services/test_dispatch.py +++ b/packages/server/src/notify_bridge_server/services/test_dispatch.py @@ -7,7 +7,6 @@ objects and dispatches through the same path the watcher uses. import logging from typing import Any -import aiohttp from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -183,58 +182,59 @@ async def _build_immich_event( memory_source = getattr(tracking_config, "memory_source", "albums") if tracking_config else "albums" is_memory = test_type == "memory" - async with aiohttp.ClientSession() as http_session: - immich = ImmichServiceProvider( - http_session, - provider_config.get("url", ""), - provider_config.get("api_key", ""), - provider_config.get("external_domain"), - provider_name, - ) - if not await immich.connect(): - return None + from .http_session import get_http_session + http_session = await get_http_session() + immich = ImmichServiceProvider( + http_session, + provider_config.get("url", ""), + provider_config.get("api_key", ""), + provider_config.get("external_domain"), + provider_name, + ) + if not await immich.connect(): + return None - # Native Immich memories API path - if is_memory and memory_source == "native": - return await _build_native_memory_event( - immich, ext_domain, provider_name, tracker_name, - collection_ids, limit, asset_type, favorite_only, min_rating, - ) - - # Album-based path: use shared collect_scheduled_assets - albums: dict[str, ImmichAlbumData] = {} - shared_links: dict[str, list[SharedLinkInfo]] = {} - for album_id in collection_ids: - album = await immich.client.get_album(album_id) - if album: - albums[album_id] = album - shared_links[album_id] = await immich.client.get_shared_links(album_id) - - assets, collections_extra = collect_scheduled_assets( - albums, shared_links, ext_domain, - limit=limit, - asset_type=asset_type, - favorite_only=favorite_only, - min_rating=min_rating, - is_memory=is_memory, + # Native Immich memories API path + if is_memory and memory_source == "native": + return await _build_native_memory_event( + immich, ext_domain, provider_name, tracker_name, + collection_ids, limit, asset_type, favorite_only, min_rating, ) - first_col = collections_extra[0] if collections_extra else {} - return ServiceEvent( - event_type=EventType.SCHEDULED_MESSAGE, - provider_type=ServiceProviderType.IMMICH, - provider_name=provider_name, - collection_id=collection_ids[0] if collection_ids else "", - collection_name=first_col.get("name", tracker_name), - timestamp=datetime.now(timezone.utc), - added_assets=assets, - added_count=len(assets), - extra={ - "collections": collections_extra, - "albums": collections_extra, - **(first_col if first_col else {}), - }, - ) + # Album-based path: use shared collect_scheduled_assets + albums: dict[str, ImmichAlbumData] = {} + shared_links: dict[str, list[SharedLinkInfo]] = {} + for album_id in collection_ids: + album = await immich.client.get_album(album_id) + if album: + albums[album_id] = album + shared_links[album_id] = await immich.client.get_shared_links(album_id) + + assets, collections_extra = collect_scheduled_assets( + albums, shared_links, ext_domain, + limit=limit, + asset_type=asset_type, + favorite_only=favorite_only, + min_rating=min_rating, + is_memory=is_memory, + ) + + first_col = collections_extra[0] if collections_extra else {} + return ServiceEvent( + event_type=EventType.SCHEDULED_MESSAGE, + provider_type=ServiceProviderType.IMMICH, + provider_name=provider_name, + collection_id=collection_ids[0] if collection_ids else "", + collection_name=first_col.get("name", tracker_name), + timestamp=datetime.now(timezone.utc), + added_assets=assets, + added_count=len(assets), + extra={ + "collections": collections_extra, + "albums": collections_extra, + **(first_col if first_col else {}), + }, + ) async def _build_native_memory_event( diff --git a/packages/server/src/notify_bridge_server/services/watcher.py b/packages/server/src/notify_bridge_server/services/watcher.py index befdfab..23e1f9c 100644 --- a/packages/server/src/notify_bridge_server/services/watcher.py +++ b/packages/server/src/notify_bridge_server/services/watcher.py @@ -6,7 +6,6 @@ import asyncio import logging from typing import Any -import aiohttp from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -102,19 +101,20 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]: if provider_type == "immich": from notify_bridge_core.providers.immich import ImmichServiceProvider - async with aiohttp.ClientSession() as http_session: - immich = ImmichServiceProvider( - http_session, - provider_config.get("url", ""), - provider_config.get("api_key", ""), - provider_config.get("external_domain"), - provider_name, - ) - connected = await immich.connect() - if not connected: - return {"status": "error", "reason": "failed to connect to provider"} + from .http_session import get_http_session + http_session = await get_http_session() + immich = ImmichServiceProvider( + http_session, + provider_config.get("url", ""), + provider_config.get("api_key", ""), + provider_config.get("external_domain"), + provider_name, + ) + connected = await immich.connect() + if not connected: + return {"status": "error", "reason": "failed to connect to provider"} - events, new_state = await immich.poll(collection_ids, state_dict) + events, new_state = await immich.poll(collection_ids, state_dict) elif provider_type == "gitea": # Gitea is webhook-based — events arrive via /api/webhooks/gitea endpoint. # The scheduler still calls check_tracker but there's nothing to poll. @@ -143,18 +143,22 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]: events, new_state = await nut.poll(collection_ids, state_dict) elif provider_type == "google_photos": from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider - async with aiohttp.ClientSession() as http_session: - gp = GooglePhotosServiceProvider( - http_session, - provider_config.get("client_id", ""), - provider_config.get("client_secret", ""), - provider_config.get("refresh_token", ""), - provider_name, - ) - connected = await gp.connect() - if not connected: - return {"status": "error", "reason": "failed to connect to Google Photos"} - events, new_state = await gp.poll(collection_ids, state_dict) + from .http_session import get_http_session + http_session = await get_http_session() + gp = GooglePhotosServiceProvider( + http_session, + provider_config.get("client_id", ""), + provider_config.get("client_secret", ""), + provider_config.get("refresh_token", ""), + provider_name, + ) + connected = await gp.connect() + if not connected: + return {"status": "error", "reason": "failed to connect to Google Photos"} + events, new_state = await gp.poll(collection_ids, state_dict) + elif provider_type == "webhook": + # Webhook providers receive events via inbound HTTP; no polling needed. + return {"status": "ok", "events_detected": 0, "collections_checked": 0} else: return {"status": "error", "reason": f"unsupported provider type: {provider_type}"}