Files
notify-bridge/packages/server/src/notify_bridge_server/api/providers.py
T
alexei.dolgolyov b803d004e1 refactor: comprehensive codebase review — security, performance, quality, UX
Security:
- Fix NUT protocol command injection (validate names against safe regex)
- Enable Jinja2 autoescape=True to prevent HTML injection via external data
- Add WebhookProviderConfig validation model

Performance:
- Shared aiohttp.ClientSession singleton (replaces 40+ per-request sessions)
- Fix 4 N+1 queries with batch IN loads (poller, scheduler, memory, broadcast)
- asyncio.gather for Gitea commands and notification dispatcher
- Add DB indexes on NotificationTrackerState.tracker_id, CommandTrackerListener
- LRU cache for compiled Jinja2 templates
- Daily EventLog cleanup job (90-day retention)
- 30s HTTP timeout on all external calls
- GROUP BY for target type counts (replaces 7 sequential queries)

Code quality:
- Extract get_owned_entity() helper (replaces 11 duplicate functions)
- Extract slot_helpers.py (load_slots, save_slots, render_template_preview)
- Extract command_utils.py (tracker lookup, last event, collection IDs)
- Extract http_session.py (shared session lifecycle)
- Provider connection validation dedup (3x → 1 helper)
- Command dispatch tables replacing if/elif chains
- Album+links fetch helper (fetch_albums_with_links)
- Provider dispatch polymorphism (list_provider_collections)
- Immutable _enrich_assets (no longer mutates in-place)
- Fix _format_assets return type + handler unpacking

Frontend:
- Fix 18+ hardcoded English strings → t() with new i18n keys (en + ru)
- Mobile "More" nav panel with provider filter and search
- Shared Button.svelte component (4 variants, 2 sizes)
- Shared ErrorBanner.svelte component (8 pages updated)
- SvelteKit goto() replacing window.location.href
- Dashboard grid fixed for 4 cards, paginator opacity consistency

Functionality:
- max_instances=1 on scheduler jobs (prevents duplicate events)
- Webhook provider in watcher (prevents error spam)
- Fix stale SQLModel reference in poller
- Gitea get_repo() direct API call
2026-03-28 13:22:26 +03:00

460 lines
14 KiB
Python

