refactor: comprehensive codebase review — security, performance, quality, UX
Security: - Fix NUT protocol command injection (validate names against safe regex) - Enable Jinja2 autoescape=True to prevent HTML injection via external data - Add WebhookProviderConfig validation model Performance: - Shared aiohttp.ClientSession singleton (replaces 40+ per-request sessions) - Fix 4 N+1 queries with batch IN loads (poller, scheduler, memory, broadcast) - asyncio.gather for Gitea commands and notification dispatcher - Add DB indexes on NotificationTrackerState.tracker_id, CommandTrackerListener - LRU cache for compiled Jinja2 templates - Daily EventLog cleanup job (90-day retention) - 30s HTTP timeout on all external calls - GROUP BY for target type counts (replaces 7 sequential queries) Code quality: - Extract get_owned_entity() helper (replaces 11 duplicate functions) - Extract slot_helpers.py (load_slots, save_slots, render_template_preview) - Extract command_utils.py (tracker lookup, last event, collection IDs) - Extract http_session.py (shared session lifecycle) - Provider connection validation dedup (3x → 1 helper) - Command dispatch tables replacing if/elif chains - Album+links fetch helper (fetch_albums_with_links) - Provider dispatch polymorphism (list_provider_collections) - Immutable _enrich_assets (no longer mutates in-place) - Fix _format_assets return type + handler unpacking Frontend: - Fix 18+ hardcoded English strings → t() with new i18n keys (en + ru) - Mobile "More" nav panel with provider filter and search - Shared Button.svelte component (4 variants, 2 sizes) - Shared ErrorBanner.svelte component (8 pages updated) - SvelteKit goto() replacing window.location.href - Dashboard grid fixed for 4 cards, paginator opacity consistency Functionality: - max_instances=1 on scheduler jobs (prevents duplicate events) - Webhook provider in watcher (prevents error spam) - Fix stale SQLModel reference in poller - Gitea get_repo() direct API call
This commit is contained in:
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import Action, ActionRule, User
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,10 +60,9 @@ def _rule_response(rule: ActionRule) -> dict:
|
||||
async def _get_user_action(
|
||||
session: AsyncSession, action_id: int, user: User
|
||||
) -> Action:
|
||||
action = await session.get(Action, action_id)
|
||||
if not action or action.user_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Action not found")
|
||||
return action
|
||||
return await get_owned_entity(
|
||||
session, Action, action_id, user.id, not_found_msg="Action not found",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -12,12 +12,10 @@ from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import CommandTemplateConfig, CommandTemplateSlot, User
|
||||
from .slot_helpers import load_slots, render_template_preview, save_slots
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,38 +42,11 @@ class CommandTemplateConfigUpdate(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]:
|
||||
"""Load slots as {slot_name: {locale: template}}."""
|
||||
result = await session.exec(
|
||||
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == config_id)
|
||||
)
|
||||
nested: dict[str, dict[str, str]] = {}
|
||||
for s in result.all():
|
||||
nested.setdefault(s.slot_name, {})[s.locale] = s.template
|
||||
return nested
|
||||
return await load_slots(session, CommandTemplateSlot, config_id)
|
||||
|
||||
|
||||
async def _save_slots(session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]]) -> None:
|
||||
"""Save slots from {slot_name: {locale: template}} format."""
|
||||
for slot_name, locale_map in slots.items():
|
||||
for locale, template_text in locale_map.items():
|
||||
result = await session.exec(
|
||||
select(CommandTemplateSlot).where(
|
||||
CommandTemplateSlot.config_id == config_id,
|
||||
CommandTemplateSlot.slot_name == slot_name,
|
||||
CommandTemplateSlot.locale == locale,
|
||||
)
|
||||
)
|
||||
existing = result.first()
|
||||
if existing:
|
||||
existing.template = template_text
|
||||
session.add(existing)
|
||||
else:
|
||||
session.add(CommandTemplateSlot(
|
||||
config_id=config_id,
|
||||
slot_name=slot_name,
|
||||
locale=locale,
|
||||
template=template_text,
|
||||
))
|
||||
await save_slots(session, CommandTemplateSlot, config_id, slots)
|
||||
|
||||
|
||||
async def _response(session: AsyncSession, c: CommandTemplateConfig) -> dict[str, Any]:
|
||||
@@ -367,18 +338,4 @@ async def preview_raw(
|
||||
"wait": 15,
|
||||
}
|
||||
|
||||
try:
|
||||
env = SandboxedEnvironment(autoescape=False)
|
||||
env.from_string(body.template)
|
||||
except TemplateSyntaxError as e:
|
||||
return {"rendered": None, "error": e.message, "error_line": e.lineno}
|
||||
|
||||
try:
|
||||
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
|
||||
tmpl = strict_env.from_string(body.template)
|
||||
rendered = tmpl.render(**sample_ctx)
|
||||
return {"rendered": rendered}
|
||||
except UndefinedError as e:
|
||||
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
|
||||
except Exception as e:
|
||||
return {"rendered": None, "error": str(e), "error_line": None}
|
||||
return render_template_preview(body.template, sample_ctx)
|
||||
|
||||
@@ -17,6 +17,7 @@ from ..database.models import (
|
||||
TelegramBot,
|
||||
User,
|
||||
)
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -401,7 +402,7 @@ async def _listener_response(session: AsyncSession, l: CommandTrackerListener) -
|
||||
async def _get_user_tracker(
|
||||
session: AsyncSession, tracker_id: int, user_id: int
|
||||
) -> CommandTracker:
|
||||
tracker = await session.get(CommandTracker, tracker_id)
|
||||
if not tracker or tracker.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Command tracker not found")
|
||||
return tracker
|
||||
return await get_owned_entity(
|
||||
session, CommandTracker, tracker_id, user_id,
|
||||
not_found_msg="Command tracker not found",
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import EmailBot, User
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -156,7 +157,6 @@ def _response(bot: EmailBot) -> dict:
|
||||
|
||||
|
||||
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> EmailBot:
|
||||
bot = await session.get(EmailBot, bot_id)
|
||||
if not bot or bot.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Email bot not found")
|
||||
return bot
|
||||
return await get_owned_entity(
|
||||
session, EmailBot, bot_id, user_id, not_found_msg="Email bot not found",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Shared helpers for API route modules."""
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
T = TypeVar("T", bound=SQLModel)
|
||||
|
||||
|
||||
async def get_owned_entity(
|
||||
session: AsyncSession,
|
||||
model: type[T],
|
||||
entity_id: int,
|
||||
user_id: int,
|
||||
*,
|
||||
owner_field: str = "user_id",
|
||||
not_found_msg: str = "Not found",
|
||||
) -> T:
|
||||
"""Fetch an entity by PK and verify ownership, or raise 404."""
|
||||
entity = await session.get(model, entity_id)
|
||||
if not entity or getattr(entity, owner_field) != user_id:
|
||||
raise HTTPException(status_code=404, detail=not_found_msg)
|
||||
return entity
|
||||
@@ -10,6 +10,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import MatrixBot, User
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -108,33 +109,34 @@ async def test_matrix_bot(
|
||||
bot = await _get_user_bot(session, bot_id, user.id)
|
||||
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as http:
|
||||
# Verify token with /whoami
|
||||
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
|
||||
headers = {"Authorization": f"Bearer {bot.access_token}"}
|
||||
try:
|
||||
async with http.get(whoami_url, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
body = await resp.text()
|
||||
return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"}
|
||||
whoami = await resp.json()
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": f"Connection failed: {e}"}
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
# Verify token with /whoami
|
||||
whoami_url = f"{bot.homeserver_url.rstrip('/')}/_matrix/client/v3/account/whoami"
|
||||
headers = {"Authorization": f"Bearer {bot.access_token}"}
|
||||
try:
|
||||
async with http.get(whoami_url, headers=headers) as resp:
|
||||
if resp.status != 200:
|
||||
body = await resp.text()
|
||||
return {"success": False, "error": f"Auth failed: HTTP {resp.status} — {body[:200]}"}
|
||||
whoami = await resp.json()
|
||||
except aiohttp.ClientError as e:
|
||||
return {"success": False, "error": f"Connection failed: {e}"}
|
||||
|
||||
result = {"success": True, "user_id": whoami.get("user_id", "")}
|
||||
result = {"success": True, "user_id": whoami.get("user_id", "")}
|
||||
|
||||
# Optionally send a test message
|
||||
if room_id:
|
||||
from notify_bridge_core.notifications.matrix.client import MatrixClient
|
||||
client = MatrixClient(http, bot.homeserver_url, bot.access_token)
|
||||
send_result = await client.send_message(
|
||||
room_id,
|
||||
"Test message from Notify Bridge",
|
||||
html_message="<b>Test message</b> from Notify Bridge",
|
||||
)
|
||||
result["send_result"] = send_result
|
||||
# Optionally send a test message
|
||||
if room_id:
|
||||
from notify_bridge_core.notifications.matrix.client import MatrixClient
|
||||
client = MatrixClient(http, bot.homeserver_url, bot.access_token)
|
||||
send_result = await client.send_message(
|
||||
room_id,
|
||||
"Test message from Notify Bridge",
|
||||
html_message="<b>Test message</b> from Notify Bridge",
|
||||
)
|
||||
result["send_result"] = send_result
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
|
||||
def _response(bot: MatrixBot) -> dict:
|
||||
@@ -150,7 +152,6 @@ def _response(bot: MatrixBot) -> dict:
|
||||
|
||||
|
||||
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> MatrixBot:
|
||||
bot = await session.get(MatrixBot, bot_id)
|
||||
if not bot or bot.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Matrix bot not found")
|
||||
return bot
|
||||
return await get_owned_entity(
|
||||
session, MatrixBot, bot_id, user_id, not_found_msg="Matrix bot not found",
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from ..database.models import (
|
||||
)
|
||||
from ..services.notifier import send_test_notification
|
||||
from ..services.test_dispatch import dispatch_test_notification
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -277,7 +278,7 @@ async def _tt_response(session: AsyncSession, tt: NotificationTrackerTarget) ->
|
||||
async def _get_user_tracker(
|
||||
session: AsyncSession, tracker_id: int, user_id: int
|
||||
) -> NotificationTracker:
|
||||
tracker = await session.get(NotificationTracker, tracker_id)
|
||||
if not tracker or tracker.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Tracker not found")
|
||||
return tracker
|
||||
return await get_owned_entity(
|
||||
session, NotificationTracker, tracker_id, user_id,
|
||||
not_found_msg="Tracker not found",
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from ..database.models import (
|
||||
User,
|
||||
)
|
||||
from ..services.scheduler import schedule_tracker, unschedule_tracker
|
||||
from .helpers import get_owned_entity
|
||||
from .notification_tracker_targets import _tt_response
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -205,7 +206,7 @@ async def _tracker_response(session: AsyncSession, t: NotificationTracker) -> di
|
||||
async def _get_user_tracker(
|
||||
session: AsyncSession, tracker_id: int, user_id: int
|
||||
) -> NotificationTracker:
|
||||
tracker = await session.get(NotificationTracker, tracker_id)
|
||||
if not tracker or tracker.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Tracker not found")
|
||||
return tracker
|
||||
return await get_owned_entity(
|
||||
session, NotificationTracker, tracker_id, user_id,
|
||||
not_found_msg="Tracker not found",
|
||||
)
|
||||
|
||||
@@ -13,7 +13,12 @@ import aiohttp
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import ServiceProvider, User
|
||||
from ..services import make_immich_provider, make_gitea_provider, make_planka_provider, make_nut_provider, make_google_photos_provider
|
||||
from ..services import (
|
||||
make_immich_provider, make_gitea_provider, make_planka_provider,
|
||||
make_nut_provider, make_google_photos_provider, list_provider_collections,
|
||||
)
|
||||
from ..services.http_session import get_http_session
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -82,6 +87,20 @@ class GooglePhotosProviderConfig(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class PayloadMapping(BaseModel):
|
||||
variable: str
|
||||
jsonpath: str
|
||||
default: str | None = None
|
||||
|
||||
|
||||
class WebhookProviderConfig(BaseModel):
|
||||
auth_mode: str = "none"
|
||||
webhook_secret: str | None = None
|
||||
payload_mappings: list[PayloadMapping] = []
|
||||
event_type_path: str | None = None
|
||||
collection_path: str | None = None
|
||||
|
||||
|
||||
_PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
|
||||
"immich": ImmichProviderConfig,
|
||||
"gitea": GiteaProviderConfig,
|
||||
@@ -89,6 +108,7 @@ _PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
|
||||
"scheduler": SchedulerProviderConfig,
|
||||
"nut": NutProviderConfig,
|
||||
"google_photos": GooglePhotosProviderConfig,
|
||||
"webhook": WebhookProviderConfig,
|
||||
}
|
||||
|
||||
|
||||
@@ -106,6 +126,70 @@ def _validate_provider_config(provider_type: str, config: dict[str, Any]) -> Non
|
||||
)
|
||||
|
||||
|
||||
async def _test_provider_connection(provider: ServiceProvider) -> dict[str, Any]:
|
||||
"""Test provider connection and return the result dict.
|
||||
|
||||
For providers that lack optional credentials (gitea without api_token,
|
||||
planka without api_key), returns a success stub.
|
||||
"""
|
||||
http_session = await get_http_session()
|
||||
|
||||
if provider.type == "immich":
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
return await immich.test_connection()
|
||||
|
||||
if provider.type == "gitea":
|
||||
if not provider.config.get("api_token"):
|
||||
return {"ok": True, "message": "Gitea webhook-only mode (no API token for testing)"}
|
||||
gitea = make_gitea_provider(http_session, provider)
|
||||
return await gitea.test_connection()
|
||||
|
||||
if provider.type == "planka":
|
||||
if not provider.config.get("api_key"):
|
||||
return {"ok": True, "message": "Planka webhook-only mode (no API key for testing)"}
|
||||
planka = make_planka_provider(http_session, provider)
|
||||
return await planka.test_connection()
|
||||
|
||||
if provider.type == "nut":
|
||||
nut = make_nut_provider(provider)
|
||||
return await nut.test_connection()
|
||||
|
||||
if provider.type == "google_photos":
|
||||
gp = make_google_photos_provider(http_session, provider)
|
||||
return await gp.test_connection()
|
||||
|
||||
if provider.type in ("scheduler", "webhook"):
|
||||
return {"ok": True, "message": "Virtual provider — always available"}
|
||||
|
||||
return {"ok": False, "message": f"Unknown provider type: {provider.type}"}
|
||||
|
||||
|
||||
async def _validate_provider_connection(provider: ServiceProvider) -> dict[str, Any]:
|
||||
"""Test provider connection. Raise HTTPException on failure.
|
||||
|
||||
Returns the test_result dict on success (caller may inspect extra fields
|
||||
like ``external_domain``).
|
||||
"""
|
||||
try:
|
||||
test_result = await _test_provider_connection(provider)
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
except OSError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", f"Cannot connect to {provider.type} provider"),
|
||||
)
|
||||
return test_result
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_providers(
|
||||
user: User = Depends(get_current_user),
|
||||
@@ -128,96 +212,15 @@ async def create_provider(
|
||||
"""Add a new service provider (validates connection for known types)."""
|
||||
_validate_provider_config(body.type, body.config)
|
||||
|
||||
# Validate connection for known provider types
|
||||
try:
|
||||
if body.type == "immich":
|
||||
from notify_bridge_core.providers.immich import ImmichServiceProvider
|
||||
config = body.config
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = ImmichServiceProvider(
|
||||
http_session, config.get("url", ""), config.get("api_key", ""),
|
||||
config.get("external_domain"), body.name,
|
||||
)
|
||||
test_result = await immich.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", f"Cannot connect to {body.type} provider"),
|
||||
)
|
||||
# Store external_domain from server config if available
|
||||
if test_result.get("external_domain"):
|
||||
config["external_domain"] = test_result["external_domain"]
|
||||
# Build a temporary ServiceProvider for connection testing
|
||||
temp_provider = ServiceProvider(
|
||||
id=0, user_id=0, type=body.type, name=body.name, config=body.config,
|
||||
)
|
||||
test_result = await _validate_provider_connection(temp_provider)
|
||||
|
||||
elif body.type == "gitea":
|
||||
config = body.config
|
||||
# api_token is optional (webhook_secret is required, but token only for repo listing)
|
||||
if config.get("api_token"):
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
from notify_bridge_core.providers.gitea import GiteaServiceProvider
|
||||
gitea = GiteaServiceProvider(
|
||||
http_session, config.get("url", ""), config.get("api_token", ""), body.name,
|
||||
)
|
||||
test_result = await gitea.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Gitea"),
|
||||
)
|
||||
|
||||
elif body.type == "planka":
|
||||
config = body.config
|
||||
if config.get("api_key"):
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
from notify_bridge_core.providers.planka import PlankaServiceProvider
|
||||
planka = PlankaServiceProvider(
|
||||
http_session, config.get("url", ""), config.get("api_key", ""), body.name,
|
||||
)
|
||||
test_result = await planka.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Planka"),
|
||||
)
|
||||
|
||||
elif body.type == "nut":
|
||||
nut = make_nut_provider(ServiceProvider(
|
||||
id=0, user_id=0, type="nut", name=body.name, config=body.config,
|
||||
))
|
||||
test_result = await nut.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to NUT server"),
|
||||
)
|
||||
|
||||
elif body.type == "google_photos":
|
||||
config = body.config
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
from notify_bridge_core.providers.google_photos import GooglePhotosServiceProvider
|
||||
gp = GooglePhotosServiceProvider(
|
||||
http_session, config.get("client_id", ""), config.get("client_secret", ""),
|
||||
config.get("refresh_token", ""), body.name,
|
||||
)
|
||||
test_result = await gp.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Google Photos"),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
except OSError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
|
||||
# Scheduler: no validation needed (virtual provider)
|
||||
# Store external_domain from Immich server config if available
|
||||
if test_result.get("external_domain"):
|
||||
body.config["external_domain"] = test_result["external_domain"]
|
||||
|
||||
provider = ServiceProvider(
|
||||
user_id=user.id,
|
||||
@@ -307,78 +310,10 @@ async def update_provider(
|
||||
provider.config = body.config
|
||||
|
||||
# Re-validate connection when config changes for known provider types
|
||||
if config_changed and provider.type == "immich":
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
test_result = await immich.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", f"Cannot connect to {provider.type} provider"),
|
||||
)
|
||||
if test_result.get("external_domain"):
|
||||
provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
elif config_changed and provider.type == "gitea":
|
||||
if provider.config.get("api_token"):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gitea = make_gitea_provider(http_session, provider)
|
||||
test_result = await gitea.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Gitea"),
|
||||
)
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
elif config_changed and provider.type == "planka":
|
||||
if provider.config.get("api_key"):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
planka = make_planka_provider(http_session, provider)
|
||||
test_result = await planka.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Planka"),
|
||||
)
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
elif config_changed and provider.type == "nut":
|
||||
nut = make_nut_provider(provider)
|
||||
test_result = await nut.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to NUT server"),
|
||||
)
|
||||
elif config_changed and provider.type == "google_photos":
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gp = make_google_photos_provider(http_session, provider)
|
||||
test_result = await gp.test_connection()
|
||||
if not test_result.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=test_result.get("message", "Cannot connect to Google Photos"),
|
||||
)
|
||||
except aiohttp.ClientError as err:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Connection error: {err}",
|
||||
)
|
||||
if config_changed:
|
||||
test_result = await _validate_provider_connection(provider)
|
||||
if test_result.get("external_domain"):
|
||||
provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
|
||||
|
||||
session.add(provider)
|
||||
await session.commit()
|
||||
@@ -408,39 +343,7 @@ async def test_provider(
|
||||
):
|
||||
"""Check if a service provider is reachable."""
|
||||
provider = await _get_user_provider(session, provider_id, user.id)
|
||||
|
||||
if provider.type == "immich":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
return await immich.test_connection()
|
||||
|
||||
if provider.type == "gitea":
|
||||
if not provider.config.get("api_token"):
|
||||
return {"ok": True, "message": "Gitea webhook-only mode (no API token for testing)"}
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gitea = make_gitea_provider(http_session, provider)
|
||||
return await gitea.test_connection()
|
||||
|
||||
if provider.type == "planka":
|
||||
if not provider.config.get("api_key"):
|
||||
return {"ok": True, "message": "Planka webhook-only mode (no API key for testing)"}
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
planka = make_planka_provider(http_session, provider)
|
||||
return await planka.test_connection()
|
||||
|
||||
if provider.type == "scheduler":
|
||||
return {"ok": True, "message": "Virtual provider — always available"}
|
||||
|
||||
if provider.type == "nut":
|
||||
nut = make_nut_provider(provider)
|
||||
return await nut.test_connection()
|
||||
|
||||
if provider.type == "google_photos":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gp = make_google_photos_provider(http_session, provider)
|
||||
return await gp.test_connection()
|
||||
|
||||
return {"ok": False, "message": f"Unknown provider type: {provider.type}"}
|
||||
return await _test_provider_connection(provider)
|
||||
|
||||
|
||||
@router.get("/{provider_id}/people")
|
||||
@@ -454,14 +357,14 @@ async def list_people(
|
||||
|
||||
if provider.type == "immich":
|
||||
from notify_bridge_core.providers.immich.client import ImmichClient
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
client = ImmichClient(
|
||||
http_session,
|
||||
provider.config.get("url", ""),
|
||||
provider.config.get("api_key", ""),
|
||||
)
|
||||
people = await client.get_people()
|
||||
return [{"id": pid, "name": name} for pid, name in people.items()]
|
||||
http_session = await get_http_session()
|
||||
client = ImmichClient(
|
||||
http_session,
|
||||
provider.config.get("url", ""),
|
||||
provider.config.get("api_key", ""),
|
||||
)
|
||||
people = await client.get_people()
|
||||
return [{"id": pid, "name": name} for pid, name in people.items()]
|
||||
|
||||
return []
|
||||
|
||||
@@ -475,35 +378,7 @@ async def list_collections(
|
||||
"""Fetch collections from a service provider."""
|
||||
provider = await _get_user_provider(session, provider_id, user.id)
|
||||
|
||||
if provider.type == "immich":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
return await immich.list_collections()
|
||||
|
||||
if provider.type == "gitea":
|
||||
if not provider.config.get("api_token"):
|
||||
return []
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gitea = make_gitea_provider(http_session, provider)
|
||||
return await gitea.list_collections()
|
||||
|
||||
if provider.type == "planka":
|
||||
if not provider.config.get("api_key"):
|
||||
return []
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
planka = make_planka_provider(http_session, provider)
|
||||
return await planka.list_collections()
|
||||
|
||||
if provider.type == "nut":
|
||||
nut = make_nut_provider(provider)
|
||||
return await nut.list_collections()
|
||||
|
||||
if provider.type == "google_photos":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
gp = make_google_photos_provider(http_session, provider)
|
||||
return await gp.list_collections()
|
||||
|
||||
return []
|
||||
return await list_provider_collections(provider)
|
||||
|
||||
|
||||
@router.get("/{provider_id}/albums/{album_id}/shared-links")
|
||||
@@ -517,19 +392,19 @@ async def get_album_shared_links(
|
||||
provider = await _get_user_provider(session, provider_id, user.id)
|
||||
|
||||
if provider.type == "immich":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
links = await immich.client.get_shared_links(album_id)
|
||||
return [
|
||||
{
|
||||
"id": link.id,
|
||||
"key": link.key,
|
||||
"has_password": link.has_password,
|
||||
"is_expired": link.is_expired,
|
||||
"is_accessible": link.is_accessible,
|
||||
}
|
||||
for link in links
|
||||
]
|
||||
http_session = await get_http_session()
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
links = await immich.client.get_shared_links(album_id)
|
||||
return [
|
||||
{
|
||||
"id": link.id,
|
||||
"key": link.key,
|
||||
"has_password": link.has_password,
|
||||
"is_expired": link.is_expired,
|
||||
"is_accessible": link.is_accessible,
|
||||
}
|
||||
for link in links
|
||||
]
|
||||
|
||||
return []
|
||||
|
||||
@@ -545,15 +420,13 @@ async def create_album_shared_link(
|
||||
provider = await _get_user_provider(session, provider_id, user.id)
|
||||
|
||||
if provider.type == "immich":
|
||||
async with aiohttp.ClientSession() as http_session:
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
success = await immich.client.create_shared_link(album_id)
|
||||
if success:
|
||||
return {"success": True}
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=400, detail="Failed to create shared link")
|
||||
http_session = await get_http_session()
|
||||
immich = make_immich_provider(http_session, provider)
|
||||
success = await immich.client.create_shared_link(album_id)
|
||||
if success:
|
||||
return {"success": True}
|
||||
raise HTTPException(status_code=400, detail="Failed to create shared link")
|
||||
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=400, detail="Provider type does not support shared links")
|
||||
|
||||
|
||||
@@ -580,7 +453,7 @@ async def _get_user_provider(
|
||||
session: AsyncSession, provider_id: int, user_id: int
|
||||
) -> ServiceProvider:
|
||||
"""Get a provider owned by the user, or raise 404."""
|
||||
provider = await session.get(ServiceProvider, provider_id)
|
||||
if not provider or provider.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
return provider
|
||||
return await get_owned_entity(
|
||||
session, ServiceProvider, provider_id, user_id,
|
||||
not_found_msg="Provider not found",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Shared slot load/save and Jinja2 preview helpers for template config APIs."""
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
from jinja2 import StrictUndefined, TemplateSyntaxError, UndefinedError
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
S = TypeVar("S", bound=SQLModel)
|
||||
|
||||
|
||||
async def load_slots(
|
||||
session: AsyncSession,
|
||||
slot_model: type[S],
|
||||
config_id: int,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Load all template slots for a config as {slot_name: {locale: template}}.
|
||||
|
||||
Works for both TemplateSlot and CommandTemplateSlot — they share the same
|
||||
column names (config_id, slot_name, locale, template).
|
||||
"""
|
||||
result = await session.exec(
|
||||
select(slot_model).where(slot_model.config_id == config_id) # type: ignore[attr-defined]
|
||||
)
|
||||
slots: dict[str, dict[str, str]] = {}
|
||||
for s in result.all():
|
||||
slots.setdefault(s.slot_name, {})[s.locale] = s.template # type: ignore[attr-defined]
|
||||
return slots
|
||||
|
||||
|
||||
async def save_slots(
|
||||
session: AsyncSession,
|
||||
slot_model: type[S],
|
||||
config_id: int,
|
||||
slots: dict[str, dict[str, str]],
|
||||
) -> None:
|
||||
"""Create or update template slots for a config (locale-aware).
|
||||
|
||||
Works for both TemplateSlot and CommandTemplateSlot.
|
||||
"""
|
||||
for slot_name, locale_map in slots.items():
|
||||
for locale, template_text in locale_map.items():
|
||||
result = await session.exec(
|
||||
select(slot_model).where(
|
||||
slot_model.config_id == config_id, # type: ignore[attr-defined]
|
||||
slot_model.slot_name == slot_name, # type: ignore[attr-defined]
|
||||
slot_model.locale == locale, # type: ignore[attr-defined]
|
||||
)
|
||||
)
|
||||
existing = result.first()
|
||||
if existing:
|
||||
existing.template = template_text # type: ignore[attr-defined]
|
||||
session.add(existing)
|
||||
else:
|
||||
session.add(slot_model(
|
||||
config_id=config_id,
|
||||
slot_name=slot_name,
|
||||
locale=locale,
|
||||
template=template_text,
|
||||
))
|
||||
|
||||
|
||||
def render_template_preview(template: str, context: dict) -> dict:
|
||||
"""Two-pass Jinja2 render: syntax check, then strict render.
|
||||
|
||||
Returns a dict with either ``{"rendered": str}`` on success, or
|
||||
``{"rendered": None, "error": str, ...}`` on failure.
|
||||
"""
|
||||
# Pass 1: syntax check (default Undefined — catches parse errors only)
|
||||
try:
|
||||
env = SandboxedEnvironment(autoescape=False)
|
||||
env.from_string(template)
|
||||
except TemplateSyntaxError as e:
|
||||
return {
|
||||
"rendered": None,
|
||||
"error": e.message,
|
||||
"error_line": e.lineno,
|
||||
}
|
||||
|
||||
# Pass 2: render with StrictUndefined to catch unknown variables
|
||||
try:
|
||||
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
|
||||
tmpl = strict_env.from_string(template)
|
||||
rendered = tmpl.render(**context)
|
||||
return {"rendered": rendered}
|
||||
except UndefinedError as e:
|
||||
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
|
||||
except Exception as e:
|
||||
return {"rendered": None, "error": str(e), "error_line": None}
|
||||
@@ -112,8 +112,16 @@ async def get_nav_counts(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Return entity counts for sidebar navigation badges."""
|
||||
counts = {}
|
||||
"""Return entity counts for sidebar navigation badges.
|
||||
|
||||
Note: queries run sequentially because SQLAlchemy AsyncSession is NOT safe
|
||||
for concurrent use within a single session (no asyncio.gather). We
|
||||
minimise round-trips by combining user + system counts and per-type
|
||||
target counts into single aggregate queries where possible.
|
||||
"""
|
||||
counts: dict[str, int] = {}
|
||||
|
||||
# --- 1) User-owned entity counts (one query per model) ---
|
||||
for model, key in [
|
||||
(ServiceProvider, "providers"),
|
||||
(NotificationTracker, "notification_trackers"),
|
||||
@@ -132,7 +140,7 @@ async def get_nav_counts(
|
||||
)).one()
|
||||
counts[key] = count
|
||||
|
||||
# System-owned entities (user_id=0) count as well
|
||||
# --- 2) Add system-owned counts (user_id=0) for shared entities ---
|
||||
for model, key in [
|
||||
(TemplateConfig, "template_configs"),
|
||||
(CommandTemplateConfig, "command_template_configs"),
|
||||
@@ -144,15 +152,22 @@ async def get_nav_counts(
|
||||
)).one()
|
||||
counts[key] += system_count
|
||||
|
||||
# Per-type target counts for nav badges
|
||||
for target_type in ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix"):
|
||||
type_count = (await session.exec(
|
||||
select(func.count()).select_from(NotificationTarget).where(
|
||||
NotificationTarget.user_id == user.id,
|
||||
NotificationTarget.type == target_type,
|
||||
)
|
||||
)).one()
|
||||
counts[f"targets_{target_type}"] = type_count
|
||||
# --- 3) Per-type target counts in a single query using conditional aggregation ---
|
||||
target_types = ("telegram", "webhook", "email", "discord", "slack", "ntfy", "matrix")
|
||||
type_counts_result = (await session.exec(
|
||||
select(
|
||||
NotificationTarget.type,
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
NotificationTarget.user_id == user.id,
|
||||
NotificationTarget.type.in_(target_types),
|
||||
)
|
||||
.group_by(NotificationTarget.type)
|
||||
)).all()
|
||||
type_counts_map = dict(type_counts_result)
|
||||
for target_type in target_types:
|
||||
counts[f"targets_{target_type}"] = type_counts_map.get(target_type, 0)
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import NotificationTarget, TargetReceiver, User
|
||||
from ..services.notifier import send_to_receiver
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -170,7 +171,7 @@ def _response(r: TargetReceiver) -> dict:
|
||||
|
||||
|
||||
async def _get_user_target(session: AsyncSession, target_id: int, user_id: int) -> NotificationTarget:
|
||||
target = await session.get(NotificationTarget, target_id)
|
||||
if not target or target.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Target not found")
|
||||
return target
|
||||
return await get_owned_entity(
|
||||
session, NotificationTarget, target_id, user_id,
|
||||
not_found_msg="Target not found",
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import NotificationTarget, NotificationTrackerTarget, TargetReceiver, TelegramBot, TelegramChat, User
|
||||
from ..services.notifier import send_test_notification
|
||||
from .helpers import get_owned_entity
|
||||
from .target_receivers import _receiver_key
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -306,8 +307,15 @@ async def _validate_broadcast_children(
|
||||
return
|
||||
if exclude_target_id and exclude_target_id in child_ids:
|
||||
raise HTTPException(status_code=400, detail="A broadcast target cannot include itself")
|
||||
|
||||
# Batch-load all children in a single IN query instead of N+1 individual fetches
|
||||
children = (await session.exec(
|
||||
select(NotificationTarget).where(NotificationTarget.id.in_(child_ids))
|
||||
)).all()
|
||||
children_by_id = {c.id: c for c in children}
|
||||
|
||||
for child_id in child_ids:
|
||||
child = await session.get(NotificationTarget, child_id)
|
||||
child = children_by_id.get(child_id)
|
||||
if not child or child.user_id != user_id:
|
||||
raise HTTPException(status_code=400, detail=f"Child target {child_id} not found")
|
||||
if child.type == "broadcast":
|
||||
@@ -378,7 +386,7 @@ def _safe_config(target: NotificationTarget) -> dict:
|
||||
async def _get_user_target(
|
||||
session: AsyncSession, target_id: int, user_id: int
|
||||
) -> NotificationTarget:
|
||||
target = await session.get(NotificationTarget, target_id)
|
||||
if not target or target.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Target not found")
|
||||
return target
|
||||
return await get_owned_entity(
|
||||
session, NotificationTarget, target_id, user_id,
|
||||
not_found_msg="Target not found",
|
||||
)
|
||||
|
||||
@@ -7,8 +7,6 @@ from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.notifications.telegram.client import TelegramClient
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
@@ -19,6 +17,7 @@ from ..database.models import AppSetting, NotificationTarget, TargetReceiver, Te
|
||||
from ..services.notifier import _get_test_message
|
||||
from ..services.telegram_poller import schedule_bot_polling, unschedule_bot_polling
|
||||
from .app_settings import get_setting
|
||||
from .helpers import get_owned_entity
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -290,10 +289,11 @@ async def test_chat(
|
||||
):
|
||||
"""Send a test message to a chat via the bot."""
|
||||
bot = await _get_user_bot(session, bot_id, user.id)
|
||||
from ..services.http_session import get_http_session
|
||||
message = _get_test_message(locale, "telegram")
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, bot.token)
|
||||
return await client.send_message(chat_id, message)
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, bot.token)
|
||||
return await client.send_message(chat_id, message)
|
||||
|
||||
|
||||
class ChatUpdate(BaseModel):
|
||||
@@ -344,41 +344,44 @@ async def delete_chat(
|
||||
|
||||
async def _get_webhook_info(token: str) -> dict | None:
|
||||
"""Call Telegram getWebhookInfo via TelegramClient."""
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_webhook_info()
|
||||
return result.get("result") if result.get("success") else None
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_webhook_info()
|
||||
return result.get("result") if result.get("success") else None
|
||||
|
||||
|
||||
async def _get_me(token: str) -> dict | None:
|
||||
"""Call Telegram getMe via TelegramClient."""
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_me()
|
||||
return result.get("result") if result.get("success") else None
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_me()
|
||||
return result.get("result") if result.get("success") else None
|
||||
|
||||
|
||||
async def _fetch_chats_from_telegram(token: str) -> list[dict]:
|
||||
"""Fetch chats from Telegram getUpdates via TelegramClient."""
|
||||
async with aiohttp.ClientSession() as http:
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_updates(limit=100)
|
||||
if not result.get("success"):
|
||||
return []
|
||||
from ..services.http_session import get_http_session
|
||||
http = await get_http_session()
|
||||
client = TelegramClient(http, token)
|
||||
result = await client.get_updates(limit=100)
|
||||
if not result.get("success"):
|
||||
return []
|
||||
|
||||
seen: dict[int, dict] = {}
|
||||
for update in result.get("result", []):
|
||||
msg = update.get("message", {})
|
||||
chat = msg.get("chat", {})
|
||||
chat_id = chat.get("id")
|
||||
if chat_id and chat_id not in seen:
|
||||
seen[chat_id] = {
|
||||
"id": chat_id,
|
||||
"title": chat.get("title") or (chat.get("first_name", "") + (" " + chat.get("last_name", "")).strip()),
|
||||
"type": chat.get("type", "private"),
|
||||
"username": chat.get("username", ""),
|
||||
}
|
||||
return list(seen.values())
|
||||
seen: dict[int, dict] = {}
|
||||
for update in result.get("result", []):
|
||||
msg = update.get("message", {})
|
||||
chat = msg.get("chat", {})
|
||||
chat_id = chat.get("id")
|
||||
if chat_id and chat_id not in seen:
|
||||
seen[chat_id] = {
|
||||
"id": chat_id,
|
||||
"title": chat.get("title") or (chat.get("first_name", "") + (" " + chat.get("last_name", "")).strip()),
|
||||
"type": chat.get("type", "private"),
|
||||
"username": chat.get("username", ""),
|
||||
}
|
||||
return list(seen.values())
|
||||
|
||||
|
||||
def _chat_response(c: TelegramChat) -> dict:
|
||||
@@ -410,10 +413,9 @@ def _bot_response(b: TelegramBot) -> dict:
|
||||
|
||||
|
||||
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> TelegramBot:
|
||||
bot = await session.get(TelegramBot, bot_id)
|
||||
if not bot or bot.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Bot not found")
|
||||
return bot
|
||||
return await get_owned_entity(
|
||||
session, TelegramBot, bot_id, user_id, not_found_msg="Bot not found",
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -13,12 +13,12 @@ from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import TemplateConfig, TemplateSlot, User
|
||||
from ..services.sample_context import _SAMPLE_CONTEXT
|
||||
from .slot_helpers import load_slots, render_template_preview, save_slots
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,40 +49,13 @@ class TemplateConfigUpdate(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, dict[str, str]]:
|
||||
"""Load all template slots for a config as {slot_name: {locale: template}}."""
|
||||
result = await session.exec(
|
||||
select(TemplateSlot).where(TemplateSlot.config_id == config_id)
|
||||
)
|
||||
slots: dict[str, dict[str, str]] = {}
|
||||
for s in result.all():
|
||||
slots.setdefault(s.slot_name, {})[s.locale] = s.template
|
||||
return slots
|
||||
return await load_slots(session, TemplateSlot, config_id)
|
||||
|
||||
|
||||
async def _save_slots(
|
||||
session: AsyncSession, config_id: int, slots: dict[str, dict[str, str]]
|
||||
) -> None:
|
||||
"""Create or update template slots for a config (locale-aware)."""
|
||||
for slot_name, locale_map in slots.items():
|
||||
for locale, template_text in locale_map.items():
|
||||
result = await session.exec(
|
||||
select(TemplateSlot).where(
|
||||
TemplateSlot.config_id == config_id,
|
||||
TemplateSlot.slot_name == slot_name,
|
||||
TemplateSlot.locale == locale,
|
||||
)
|
||||
)
|
||||
existing = result.first()
|
||||
if existing:
|
||||
existing.template = template_text
|
||||
session.add(existing)
|
||||
else:
|
||||
session.add(TemplateSlot(
|
||||
config_id=config_id,
|
||||
slot_name=slot_name,
|
||||
locale=locale,
|
||||
template=template_text,
|
||||
))
|
||||
await save_slots(session, TemplateSlot, config_id, slots)
|
||||
|
||||
|
||||
async def _response(session: AsyncSession, c: TemplateConfig) -> dict[str, Any]:
|
||||
@@ -155,7 +128,7 @@ async def get_template_variables(
|
||||
"photo_count": "Total photo count in album",
|
||||
"video_count": "Total video count in album",
|
||||
"owner": "Album owner name",
|
||||
"target_type": "Target type: 'telegram' or 'webhook'",
|
||||
"target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
|
||||
"has_videos": "Whether added assets contain videos (boolean)",
|
||||
"has_photos": "Whether added assets contain photos (boolean)",
|
||||
"has_oversized_videos": "Whether any video exceeds the target's size limit (boolean)",
|
||||
@@ -206,7 +179,7 @@ async def get_template_variables(
|
||||
}
|
||||
scheduled_vars = {
|
||||
"date": "Current date string",
|
||||
"target_type": "Target type: 'telegram' or 'webhook'",
|
||||
"target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -284,7 +257,7 @@ def _webhook_variables() -> dict:
|
||||
"source_ip": "IP address of the webhook sender",
|
||||
"raw_payload": "Full JSON payload as dict (use raw_payload.field or raw_payload | tojson)",
|
||||
"timestamp": "When the webhook was received",
|
||||
"target_type": "Target type: 'telegram' or 'webhook'",
|
||||
"target_type": "Target type: telegram, webhook, email, discord, slack, ntfy, or matrix",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -529,7 +502,7 @@ async def preview_date_format(
|
||||
|
||||
class PreviewRequest(BaseModel):
|
||||
template: str
|
||||
target_type: str = "telegram" # "telegram" or "webhook"
|
||||
target_type: str = "telegram" # telegram, webhook, email, discord, slack, ntfy, matrix
|
||||
date_format: str = "%d.%m.%Y, %H:%M UTC"
|
||||
date_only_format: str = "%d.%m.%Y"
|
||||
|
||||
@@ -545,33 +518,12 @@ async def preview_raw(
|
||||
1. Parse with default Undefined (catches syntax errors)
|
||||
2. Render with StrictUndefined (catches unknown variables like {{ asset.a }})
|
||||
"""
|
||||
# Pass 1: syntax check
|
||||
from datetime import datetime
|
||||
ctx = {**_SAMPLE_CONTEXT, "target_type": body.target_type,
|
||||
"date_format": body.date_format, "date_only_format": body.date_only_format}
|
||||
# Format common_date using the provided date_only_format
|
||||
try:
|
||||
env = SandboxedEnvironment(autoescape=False)
|
||||
env.from_string(body.template)
|
||||
except TemplateSyntaxError as e:
|
||||
return {
|
||||
"rendered": None,
|
||||
"error": e.message,
|
||||
"error_line": e.lineno,
|
||||
}
|
||||
|
||||
# Pass 2: render with strict undefined to catch unknown variables
|
||||
try:
|
||||
from datetime import datetime
|
||||
ctx = {**_SAMPLE_CONTEXT, "target_type": body.target_type,
|
||||
"date_format": body.date_format, "date_only_format": body.date_only_format}
|
||||
# Format common_date using the provided date_only_format
|
||||
try:
|
||||
ctx["common_date"] = datetime(2026, 3, 19).strftime(body.date_only_format)
|
||||
except (ValueError, TypeError):
|
||||
ctx["common_date"] = "19.03.2026"
|
||||
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
|
||||
tmpl = strict_env.from_string(body.template)
|
||||
rendered = tmpl.render(**ctx)
|
||||
return {"rendered": rendered}
|
||||
except UndefinedError as e:
|
||||
# Still a valid template syntactically, but references unknown variable
|
||||
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
|
||||
except Exception as e:
|
||||
return {"rendered": None, "error": str(e), "error_line": None}
|
||||
ctx["common_date"] = datetime(2026, 3, 19).strftime(body.date_only_format)
|
||||
except (ValueError, TypeError):
|
||||
ctx["common_date"] = "19.03.2026"
|
||||
return render_template_preview(body.template, ctx)
|
||||
|
||||
Reference in New Issue
Block a user