refactor: comprehensive codebase review — security, performance, quality, UX
Security: - Fix NUT protocol command injection (validate names against safe regex) - Enable Jinja2 autoescape=True to prevent HTML injection via external data - Add WebhookProviderConfig validation model Performance: - Shared aiohttp.ClientSession singleton (replaces 40+ per-request sessions) - Fix 4 N+1 queries with batch IN loads (poller, scheduler, memory, broadcast) - asyncio.gather for Gitea commands and notification dispatcher - Add DB indexes on NotificationTrackerState.tracker_id, CommandTrackerListener - LRU cache for compiled Jinja2 templates - Daily EventLog cleanup job (90-day retention) - 30s HTTP timeout on all external calls - GROUP BY for target type counts (replaces 7 sequential queries) Code quality: - Extract get_owned_entity() helper (replaces 11 duplicate functions) - Extract slot_helpers.py (load_slots, save_slots, render_template_preview) - Extract command_utils.py (tracker lookup, last event, collection IDs) - Extract http_session.py (shared session lifecycle) - Provider connection validation dedup (3x → 1 helper) - Command dispatch tables replacing if/elif chains - Album+links fetch helper (fetch_albums_with_links) - Provider dispatch polymorphism (list_provider_collections) - Immutable _enrich_assets (no longer mutates in-place) - Fix _format_assets return type + handler unpacking Frontend: - Fix 18+ hardcoded English strings → t() with new i18n keys (en + ru) - Mobile "More" nav panel with provider filter and search - Shared Button.svelte component (4 variants, 2 sizes) - Shared ErrorBanner.svelte component (8 pages updated) - SvelteKit goto() replacing window.location.href - Dashboard grid fixed for 4 cards, paginator opacity consistency Functionality: - max_instances=1 on scheduler jobs (prevents duplicate events) - Webhook provider in watcher (prevents error spam) - Fix stale SQLModel reference in poller - Gitea get_repo() direct API call
This commit is contained in:
@@ -1,5 +1,11 @@
|
||||
"""Shared service utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.providers.immich import ImmichServiceProvider
|
||||
from notify_bridge_core.providers.gitea import GiteaServiceProvider
|
||||
from notify_bridge_core.providers.planka import PlankaServiceProvider
|
||||
@@ -8,8 +14,23 @@ from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvid
|
||||
|
||||
from ..database.models import ServiceProvider
|
||||
|
||||
# Default timeout for all outgoing HTTP requests to external services.
|
||||
DEFAULT_HTTP_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||
|
||||
def make_immich_provider(http_session, provider: ServiceProvider) -> ImmichServiceProvider:
|
||||
|
||||
class CollectionProvider(Protocol):
|
||||
"""Protocol for providers that can list collections."""
|
||||
|
||||
async def list_collections(self) -> list[dict[str, Any]]: ...
|
||||
|
||||
|
||||
class TestableProvider(Protocol):
|
||||
"""Protocol for providers that support connection testing."""
|
||||
|
||||
async def test_connection(self) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
def make_immich_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> ImmichServiceProvider:
|
||||
"""Create an ImmichServiceProvider from a DB provider model."""
|
||||
config = provider.config or {}
|
||||
return ImmichServiceProvider(
|
||||
@@ -21,7 +42,7 @@ def make_immich_provider(http_session, provider: ServiceProvider) -> ImmichServi
|
||||
)
|
||||
|
||||
|
||||
def make_gitea_provider(http_session, provider: ServiceProvider) -> GiteaServiceProvider:
|
||||
def make_gitea_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> GiteaServiceProvider:
|
||||
"""Create a GiteaServiceProvider from a DB provider model."""
|
||||
config = provider.config or {}
|
||||
return GiteaServiceProvider(
|
||||
@@ -32,7 +53,7 @@ def make_gitea_provider(http_session, provider: ServiceProvider) -> GiteaService
|
||||
)
|
||||
|
||||
|
||||
def make_planka_provider(http_session, provider: ServiceProvider) -> PlankaServiceProvider:
|
||||
def make_planka_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> PlankaServiceProvider:
|
||||
"""Create a PlankaServiceProvider from a DB provider model."""
|
||||
config = provider.config or {}
|
||||
return PlankaServiceProvider(
|
||||
@@ -55,7 +76,7 @@ def make_nut_provider(provider: ServiceProvider) -> NutServiceProvider:
|
||||
)
|
||||
|
||||
|
||||
def make_google_photos_provider(http_session, provider: ServiceProvider) -> GooglePhotosServiceProvider:
|
||||
def make_google_photos_provider(http_session: aiohttp.ClientSession, provider: ServiceProvider) -> GooglePhotosServiceProvider:
|
||||
"""Create a GooglePhotosServiceProvider from a DB provider model."""
|
||||
config = provider.config or {}
|
||||
return GooglePhotosServiceProvider(
|
||||
@@ -65,3 +86,61 @@ def make_google_photos_provider(http_session, provider: ServiceProvider) -> Goog
|
||||
config.get("refresh_token", ""),
|
||||
provider.name,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider factory registry — maps provider type strings to factory callables
|
||||
# that create a provider with a ``list_collections`` method. Providers that
|
||||
# require an API credential skip creation when the credential is missing
|
||||
# (the factory returns None in that case).
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_collection_provider(
|
||||
http_session: aiohttp.ClientSession,
|
||||
provider: ServiceProvider,
|
||||
) -> CollectionProvider | None:
|
||||
"""Create a CollectionProvider for the given DB provider, or None if unsupported."""
|
||||
ptype = provider.type
|
||||
config = provider.config or {}
|
||||
|
||||
if ptype == "immich":
|
||||
return make_immich_provider(http_session, provider)
|
||||
if ptype == "gitea":
|
||||
if not config.get("api_token"):
|
||||
return None
|
||||
return make_gitea_provider(http_session, provider)
|
||||
if ptype == "planka":
|
||||
if not config.get("api_key"):
|
||||
return None
|
||||
return make_planka_provider(http_session, provider)
|
||||
if ptype == "google_photos":
|
||||
return make_google_photos_provider(http_session, provider)
|
||||
# NUT provider needs no http_session
|
||||
if ptype == "nut":
|
||||
return make_nut_provider(provider) # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
|
||||
# Set of provider types that need an aiohttp session for collection listing.
|
||||
_HTTP_COLLECTION_PROVIDERS = {"immich", "gitea", "planka", "google_photos"}
|
||||
|
||||
|
||||
async def list_provider_collections(provider: ServiceProvider) -> list[dict[str, Any]]:
|
||||
"""List collections for any supported provider type.
|
||||
|
||||
Returns an empty list for providers that don't support collections or
|
||||
are missing required credentials.
|
||||
"""
|
||||
if provider.type in _HTTP_COLLECTION_PROVIDERS:
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
svc = _make_collection_provider(http_session, provider)
|
||||
if svc is None:
|
||||
return []
|
||||
return await svc.list_collections()
|
||||
|
||||
# Non-HTTP providers (e.g. NUT)
|
||||
svc = _make_collection_provider(None, provider) # type: ignore[arg-type]
|
||||
if svc is None:
|
||||
return []
|
||||
return await svc.list_collections()
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -159,27 +158,28 @@ async def _execute_with_provider(
|
||||
)
|
||||
from notify_bridge_core.providers.immich.client import ImmichClient
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
client = ImmichClient(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
client = ImmichClient(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
)
|
||||
external_domain = provider_config.get("external_domain")
|
||||
if external_domain:
|
||||
client.external_domain = external_domain
|
||||
|
||||
# Verify connectivity
|
||||
if not await client.ping():
|
||||
return ActionResult(
|
||||
success=False,
|
||||
error=f"Cannot connect to Immich server ({provider_name})",
|
||||
)
|
||||
external_domain = provider_config.get("external_domain")
|
||||
if external_domain:
|
||||
client.external_domain = external_domain
|
||||
|
||||
# Verify connectivity
|
||||
if not await client.ping():
|
||||
return ActionResult(
|
||||
success=False,
|
||||
error=f"Cannot connect to Immich server ({provider_name})",
|
||||
)
|
||||
|
||||
executor = ImmichActionExecutor(client)
|
||||
if dry_run:
|
||||
return await executor.dry_run(action_type, rule_configs, action_config)
|
||||
return await executor.execute(action_type, rule_configs, action_config)
|
||||
executor = ImmichActionExecutor(client)
|
||||
if dry_run:
|
||||
return await executor.dry_run(action_type, rule_configs, action_config)
|
||||
return await executor.execute(action_type, rule_configs, action_config)
|
||||
|
||||
return ActionResult(
|
||||
success=False,
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Application-level shared aiohttp.ClientSession.
|
||||
|
||||
All outgoing HTTP requests in the server package should use the shared
|
||||
session returned by ``get_http_session()`` instead of creating
|
||||
per-request ``aiohttp.ClientSession`` instances. This keeps a single
|
||||
TCP connection pool alive for the lifetime of the process, avoiding
|
||||
the overhead of pool creation/teardown on every request.
|
||||
|
||||
Call ``close_http_session()`` once during application shutdown.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
|
||||
_DEFAULT_TIMEOUT = aiohttp.ClientTimeout(total=30)
|
||||
_session: aiohttp.ClientSession | None = None
|
||||
|
||||
|
||||
async def get_http_session() -> aiohttp.ClientSession:
|
||||
"""Get or create the shared HTTP session."""
|
||||
global _session
|
||||
if _session is None or _session.closed:
|
||||
_session = aiohttp.ClientSession(timeout=_DEFAULT_TIMEOUT)
|
||||
return _session
|
||||
|
||||
|
||||
async def close_http_session() -> None:
|
||||
"""Close the shared HTTP session (call on app shutdown)."""
|
||||
global _session
|
||||
if _session is not None and not _session.closed:
|
||||
await _session.close()
|
||||
_session = None
|
||||
@@ -3,8 +3,6 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -90,19 +88,21 @@ async def _send_telegram_broadcast(target: NotificationTarget, message: str, rec
|
||||
if not receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
|
||||
results: list[dict] = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
client = TelegramClient(session, bot_token)
|
||||
for recv in receivers:
|
||||
chat_id = recv.get("chat_id")
|
||||
if not chat_id:
|
||||
continue
|
||||
result = await client.send_message(
|
||||
chat_id=str(chat_id),
|
||||
text=message,
|
||||
disable_web_page_preview=bool(disable_preview),
|
||||
)
|
||||
results.append(result)
|
||||
client = TelegramClient(http, bot_token)
|
||||
for recv in receivers:
|
||||
chat_id = recv.get("chat_id")
|
||||
if not chat_id:
|
||||
continue
|
||||
result = await client.send_message(
|
||||
chat_id=str(chat_id),
|
||||
text=message,
|
||||
disable_web_page_preview=bool(disable_preview),
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return _aggregate(results)
|
||||
|
||||
@@ -113,15 +113,17 @@ async def _send_webhook_broadcast(target: NotificationTarget, message: str, rece
|
||||
if not receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
|
||||
results: list[dict] = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for recv in receivers:
|
||||
url = recv.get("url")
|
||||
headers = recv.get("headers", {})
|
||||
if not url:
|
||||
continue
|
||||
client = WebhookClient(session, url, headers)
|
||||
results.append(await client.send({"message": message, "event_type": "notification"}))
|
||||
for recv in receivers:
|
||||
url = recv.get("url")
|
||||
headers = recv.get("headers", {})
|
||||
if not url:
|
||||
continue
|
||||
client = WebhookClient(http, url, headers)
|
||||
results.append(await client.send({"message": message, "event_type": "notification"}))
|
||||
|
||||
return _aggregate(results)
|
||||
|
||||
@@ -178,22 +180,24 @@ async def _send_webhook_like_broadcast(target: NotificationTarget, message: str,
|
||||
if not receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
|
||||
results: list[dict] = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if target.type == "discord":
|
||||
from notify_bridge_core.notifications.discord.client import DiscordClient
|
||||
client = DiscordClient(session)
|
||||
for recv in receivers:
|
||||
url = recv.get("webhook_url")
|
||||
if url:
|
||||
results.append(await client.send(url, message, username=target.config.get("username")))
|
||||
elif target.type == "slack":
|
||||
from notify_bridge_core.notifications.slack.client import SlackClient
|
||||
client = SlackClient(session)
|
||||
for recv in receivers:
|
||||
url = recv.get("webhook_url")
|
||||
if url:
|
||||
results.append(await client.send(url, message, username=target.config.get("username")))
|
||||
if target.type == "discord":
|
||||
from notify_bridge_core.notifications.discord.client import DiscordClient
|
||||
client = DiscordClient(http)
|
||||
for recv in receivers:
|
||||
url = recv.get("webhook_url")
|
||||
if url:
|
||||
results.append(await client.send(url, message, username=target.config.get("username")))
|
||||
elif target.type == "slack":
|
||||
from notify_bridge_core.notifications.slack.client import SlackClient
|
||||
client = SlackClient(http)
|
||||
for recv in receivers:
|
||||
url = recv.get("webhook_url")
|
||||
if url:
|
||||
results.append(await client.send(url, message, username=target.config.get("username")))
|
||||
|
||||
return _aggregate(results)
|
||||
|
||||
@@ -207,18 +211,20 @@ async def _send_ntfy_broadcast(target: NotificationTarget, message: str, receive
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
from notify_bridge_core.notifications.ntfy.client import NtfyClient
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
|
||||
results: list[dict] = []
|
||||
async with aiohttp.ClientSession() as session:
|
||||
client = NtfyClient(session)
|
||||
for recv in receivers:
|
||||
topic = recv.get("topic")
|
||||
if topic:
|
||||
results.append(await client.send(
|
||||
server_url, topic, message,
|
||||
title="Notify Bridge",
|
||||
priority=recv.get("priority", 3),
|
||||
auth_token=auth_token,
|
||||
))
|
||||
client = NtfyClient(http)
|
||||
for recv in receivers:
|
||||
topic = recv.get("topic")
|
||||
if topic:
|
||||
results.append(await client.send(
|
||||
server_url, topic, message,
|
||||
title="Notify Bridge",
|
||||
priority=recv.get("priority", 3),
|
||||
auth_token=auth_token,
|
||||
))
|
||||
|
||||
return _aggregate(results)
|
||||
|
||||
@@ -243,13 +249,15 @@ async def _send_matrix_broadcast(target: NotificationTarget, message: str, recei
|
||||
if not receivers:
|
||||
return {"success": False, "error": "No receivers configured"}
|
||||
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
|
||||
results: list[dict] = []
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = MatrixClient(http, homeserver, access_token)
|
||||
for recv in receivers:
|
||||
room_id = recv.get("room_id")
|
||||
if room_id:
|
||||
results.append(await client.send_message(room_id, message, html_message=message))
|
||||
client = MatrixClient(http, homeserver, access_token)
|
||||
for recv in receivers:
|
||||
room_id = recv.get("room_id")
|
||||
if room_id:
|
||||
results.append(await client.send_message(room_id, message, html_message=message))
|
||||
|
||||
return _aggregate(results)
|
||||
|
||||
|
||||
@@ -31,11 +31,50 @@ async def start_scheduler() -> None:
|
||||
from .telegram_poller import start_command_listener_polling
|
||||
await start_command_listener_polling()
|
||||
|
||||
# Schedule daily cleanup of old event log entries
|
||||
_schedule_event_cleanup()
|
||||
|
||||
# Start debounced command auto-sync scheduler
|
||||
from .command_sync import start_sync_scheduler
|
||||
start_sync_scheduler()
|
||||
|
||||
|
||||
def _schedule_event_cleanup() -> None:
|
||||
"""Schedule a daily job to delete EventLog entries older than 90 days."""
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
|
||||
scheduler = get_scheduler()
|
||||
job_id = "cleanup_old_events"
|
||||
if scheduler.get_job(job_id):
|
||||
return
|
||||
scheduler.add_job(
|
||||
_cleanup_old_events,
|
||||
CronTrigger(hour=3, minute=0),
|
||||
id=job_id,
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Scheduled daily event log cleanup at 03:00 UTC")
|
||||
|
||||
|
||||
async def _cleanup_old_events() -> None:
|
||||
"""Delete EventLog entries older than 90 days."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlmodel import delete
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import EventLog
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
await session.exec(delete(EventLog).where(EventLog.created_at < cutoff))
|
||||
await session.commit()
|
||||
_LOGGER.info("Cleaned up event log entries older than %s", cutoff.date())
|
||||
|
||||
|
||||
async def _load_tracker_jobs() -> None:
|
||||
"""Load enabled trackers and schedule polling jobs."""
|
||||
from sqlmodel import select
|
||||
@@ -50,13 +89,16 @@ async def _load_tracker_jobs() -> None:
|
||||
result = await session.exec(select(NotificationTracker).where(NotificationTracker.enabled == True))
|
||||
trackers = result.all()
|
||||
|
||||
# Pre-load provider types for scheduler detection
|
||||
# Batch-load provider types for scheduler detection
|
||||
unique_provider_ids = list({t.provider_id for t in trackers})
|
||||
provider_types: dict[int, str] = {}
|
||||
for tracker in trackers:
|
||||
if tracker.provider_id not in provider_types:
|
||||
provider = await session.get(ServiceProviderModel, tracker.provider_id)
|
||||
if provider:
|
||||
provider_types[tracker.provider_id] = provider.type
|
||||
if unique_provider_ids:
|
||||
provider_result = await session.exec(
|
||||
select(ServiceProviderModel).where(
|
||||
ServiceProviderModel.id.in_(unique_provider_ids)
|
||||
)
|
||||
)
|
||||
provider_types = {p.id: p.type for p in provider_result.all()}
|
||||
|
||||
for tracker in trackers:
|
||||
job_id = f"tracker_{tracker.id}"
|
||||
@@ -86,6 +128,7 @@ async def _load_tracker_jobs() -> None:
|
||||
id=job_id,
|
||||
args=[tracker.id],
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Scheduled tracker %d (%s) every %ds", tracker.id, tracker.name, tracker.scan_interval)
|
||||
|
||||
@@ -106,6 +149,7 @@ def _add_cron_job(
|
||||
id=job_id,
|
||||
args=[tracker_id],
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_LOGGER.info("Scheduled tracker %d (%s) with cron: %s", tracker_id, tracker_name, cron_expression)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -47,10 +46,18 @@ async def _get_bot_ids_with_active_listeners() -> set[int]:
|
||||
listeners = result.all()
|
||||
|
||||
active_bot_ids: set[int] = set()
|
||||
for listener in listeners:
|
||||
tracker = await session.get(CommandTracker, listener.command_tracker_id)
|
||||
if tracker and tracker.enabled:
|
||||
active_bot_ids.add(listener.listener_id)
|
||||
tracker_ids = list({l.command_tracker_id for l in listeners})
|
||||
if tracker_ids:
|
||||
tracker_result = await session.exec(
|
||||
select(CommandTracker).where(
|
||||
CommandTracker.id.in_(tracker_ids),
|
||||
CommandTracker.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
enabled_tracker_ids = {t.id for t in tracker_result.all()}
|
||||
for listener in listeners:
|
||||
if listener.command_tracker_id in enabled_tracker_ids:
|
||||
active_bot_ids.add(listener.listener_id)
|
||||
|
||||
return active_bot_ids
|
||||
|
||||
@@ -145,21 +152,23 @@ async def _poll_bot(bot_id: int) -> None:
|
||||
if not bot or bot.update_mode != "polling":
|
||||
unschedule_bot_polling(bot_id)
|
||||
return
|
||||
# Extract what we need before closing session
|
||||
# Copy attributes before session closes to avoid detached-instance errors
|
||||
from types import SimpleNamespace
|
||||
bot_token = bot.token
|
||||
bot_obj = bot
|
||||
bot_obj = SimpleNamespace(id=bot.id, name=bot.name, token=bot.token)
|
||||
|
||||
offset = _last_update_id.get(bot_id, 0)
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, bot_token)
|
||||
result = await client.get_updates(
|
||||
offset=offset + 1 if offset else None, limit=50,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return
|
||||
updates = result.get("result", [])
|
||||
from .http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, bot_token)
|
||||
result = await client.get_updates(
|
||||
offset=offset + 1 if offset else None, limit=50,
|
||||
)
|
||||
if not result.get("success"):
|
||||
return
|
||||
updates = result.get("result", [])
|
||||
except Exception as e:
|
||||
_LOGGER.debug("Polling error for bot %d: %s", bot_id, e)
|
||||
return
|
||||
@@ -209,17 +218,13 @@ async def _poll_bot(bot_id: int) -> None:
|
||||
continue
|
||||
effective_lang = chat_row.language_override or msg_language
|
||||
message_id = message.get("message_id")
|
||||
cmd_response = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
|
||||
if cmd_response is not None:
|
||||
if isinstance(cmd_response, dict) and "media" in cmd_response:
|
||||
# Text + media: send text first, media as reply
|
||||
from ..commands.handler import send_reply as _reply
|
||||
await _reply(bot_token, chat_id, cmd_response["text"], reply_to_message_id=message_id)
|
||||
await send_media_group(bot_token, chat_id, cmd_response["media"], reply_to_message_id=message_id)
|
||||
elif isinstance(cmd_response, list):
|
||||
await send_media_group(bot_token, chat_id, cmd_response, reply_to_message_id=message_id)
|
||||
else:
|
||||
await send_reply(bot_token, chat_id, cmd_response, reply_to_message_id=message_id)
|
||||
responses = await handle_command(bot_obj, chat_id, text, language_code=effective_lang)
|
||||
if responses:
|
||||
for resp in responses:
|
||||
if resp.text:
|
||||
await send_reply(bot_token, chat_id, resp.text, reply_to_message_id=message_id)
|
||||
if resp.media:
|
||||
await send_media_group(bot_token, chat_id, resp.media, reply_to_message_id=message_id)
|
||||
except Exception:
|
||||
_LOGGER.error("Error handling command from bot %d", bot_id, exc_info=True)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ objects and dispatches through the same path the watcher uses.
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -183,58 +182,59 @@ async def _build_immich_event(
|
||||
memory_source = getattr(tracking_config, "memory_source", "albums") if tracking_config else "albums"
|
||||
is_memory = test_type == "memory"
|
||||
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = ImmichServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
provider_config.get("external_domain"),
|
||||
provider_name,
|
||||
)
|
||||
if not await immich.connect():
|
||||
return None
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
immich = ImmichServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
provider_config.get("external_domain"),
|
||||
provider_name,
|
||||
)
|
||||
if not await immich.connect():
|
||||
return None
|
||||
|
||||
# Native Immich memories API path
|
||||
if is_memory and memory_source == "native":
|
||||
return await _build_native_memory_event(
|
||||
immich, ext_domain, provider_name, tracker_name,
|
||||
collection_ids, limit, asset_type, favorite_only, min_rating,
|
||||
)
|
||||
|
||||
# Album-based path: use shared collect_scheduled_assets
|
||||
albums: dict[str, ImmichAlbumData] = {}
|
||||
shared_links: dict[str, list[SharedLinkInfo]] = {}
|
||||
for album_id in collection_ids:
|
||||
album = await immich.client.get_album(album_id)
|
||||
if album:
|
||||
albums[album_id] = album
|
||||
shared_links[album_id] = await immich.client.get_shared_links(album_id)
|
||||
|
||||
assets, collections_extra = collect_scheduled_assets(
|
||||
albums, shared_links, ext_domain,
|
||||
limit=limit,
|
||||
asset_type=asset_type,
|
||||
favorite_only=favorite_only,
|
||||
min_rating=min_rating,
|
||||
is_memory=is_memory,
|
||||
# Native Immich memories API path
|
||||
if is_memory and memory_source == "native":
|
||||
return await _build_native_memory_event(
|
||||
immich, ext_domain, provider_name, tracker_name,
|
||||
collection_ids, limit, asset_type, favorite_only, min_rating,
|
||||
)
|
||||
|
||||
first_col = collections_extra[0] if collections_extra else {}
|
||||
return ServiceEvent(
|
||||
event_type=EventType.SCHEDULED_MESSAGE,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
provider_name=provider_name,
|
||||
collection_id=collection_ids[0] if collection_ids else "",
|
||||
collection_name=first_col.get("name", tracker_name),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
added_assets=assets,
|
||||
added_count=len(assets),
|
||||
extra={
|
||||
"collections": collections_extra,
|
||||
"albums": collections_extra,
|
||||
**(first_col if first_col else {}),
|
||||
},
|
||||
)
|
||||
# Album-based path: use shared collect_scheduled_assets
|
||||
albums: dict[str, ImmichAlbumData] = {}
|
||||
shared_links: dict[str, list[SharedLinkInfo]] = {}
|
||||
for album_id in collection_ids:
|
||||
album = await immich.client.get_album(album_id)
|
||||
if album:
|
||||
albums[album_id] = album
|
||||
shared_links[album_id] = await immich.client.get_shared_links(album_id)
|
||||
|
||||
assets, collections_extra = collect_scheduled_assets(
|
||||
albums, shared_links, ext_domain,
|
||||
limit=limit,
|
||||
asset_type=asset_type,
|
||||
favorite_only=favorite_only,
|
||||
min_rating=min_rating,
|
||||
is_memory=is_memory,
|
||||
)
|
||||
|
||||
first_col = collections_extra[0] if collections_extra else {}
|
||||
return ServiceEvent(
|
||||
event_type=EventType.SCHEDULED_MESSAGE,
|
||||
provider_type=ServiceProviderType.IMMICH,
|
||||
provider_name=provider_name,
|
||||
collection_id=collection_ids[0] if collection_ids else "",
|
||||
collection_name=first_col.get("name", tracker_name),
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
added_assets=assets,
|
||||
added_count=len(assets),
|
||||
extra={
|
||||
"collections": collections_extra,
|
||||
"albums": collections_extra,
|
||||
**(first_col if first_col else {}),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _build_native_memory_event(
|
||||
|
||||
@@ -6,7 +6,6 @@ import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
@@ -102,19 +101,20 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
|
||||
if provider_type == "immich":
|
||||
from notify_bridge_core.providers.immich import ImmichServiceProvider
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = ImmichServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
provider_config.get("external_domain"),
|
||||
provider_name,
|
||||
)
|
||||
connected = await immich.connect()
|
||||
if not connected:
|
||||
return {"status": "error", "reason": "failed to connect to provider"}
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
immich = ImmichServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("url", ""),
|
||||
provider_config.get("api_key", ""),
|
||||
provider_config.get("external_domain"),
|
||||
provider_name,
|
||||
)
|
||||
connected = await immich.connect()
|
||||
if not connected:
|
||||
return {"status": "error", "reason": "failed to connect to provider"}
|
||||
|
||||
events, new_state = await immich.poll(collection_ids, state_dict)
|
||||
events, new_state = await immich.poll(collection_ids, state_dict)
|
||||
elif provider_type == "gitea":
|
||||
# Gitea is webhook-based — events arrive via /api/webhooks/gitea endpoint.
|
||||
# The scheduler still calls check_tracker but there's nothing to poll.
|
||||
@@ -143,18 +143,22 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
events, new_state = await nut.poll(collection_ids, state_dict)
|
||||
elif provider_type == "google_photos":
|
||||
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gp = GooglePhotosServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("client_id", ""),
|
||||
provider_config.get("client_secret", ""),
|
||||
provider_config.get("refresh_token", ""),
|
||||
provider_name,
|
||||
)
|
||||
connected = await gp.connect()
|
||||
if not connected:
|
||||
return {"status": "error", "reason": "failed to connect to Google Photos"}
|
||||
events, new_state = await gp.poll(collection_ids, state_dict)
|
||||
from .http_session import get_http_session
|
||||
http_session = await get_http_session()
|
||||
gp = GooglePhotosServiceProvider(
|
||||
http_session,
|
||||
provider_config.get("client_id", ""),
|
||||
provider_config.get("client_secret", ""),
|
||||
provider_config.get("refresh_token", ""),
|
||||
provider_name,
|
||||
)
|
||||
connected = await gp.connect()
|
||||
if not connected:
|
||||
return {"status": "error", "reason": "failed to connect to Google Photos"}
|
||||
events, new_state = await gp.poll(collection_ids, state_dict)
|
||||
elif provider_type == "webhook":
|
||||
# Webhook providers receive events via inbound HTTP; no polling needed.
|
||||
return {"status": "ok", "events_detected": 0, "collections_checked": 0}
|
||||
else:
|
||||
return {"status": "error", "reason": f"unsupported provider type: {provider_type}"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user