"""Service provider management API routes."""
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, ValidationError
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from typing import Any
import aiohttp
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import ServiceProvider, User
from ..services import (
make_immich_provider, make_gitea_provider, make_planka_provider,
make_nut_provider, make_google_photos_provider, list_provider_collections,
)
from ..services.http_session import get_http_session
from .helpers import get_owned_entity
_LOGGER = logging.getLogger(__name__)
router = APIRouter(prefix="/api/providers", tags=["providers"])
class ProviderCreate(BaseModel):
type: str
name: str
icon: str = ""
config: dict[str, Any] = {}
class ProviderUpdate(BaseModel):
name: str | None = None
icon: str | None = None
config: dict[str, Any] | None = None
class ProviderResponse(BaseModel):
id: int
type: str
name: str
icon: str
config: dict[str, Any]
created_at: str
# -- Per-provider config validation models --
class ImmichProviderConfig(BaseModel):
url: str
api_key: str
external_domain: str | None = None
class GiteaProviderConfig(BaseModel):
url: str
webhook_secret: str
api_token: str | None = None
class PlankaProviderConfig(BaseModel):
url: str
webhook_secret: str
api_key: str | None = None
class SchedulerProviderConfig(BaseModel):
"""Scheduler is a virtual provider — no required fields."""
pass
class NutProviderConfig(BaseModel):
host: str
port: int = 3493
username: str | None = None
password: str | None = None
class GooglePhotosProviderConfig(BaseModel):
client_id: str
client_secret: str
refresh_token: str
class PayloadMapping(BaseModel):
variable: str
jsonpath: str
default: str | None = None
class WebhookProviderConfig(BaseModel):
auth_mode: str = "none"
webhook_secret: str | None = None
payload_mappings: list[PayloadMapping] = []
event_type_path: str | None = None
collection_path: str | None = None
_PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
"immich": ImmichProviderConfig,
"gitea": GiteaProviderConfig,
"planka": PlankaProviderConfig,
"scheduler": SchedulerProviderConfig,
"nut": NutProviderConfig,
"google_photos": GooglePhotosProviderConfig,
"webhook": WebhookProviderConfig,
}
def _validate_provider_config(provider_type: str, config: dict[str, Any]) -> None:
"""Validate provider config against the schema for the given type."""
config_model = _PROVIDER_CONFIG_MODELS.get(provider_type)
if config_model is None:
return
try:
config_model.model_validate(config)
except ValidationError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid config for '{provider_type}' provider: {exc}",
)
async def _test_provider_connection(provider: ServiceProvider) -> dict[str, Any]:
"""Test provider connection and return the result dict.
For providers that lack optional credentials (gitea without api_token,
planka without api_key), returns a success stub.
"""
http_session = await get_http_session()
if provider.type == "immich":
immich = make_immich_provider(http_session, provider)
return await immich.test_connection()
if provider.type == "gitea":
if not provider.config.get("api_token"):
return {"ok": True, "message": "Gitea webhook-only mode (no API token for testing)"}
gitea = make_gitea_provider(http_session, provider)
return await gitea.test_connection()
if provider.type == "planka":
if not provider.config.get("api_key"):
return {"ok": True, "message": "Planka webhook-only mode (no API key for testing)"}
planka = make_planka_provider(http_session, provider)
return await planka.test_connection()
if provider.type == "nut":
nut = make_nut_provider(provider)
return await nut.test_connection()
if provider.type == "google_photos":
gp = make_google_photos_provider(http_session, provider)
return await gp.test_connection()
if provider.type in ("scheduler", "webhook"):
return {"ok": True, "message": "Virtual provider — always available"}
return {"ok": False, "message": f"Unknown provider type: {provider.type}"}
async def _validate_provider_connection(provider: ServiceProvider) -> dict[str, Any]:
"""Test provider connection. Raise HTTPException on failure.
Returns the test_result dict on success (caller may inspect extra fields
like ``external_domain``).
"""
try:
test_result = await _test_provider_connection(provider)
except aiohttp.ClientError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
except OSError as err:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Connection error: {err}",
)
if not test_result.get("ok"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=test_result.get("message", f"Cannot connect to {provider.type} provider"),
)
return test_result
@router.get("")
async def list_providers(
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""List all service providers for the current user."""
result = await session.exec(
select(ServiceProvider).where(ServiceProvider.user_id == user.id)
)
providers = result.all()
return [_provider_response(p) for p in providers]
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_provider(
body: ProviderCreate,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Add a new service provider (validates connection for known types)."""
_validate_provider_config(body.type, body.config)
# Build a temporary ServiceProvider for connection testing
temp_provider = ServiceProvider(
id=0, user_id=0, type=body.type, name=body.name, config=body.config,
)
test_result = await _validate_provider_connection(temp_provider)
# Store external_domain from Immich server config if available
if test_result.get("external_domain"):
body.config["external_domain"] = test_result["external_domain"]
provider = ServiceProvider(
user_id=user.id,
type=body.type,
name=body.name,
icon=body.icon,
config=body.config,
)
session.add(provider)
await session.commit()
await session.refresh(provider)
return _provider_response(provider)
@router.get("/capabilities")
async def list_provider_capabilities(
user: User = Depends(get_current_user),
):
"""List capabilities for all registered provider types."""
from notify_bridge_core.providers.capabilities import get_all_capabilities
result = {}
for pt, caps in get_all_capabilities().items():
result[pt] = {
"provider_type": caps.provider_type,
"display_name": caps.display_name,
"notification_slots": caps.notification_slots,
"command_slots": caps.command_slots,
"events": caps.events,
"commands": caps.commands,
"supported_filters": caps.supported_filters,
"webhook_based": caps.webhook_based,
"action_types": caps.action_types,
}
return result
@router.get("/capabilities/{provider_type}")
async def get_provider_capabilities(
provider_type: str,
user: User = Depends(get_current_user),
):
"""Get capabilities for a provider type (events, slots, commands)."""
from notify_bridge_core.providers.capabilities import get_capabilities
caps = get_capabilities(provider_type)
if not caps:
raise HTTPException(status_code=404, detail=f"Unknown provider type: {provider_type}")
return {
"provider_type": caps.provider_type,
"display_name": caps.display_name,
"notification_slots": caps.notification_slots,
"command_slots": caps.command_slots,
"events": caps.events,
"commands": caps.commands,
"supported_filters": caps.supported_filters,
"webhook_based": caps.webhook_based,
}
@router.get("/{provider_id}")
async def get_provider(
provider_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Get a specific service provider."""
provider = await _get_user_provider(session, provider_id, user.id)
return _provider_response(provider)
@router.put("/{provider_id}")
async def update_provider(
provider_id: int,
body: ProviderUpdate,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Update a service provider."""
provider = await _get_user_provider(session, provider_id, user.id)
if body.name is not None:
provider.name = body.name
if body.icon is not None:
provider.icon = body.icon
config_changed = body.config is not None and body.config != provider.config
if body.config is not None:
_validate_provider_config(provider.type, body.config)
provider.config = body.config
# Re-validate connection when config changes for known provider types
if config_changed:
test_result = await _validate_provider_connection(provider)
if test_result.get("external_domain"):
provider.config = {**provider.config, "external_domain": test_result["external_domain"]}
session.add(provider)
await session.commit()
await session.refresh(provider)
return _provider_response(provider)
@router.delete("/{provider_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_provider(
provider_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Delete a service provider."""
from .delete_protection import check_service_provider, raise_if_used
provider = await _get_user_provider(session, provider_id, user.id)
raise_if_used(await check_service_provider(session, provider.id), provider.name)
await session.delete(provider)
await session.commit()
@router.post("/{provider_id}/test")
async def test_provider(
provider_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Check if a service provider is reachable."""
provider = await _get_user_provider(session, provider_id, user.id)
return await _test_provider_connection(provider)
@router.get("/{provider_id}/people")
async def list_people(
provider_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Fetch people from a service provider (Immich only)."""
provider = await _get_user_provider(session, provider_id, user.id)
if provider.type == "immich":
from notify_bridge_core.providers.immich.client import ImmichClient
http_session = await get_http_session()
client = ImmichClient(
http_session,
provider.config.get("url", ""),
provider.config.get("api_key", ""),
)
people = await client.get_people()
return [{"id": pid, "name": name} for pid, name in people.items()]
return []
@router.get("/{provider_id}/collections")
async def list_collections(
provider_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Fetch collections from a service provider."""
provider = await _get_user_provider(session, provider_id, user.id)
return await list_provider_collections(provider)
@router.get("/{provider_id}/albums/{album_id}/shared-links")
async def get_album_shared_links(
provider_id: int,
album_id: str,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Check shared links for a specific album."""
provider = await _get_user_provider(session, provider_id, user.id)
if provider.type == "immich":
http_session = await get_http_session()
immich = make_immich_provider(http_session, provider)
links = await immich.client.get_shared_links(album_id)
return [
{
"id": link.id,
"key": link.key,
"has_password": link.has_password,
"is_expired": link.is_expired,
"is_accessible": link.is_accessible,
}
for link in links
]
return []
@router.post("/{provider_id}/albums/{album_id}/shared-links")
async def create_album_shared_link(
provider_id: int,
album_id: str,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Auto-create a public shared link for an album."""
provider = await _get_user_provider(session, provider_id, user.id)
if provider.type == "immich":
http_session = await get_http_session()
immich = make_immich_provider(http_session, provider)
success = await immich.client.create_shared_link(album_id)
if success:
return {"success": True}
raise HTTPException(status_code=400, detail="Failed to create shared link")
raise HTTPException(status_code=400, detail="Provider type does not support shared links")
def _provider_response(p: ServiceProvider) -> dict:
"""Build a safe response dict for a provider."""
config = dict(p.config)
# Mask sensitive fields
for secret_field in ("api_key", "api_token", "webhook_secret", "password",
"client_secret", "refresh_token"):
if secret_field in config:
key = config[secret_field]
config[secret_field] = f"***{key[-4:]}" if len(key) > 4 else "***"
return {
"id": p.id,
"type": p.type,
"name": p.name,
"icon": p.icon,
"config": config,
"created_at": p.created_at.isoformat(),
}
async def _get_user_provider(
session: AsyncSession, provider_id: int, user_id: int
) -> ServiceProvider:
"""Get a provider owned by the user, or raise 404."""
return await get_owned_entity(
session, ServiceProvider, provider_id, user_id,
not_found_msg="Provider not found",
)