feat: comprehensive code review fixes — security, performance, quality

Backend security:
- Reject Gitea webhooks when webhook_secret is empty (was silently skipping)
- Add slowapi rate limiting on login (5/min) and setup (3/min) endpoints
- Add CORS middleware with configurable origins
- Mask telegram_webhook_secret in settings API response
- Protect system-owned command template configs from regular user modification
- Increase minimum password length to 8 characters

Backend performance:
- Batch queries in _resolve_command_context (3 queries instead of 3N)
- Concurrent album fetching with asyncio.gather in immich commands
- Singleton Jinja2 SandboxedEnvironment (reuse instead of per-render creation)
- TTLCache for rate limits (bounded memory, auto-eviction)
- Optional aiohttp session reuse in send_reply/send_media_group

Backend code quality:
- Extract dispatch_helpers.py (shared link_data loading + event filtering)
- Extract database/seeds.py from main.py (490 lines → dedicated module)
- Split immich_handler.py (415 lines) into commands/immich/ subpackage
- Replace bare except blocks with logged warnings
- Add per-provider config validation (Pydantic models)
- Truncate command input to 512 chars
- Expose usage_* and desc_* slots in capabilities and variables API

Frontend security:
- CSS.escape() for user-controlled querySelector in highlight.ts
- Client-side password min 8 chars validation on setup and password change

Frontend code quality:
- Replace any types with proper interfaces across top files
- Decompose targets/+page.svelte into TargetForm + ReceiverSection
- Fix $derived.by usage, $state mutation patterns
- Add console.warn to empty catch blocks

Frontend UX:
- Auth redirect via goto() with "Redirecting..." state
- Platform-aware Ctrl/Cmd K keyboard hint
- Remove stat-card hover transform

Frontend accessibility:
- Modal: role=dialog, aria-modal, focus trap, restore focus
- EntitySelect/IconGridSelect: listbox/option roles, aria-selected/disabled
This commit is contained in:
2026-03-23 01:59:51 +03:00
parent 31584c5d31
commit e0bae394ee
78 changed files with 2855 additions and 1658 deletions
@@ -57,7 +57,11 @@ async def get_settings(
"""Return all app settings."""
result = {}
for key in _SETTING_KEYS:
result[key] = await get_setting(session, key)
value = await get_setting(session, key)
if key == "telegram_webhook_secret" and value:
result[key] = f"***{value[-4:]}" if len(value) > 4 else "***"
else:
result[key] = value
return result
@@ -124,6 +124,7 @@ async def get_command_variables():
command_fields = {
"name": "Command name (e.g. status, albums)",
"description": "Command description text",
"usage": "Usage example (e.g. /search sunset) — only for commands that take arguments",
}
event_fields = {
"type": "Event type (assets_added, assets_removed, etc.)",
@@ -197,6 +198,31 @@ async def get_command_variables():
"description": "Empty results fallback",
"variables": {**common_vars, "command": "Command name", "query": "Search query (empty for non-search commands)"},
},
# --- Description slots (shown in /help listing) ---
"desc_help": {"description": "Description for /help command", "variables": common_vars},
"desc_status": {"description": "Description for /status command", "variables": common_vars},
"desc_albums": {"description": "Description for /albums command", "variables": common_vars},
"desc_events": {"description": "Description for /events command", "variables": common_vars},
"desc_summary": {"description": "Description for /summary command", "variables": common_vars},
"desc_latest": {"description": "Description for /latest command", "variables": common_vars},
"desc_memory": {"description": "Description for /memory command", "variables": common_vars},
"desc_random": {"description": "Description for /random command", "variables": common_vars},
"desc_search": {"description": "Description for /search command", "variables": common_vars},
"desc_find": {"description": "Description for /find command", "variables": common_vars},
"desc_person": {"description": "Description for /person command", "variables": common_vars},
"desc_place": {"description": "Description for /place command", "variables": common_vars},
"desc_favorites": {"description": "Description for /favorites command", "variables": common_vars},
"desc_people": {"description": "Description for /people command", "variables": common_vars},
# --- Usage example slots (shown in /help listing) ---
"usage_search": {"description": "Usage example for /search (e.g. '/search sunset')", "variables": common_vars},
"usage_find": {"description": "Usage example for /find", "variables": common_vars},
"usage_person": {"description": "Usage example for /person", "variables": common_vars},
"usage_place": {"description": "Usage example for /place", "variables": common_vars},
"usage_latest": {"description": "Usage example for /latest", "variables": common_vars},
"usage_random": {"description": "Usage example for /random", "variables": common_vars},
"usage_favorites": {"description": "Usage example for /favorites", "variables": common_vars},
"usage_events": {"description": "Usage example for /events", "variables": common_vars},
"usage_memory": {"description": "Usage example for /memory", "variables": common_vars},
}
@@ -256,6 +282,8 @@ async def update_config(
session: AsyncSession = Depends(get_session),
):
config = await _get(session, config_id, user.id)
if config.user_id == 0 and user.role != "admin":
raise HTTPException(status_code=403, detail="Cannot modify system default configs")
for field, value in body.model_dump(exclude_unset=True, exclude={"slots"}).items():
if value is not None:
setattr(config, field, value)
@@ -275,6 +303,8 @@ async def delete_config(
):
from .delete_protection import check_command_template_config, raise_if_used
config = await _get(session, config_id, user.id)
if config.user_id == 0 and user.role != "admin":
raise HTTPException(status_code=403, detail="Cannot delete system default configs")
raise_if_used(await check_command_template_config(session, config.id), config.name)
slot_result = await session.exec(
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == config.id)
@@ -306,9 +336,10 @@ async def preview_raw(
"last_event": "2026-03-19 14:30",
# /help
"commands": [
{"name": "status", "description": "Show tracker status"},
{"name": "albums", "description": "List tracked albums"},
{"name": "latest", "description": "Show latest photos"},
{"name": "status", "description": "Show tracker status", "usage": ""},
{"name": "albums", "description": "List tracked albums", "usage": ""},
{"name": "latest", "description": "Show latest photos", "usage": "/latest 10"},
{"name": "search", "description": "Smart search (AI)", "usage": "/search sunset at the beach"},
],
# /albums, /summary
"albums": [
@@ -3,7 +3,7 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from typing import Any
@@ -42,6 +42,48 @@ class ProviderResponse(BaseModel):
created_at: str
# -- Per-provider config validation models --
class ImmichProviderConfig(BaseModel):
url: str
api_key: str
external_domain: str | None = None
class GiteaProviderConfig(BaseModel):
url: str
webhook_secret: str
api_token: str | None = None
class SchedulerProviderConfig(BaseModel):
"""Scheduler is a virtual provider — no required fields."""
pass
_PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
"immich": ImmichProviderConfig,
"gitea": GiteaProviderConfig,
"scheduler": SchedulerProviderConfig,
}
def _validate_provider_config(provider_type: str, config: dict[str, Any]) -> None:
"""Validate provider config against the schema for the given type."""
config_model = _PROVIDER_CONFIG_MODELS.get(provider_type)
if config_model is None:
return
try:
config_model.model_validate(config)
except ValidationError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid config for '{provider_type}' provider: {exc}",
)
@router.get("")
async def list_providers(
user: User = Depends(get_current_user),
@@ -62,6 +104,8 @@ async def create_provider(
session: AsyncSession = Depends(get_session),
):
"""Add a new service provider (validates connection for known types)."""
_validate_provider_config(body.type, body.config)
# Validate connection for known provider types
if body.type == "immich":
from notify_bridge_core.providers.immich import ImmichServiceProvider
@@ -177,6 +221,7 @@ async def update_provider(
config_changed = body.config is not None and body.config != provider.config
if body.config is not None:
_validate_provider_config(provider.type, body.config)
provider.config = body.config
# Re-validate connection when config changes for known provider types
@@ -17,18 +17,11 @@ from notify_bridge_core.providers.gitea.event_parser import parse_webhook as par
from ..database.engine import get_engine
from ..database.models import (
EmailBot,
EventLog,
MatrixBot,
NotificationTarget,
NotificationTracker,
NotificationTrackerTarget,
ServiceProvider,
TargetReceiver,
TemplateConfig,
TemplateSlot,
TrackingConfig,
)
from ..services.dispatch_helpers import event_allowed_by_config, load_link_data
_LOGGER = logging.getLogger(__name__)
@@ -93,10 +86,15 @@ async def gitea_webhook(provider_id: int, request: Request):
# Read raw body for HMAC check
raw_body = await request.body()
if webhook_secret:
signature = request.headers.get("X-Gitea-Signature", "")
if not signature or not _verify_gitea_signature(webhook_secret, raw_body, signature):
raise HTTPException(status_code=403, detail="Invalid signature")
if not webhook_secret:
raise HTTPException(
status_code=403,
detail="Webhook secret not configured on this provider",
)
signature = request.headers.get("X-Gitea-Signature", "")
if not signature or not _verify_gitea_signature(webhook_secret, raw_body, signature):
raise HTTPException(status_code=403, detail="Invalid signature")
# Parse event header + payload
event_header = request.headers.get("X-Gitea-Event", "")
@@ -133,7 +131,7 @@ async def gitea_webhook(provider_id: int, request: Request):
continue
# Load tracker-target links
link_data = await _load_link_data(session, tracker.id)
link_data = await load_link_data(session, tracker.id)
if not link_data:
continue
@@ -176,122 +174,6 @@ async def gitea_webhook(provider_id: int, request: Request):
return {"ok": True, "dispatched": dispatched}
# ---------------------------------------------------------------------------
# Shared dispatch helpers (extracted from watcher pattern)
# ---------------------------------------------------------------------------
async def _load_link_data(
session: AsyncSession,
tracker_id: int,
) -> list[dict[str, Any]]:
"""Load tracker-target link data for dispatch (same pattern as watcher)."""
tt_result = await session.exec(
select(NotificationTrackerTarget).where(
NotificationTrackerTarget.tracker_id == tracker_id
)
)
tracker_targets = tt_result.all()
link_data: list[dict[str, Any]] = []
for tt in tracker_targets:
if not tt.enabled:
continue
target = await session.get(NotificationTarget, tt.target_id)
if not target:
continue
# Load receivers
recv_result = await session.exec(
select(TargetReceiver).where(
TargetReceiver.target_id == target.id,
TargetReceiver.enabled == True,
)
)
receivers = [dict(r.config) for r in recv_result.all()]
tracking_config = None
if tt.tracking_config_id:
tracking_config = await session.get(TrackingConfig, tt.tracking_config_id)
template_config = None
template_slots: dict[str, str] | None = None
if tt.template_config_id:
template_config = await session.get(TemplateConfig, tt.template_config_id)
if template_config:
slot_result = await session.exec(
select(TemplateSlot).where(TemplateSlot.config_id == template_config.id)
)
raw_slots = {s.slot_name: s.template for s in slot_result.all()}
template_slots = {}
for slot_name, tmpl_text in raw_slots.items():
event_key = slot_name.removeprefix("message_") if slot_name.startswith("message_") else slot_name
template_slots[event_key] = tmpl_text
target_config = dict(target.config)
# Inject chat_action for Telegram targets
if hasattr(target, 'chat_action') and target.chat_action:
target_config["chat_action"] = target.chat_action
# Inject bot credentials
if target.type == "email":
email_bot_id = target.config.get("email_bot_id")
if email_bot_id:
email_bot = await session.get(EmailBot, email_bot_id)
if email_bot:
target_config["smtp"] = {
"host": email_bot.smtp_host,
"port": email_bot.smtp_port,
"username": email_bot.smtp_username,
"password": email_bot.smtp_password,
"from_address": email_bot.email,
"from_name": email_bot.name,
"use_tls": email_bot.smtp_use_tls,
}
elif target.type == "matrix":
matrix_bot_id = target.config.get("matrix_bot_id")
if matrix_bot_id:
matrix_bot = await session.get(MatrixBot, matrix_bot_id)
if matrix_bot:
target_config["homeserver_url"] = matrix_bot.homeserver_url
target_config["access_token"] = matrix_bot.access_token
link_data.append({
"target_type": target.type,
"target_config": target_config,
"receivers": receivers,
"tracking_config": tracking_config,
"template_config": template_config,
"template_slots": template_slots,
})
return link_data
def _event_allowed_by_tracking_config(event: ServiceEvent, tc: TrackingConfig) -> bool:
"""Check if an event type is allowed by tracking config flags."""
event_type = event.event_type.value
flag_map = {
"push": tc.track_push,
"issue_opened": tc.track_issue_opened,
"issue_closed": tc.track_issue_closed,
"issue_commented": tc.track_issue_commented,
"pr_opened": tc.track_pr_opened,
"pr_closed": tc.track_pr_closed,
"pr_merged": tc.track_pr_merged,
"pr_commented": tc.track_pr_commented,
"release_published": tc.track_release_published,
# Scheduler events
"scheduled_message": tc.track_scheduled_message,
# Immich events
"assets_added": tc.track_assets_added,
"assets_removed": tc.track_assets_removed,
"collection_renamed": tc.track_collection_renamed,
"collection_deleted": tc.track_collection_deleted,
"sharing_changed": tc.track_sharing_changed,
}
return flag_map.get(event_type, True)
def _build_target_configs(
event: ServiceEvent,
link_data: list[dict[str, Any]],
@@ -301,7 +183,7 @@ def _build_target_configs(
target_configs: list[TargetConfig] = []
for ld in link_data:
tc = ld["tracking_config"]
if tc and not _event_allowed_by_tracking_config(event, tc):
if tc and not event_allowed_by_config(event, tc):
continue
tmpl = ld["template_config"]
@@ -1,7 +1,9 @@
"""Authentication API routes."""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from slowapi import Limiter
from slowapi.util import get_remote_address
from sqlmodel import func, select
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -14,6 +16,8 @@ from .jwt import create_access_token, create_refresh_token, decode_token
router = APIRouter(prefix="/api/auth", tags=["auth"])
limiter = Limiter(key_func=get_remote_address)
class SetupRequest(BaseModel):
username: str
@@ -50,14 +54,15 @@ def _verify_password(password: str, hashed: str) -> bool:
@router.post("/setup", response_model=TokenResponse)
async def setup(body: SetupRequest, session: AsyncSession = Depends(get_session)):
@limiter.limit("3/minute")
async def setup(request: Request, body: SetupRequest, session: AsyncSession = Depends(get_session)):
result = await session.exec(select(func.count()).select_from(User))
count = result.one()
if count > 0:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Setup already completed.")
if len(body.password) < 6:
raise HTTPException(status_code=400, detail="Password must be at least 6 characters")
if len(body.password) < 8:
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
user = User(username=body.username, hashed_password=_hash_password(body.password), role="admin")
session.add(user)
await session.commit()
@@ -70,7 +75,8 @@ async def setup(body: SetupRequest, session: AsyncSession = Depends(get_session)
@router.post("/login", response_model=TokenResponse)
async def login(body: LoginRequest, session: AsyncSession = Depends(get_session)):
@limiter.limit("5/minute")
async def login(request: Request, body: LoginRequest, session: AsyncSession = Depends(get_session)):
result = await session.exec(select(User).where(User.username == body.username))
user = result.first()
if not user or not _verify_password(body.password, user.hashed_password):
@@ -121,8 +127,8 @@ async def change_password(
):
if not _verify_password(body.current_password, user.hashed_password):
raise HTTPException(status_code=400, detail="Current password is incorrect")
if len(body.new_password) < 6:
raise HTTPException(status_code=400, detail="New password must be at least 6 characters")
if len(body.new_password) < 8:
raise HTTPException(status_code=400, detail="New password must be at least 8 characters")
user.hashed_password = _hash_password(body.new_password)
session.add(user)
await session.commit()
@@ -29,7 +29,7 @@ def get_all_handlers() -> dict[str, ProviderCommandHandler]:
def _auto_register() -> None:
"""Auto-register all built-in handlers."""
from .immich_handler import ImmichCommandHandler
from .immich import ImmichCommandHandler
from .gitea_handler import GiteaCommandHandler
register_handler(ImmichCommandHandler())
@@ -7,10 +7,12 @@ import time
from typing import Any
import aiohttp
from cachetools import TTLCache
from jinja2.sandbox import SandboxedEnvironment
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from notify_bridge_core.notifications.telegram.media import TELEGRAM_API_BASE_URL
from notify_bridge_core.notifications.telegram.client import TelegramClient
from ..database.engine import get_engine
from ..database.models import (
CommandConfig,
@@ -28,8 +30,11 @@ from .registry import get_rate_category
_LOGGER = logging.getLogger(__name__)
# Rate limit state: { (bot_id, chat_id, category): last_used_timestamp }
_rate_limits: dict[tuple[int, str, str], float] = {}
# Singleton Jinja2 environment for template rendering (Phase 4d)
_JINJA_ENV = SandboxedEnvironment(autoescape=False)
# Rate limit state with automatic TTL expiry (Phase 4e)
_rate_limits: TTLCache = TTLCache(maxsize=10000, ttl=3600)
def _check_rate_limit(bot_id: int, chat_id: str, cmd: str, limits: dict[str, int]) -> int | None:
@@ -65,9 +70,7 @@ def _render_cmd_template(
_LOGGER.warning("No command template found for slot '%s' locale '%s'", slot_name, locale)
return f"[No template: {slot_name}]"
try:
from jinja2.sandbox import SandboxedEnvironment
env = SandboxedEnvironment(autoescape=False)
tmpl = env.from_string(template_str)
tmpl = _JINJA_ENV.from_string(template_str)
return tmpl.render(**context)
except Exception as e:
_LOGGER.warning("Failed to render command template '%s': %s", slot_name, e)
@@ -95,15 +98,46 @@ async def _resolve_command_context(
if not listeners:
return [], {}
# Batch-fetch all referenced entities in 3 queries instead of N*3
tracker_ids = list({l.command_tracker_id for l in listeners})
tracker_result = await session.exec(
select(CommandTracker).where(CommandTracker.id.in_(tracker_ids))
)
trackers_by_id = {t.id: t for t in tracker_result.all()}
config_ids = list({
t.command_config_id for t in trackers_by_id.values()
if t.enabled and t.command_config_id
})
if config_ids:
config_result = await session.exec(
select(CommandConfig).where(CommandConfig.id.in_(config_ids))
)
configs_by_id = {c.id: c for c in config_result.all()}
else:
configs_by_id = {}
provider_ids = list({
t.provider_id for t in trackers_by_id.values()
if t.enabled and t.provider_id
})
if provider_ids:
provider_result = await session.exec(
select(ServiceProvider).where(ServiceProvider.id.in_(provider_ids))
)
providers_by_id = {p.id: p for p in provider_result.all()}
else:
providers_by_id = {}
tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]] = []
for listener in listeners:
tracker = await session.get(CommandTracker, listener.command_tracker_id)
tracker = trackers_by_id.get(listener.command_tracker_id)
if not tracker or not tracker.enabled:
continue
config = await session.get(CommandConfig, tracker.command_config_id)
config = configs_by_id.get(tracker.command_config_id)
if not config:
continue
provider = await session.get(ServiceProvider, tracker.provider_id)
provider = providers_by_id.get(tracker.provider_id)
if not provider:
continue
tuples.append((tracker, config, provider))
@@ -220,7 +254,11 @@ def _cmd_help(
commands = []
for cmd in enabled:
desc_text = _resolve_template(templates, f"desc_{cmd}", locale) or cmd
commands.append({"name": cmd, "description": desc_text})
entry: dict[str, str] = {"name": cmd, "description": desc_text}
usage_text = _resolve_template(templates, f"usage_{cmd}", locale)
if usage_text:
entry["usage"] = usage_text
commands.append(entry)
return {"commands": commands}
@@ -240,128 +278,93 @@ async def _get_notification_trackers_for_providers(
return list(result.all())
async def send_reply(bot_token: str, chat_id: str, text: str) -> None:
"""Send a text reply via Telegram Bot API, retrying without HTML on parse failure."""
async with aiohttp.ClientSession() as http:
url = f"{TELEGRAM_API_BASE_URL}{bot_token}/sendMessage"
payload: dict[str, Any] = {"chat_id": chat_id, "text": text, "parse_mode": "HTML"}
try:
async with http.post(url, json=payload) as resp:
if resp.status != 200:
result = await resp.json()
_LOGGER.debug("Telegram reply failed: %s", result.get("description"))
if "parse" in str(result.get("description", "")).lower():
payload.pop("parse_mode", None)
async with http.post(url, json=payload) as retry_resp:
if retry_resp.status != 200:
_LOGGER.warning("Telegram reply failed on retry")
except aiohttp.ClientError as err:
_LOGGER.error("Failed to send Telegram reply: %s", err)
async def send_reply(
bot_token: str, chat_id: str, text: str, reply_to_message_id: int | None = None,
session: aiohttp.ClientSession | None = None,
) -> None:
"""Send a text reply via TelegramClient."""
async def _send(http: aiohttp.ClientSession) -> None:
client = TelegramClient(http, bot_token)
result = await client.send_message(chat_id, text, reply_to_message_id=reply_to_message_id)
if not result.get("success"):
_LOGGER.warning("Telegram reply failed: %s", result.get("error"))
if session is not None:
await _send(session)
else:
async with aiohttp.ClientSession() as http:
await _send(http)
async def send_media_group(
bot_token: str, chat_id: str, media_items: list[dict[str, Any]],
reply_to_message_id: int | None = None,
session: aiohttp.ClientSession | None = None,
) -> None:
"""Send media items as a Telegram media group (album)."""
"""Send media items via TelegramClient.send_notification."""
if not media_items:
return
async with aiohttp.ClientSession() as http:
downloaded: list[tuple[bytes, str, str]] = []
for item in media_items:
asset_id = item.get("asset_id", "")
caption = item.get("caption", "")
thumb_url = item.get("thumbnail_url", "")
api_key = item.get("api_key", "")
try:
async with http.get(thumb_url, headers={"x-api-key": api_key}) as resp:
if resp.status != 200:
_LOGGER.warning("Failed to download thumbnail for %s: HTTP %d", asset_id, resp.status)
continue
photo_bytes = await resp.read()
downloaded.append((photo_bytes, asset_id, caption))
except aiohttp.ClientError:
continue
# Convert command handler media format to TelegramClient asset format
assets = []
for item in media_items:
assets.append({
"type": "photo",
"url": item.get("thumbnail_url", ""),
"cache_key": item.get("asset_id", ""),
"headers": {"x-api-key": item.get("api_key", "")},
})
if not downloaded:
return
# Build caption from first item
captions = [item.get("caption", "") for item in media_items if item.get("caption")]
caption = "\n".join(captions) if captions else None
for i in range(0, len(downloaded), 10):
chunk = downloaded[i:i + 10]
async def _send(http: aiohttp.ClientSession) -> None:
client = TelegramClient(http, bot_token)
result = await client.send_notification(
chat_id, assets=assets, caption=caption,
reply_to_message_id=reply_to_message_id,
chat_action=None,
)
if not result.get("success"):
_LOGGER.warning("Telegram media group failed: %s", result.get("error"))
if len(chunk) == 1:
photo_bytes, asset_id, caption = chunk[0]
data = aiohttp.FormData()
data.add_field("chat_id", chat_id)
data.add_field("photo", photo_bytes, filename=f"{asset_id}.jpg", content_type="image/jpeg")
if caption:
data.add_field("caption", caption)
try:
async with http.post(f"{TELEGRAM_API_BASE_URL}{bot_token}/sendPhoto", data=data) as resp:
if resp.status != 200:
result = await resp.json()
_LOGGER.warning("Failed to send photo: %s", result.get("description"))
except aiohttp.ClientError as err:
_LOGGER.warning("Failed to send photo: %s", err)
else:
import json as _json
data = aiohttp.FormData()
data.add_field("chat_id", chat_id)
media_array = []
for idx, (photo_bytes, asset_id, caption) in enumerate(chunk):
attach_key = f"photo_{idx}"
media_obj: dict[str, Any] = {"type": "photo", "media": f"attach://{attach_key}"}
if caption:
media_obj["caption"] = caption
media_array.append(media_obj)
data.add_field(attach_key, photo_bytes, filename=f"{asset_id}.jpg", content_type="image/jpeg")
data.add_field("media", _json.dumps(media_array))
try:
async with http.post(f"{TELEGRAM_API_BASE_URL}{bot_token}/sendMediaGroup", data=data) as resp:
if resp.status != 200:
result = await resp.json()
_LOGGER.warning("Failed to send media group: %s", result.get("description"))
except aiohttp.ClientError as err:
_LOGGER.warning("Failed to send media group: %s", err)
if session is not None:
await _send(session)
else:
async with aiohttp.ClientSession() as http:
await _send(http)
async def register_commands_with_telegram(bot: TelegramBot) -> bool:
"""Register enabled commands with Telegram BotFather API."""
"""Register enabled commands with Telegram BotFather API via TelegramClient."""
ctx_tuples, templates = await _resolve_command_context(bot)
enabled, _, _, _ = _merge_command_context(ctx_tuples)
async with aiohttp.ClientSession() as http:
url = f"{TELEGRAM_API_BASE_URL}{bot.token}/setMyCommands"
client = TelegramClient(http, bot.token)
success = False
# Register per-locale commands
for locale in ("en", "ru"):
commands = []
for cmd in enabled:
desc = _resolve_template(templates, f"desc_{cmd}", locale) or cmd
commands.append({"command": cmd, "description": desc})
result = await client.set_my_commands(commands, language_code=locale)
if result.get("success"):
success = True
else:
_LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("error"))
payload: dict[str, Any] = {"commands": commands, "language_code": locale}
try:
async with http.post(url, json=payload) as resp:
result = await resp.json()
if result.get("ok"):
success = True
else:
_LOGGER.warning("Failed to register commands for locale '%s': %s", locale, result.get("description"))
except aiohttp.ClientError as err:
_LOGGER.error("Failed to register commands for locale '%s': %s", locale, err)
# Register default (no language_code) with EN descriptions
en_commands = []
for cmd in enabled:
desc = _resolve_template(templates, f"desc_{cmd}", "en") or cmd
en_commands.append({"command": cmd, "description": desc})
try:
async with http.post(url, json={"commands": en_commands}) as resp:
result = await resp.json()
if result.get("ok"):
_LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username)
success = True
except aiohttp.ClientError as err:
_LOGGER.error("Failed to register default commands: %s", err)
result = await client.set_my_commands(en_commands)
if result.get("success"):
_LOGGER.info("Registered %d commands for bot @%s (all locales)", len(en_commands), bot.bot_username)
success = True
return success
@@ -0,0 +1,5 @@
"""Immich command handler subpackage."""
from .handler import ImmichCommandHandler
__all__ = ["ImmichCommandHandler"]
@@ -0,0 +1,113 @@
"""Album-related Immich bot commands: albums, favorites, summary."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
import aiohttp
from ...database.models import ServiceProvider, TelegramBot
from ...services import make_immich_provider
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
from .common import _format_assets
_LOGGER = logging.getLogger(__name__)
async def _cmd_albums(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
if not trackers:
return {"albums": []}
albums_data: list[dict] = []
async with aiohttp.ClientSession() as http:
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "immich":
continue
immich = make_immich_provider(http, provider)
album_ids = tracker.collection_ids or []
if not album_ids:
continue
results = await asyncio.gather(
*[immich.client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
for album_id, result in zip(album_ids, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
albums_data.append({
"name": f"{album_id[:8]}...", "asset_count": "?", "id": album_id,
})
elif result:
albums_data.append({
"name": result.name, "asset_count": result.asset_count, "id": album_id,
})
return {"albums": albums_data}
async def cmd_favorites(
bot: TelegramBot, providers_map: dict[int, ServiceProvider],
all_album_ids: list[str], count: int, locale: str,
response_mode: str, client: Any,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Handle /favorites command with concurrent album fetching."""
album_ids = all_album_ids[:10]
if not album_ids:
return _format_assets([], "favorites", "", locale, response_mode, client, cmd_templates)
results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
fav_assets: list[dict[str, Any]] = []
for album_id, result in zip(album_ids, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
for aid, asset in list(result.assets.items())[:50]:
if asset.is_favorite and len(fav_assets) < count:
fav_assets.append({
"id": asset.id, "originalFileName": asset.filename,
"type": asset.type,
})
if len(fav_assets) >= count:
break
return _format_assets(fav_assets, "favorites", "", locale, response_mode, client, cmd_templates)
async def cmd_summary(
client: Any, all_album_ids: list[str], locale: str,
cmd_templates: dict[str, dict[str, str]],
) -> str:
"""Handle /summary command with concurrent album fetching."""
if not all_album_ids:
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": []})
results = await asyncio.gather(
*[client.get_album(aid) for aid in all_album_ids],
return_exceptions=True,
)
albums_data: list[dict] = []
for album_id, result in zip(all_album_ids, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
albums_data.append({
"name": result.name, "asset_count": result.asset_count, "id": album_id,
})
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": albums_data})
@@ -0,0 +1,49 @@
"""Shared helpers, imports, and constants for Immich command handlers."""
from __future__ import annotations
import logging
from typing import Any
from ...services import make_immich_provider
from ..handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__)
_IMMICH_COMMANDS = {
"status", "albums", "events", "people",
"search", "find", "person", "place",
"latest", "random", "favorites", "summary", "memory",
}
def _format_assets(
assets: list[dict[str, Any]], cmd: str, query: str,
locale: str, response_mode: str, client: Any,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Format asset results as text or media payload."""
if not assets:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": query})
if response_mode == "media":
media_items = []
for asset in assets:
asset_id = asset.get("id", "")
filename = asset.get("originalFileName", "")
year = asset.get("year", "")
caption = f"{filename} ({year})" if year else filename
media_items.append({
"type": "photo",
"asset_id": asset_id,
"caption": caption,
"thumbnail_url": f"{client.url}/api/assets/{asset_id}/thumbnail?size=preview",
"api_key": client.api_key,
})
return media_items
slot_map = {"find": "search", "person": "search", "place": "search"}
slot_name = slot_map.get(cmd, cmd)
return _render_cmd_template(cmd_templates, slot_name, locale, {
"assets": assets, "query": query, "command": cmd, "count": len(assets),
})
@@ -0,0 +1,199 @@
"""Event-related Immich bot commands: events, latest, memory, random."""
from __future__ import annotations
import asyncio
import logging
import random as rng
from datetime import datetime, timezone
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ...database.engine import get_engine
from ...database.models import (
EventLog, NotificationTarget, NotificationTrackerTarget,
ServiceProvider, TelegramBot, TrackingConfig,
)
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
from .common import _format_assets
_LOGGER = logging.getLogger(__name__)
async def _cmd_events(
bot: TelegramBot, providers_map: dict[int, ServiceProvider],
count: int, locale: str,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracker_ids = [t.id for t in trackers]
if not tracker_ids:
return {"events": []}
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(EventLog)
.where(EventLog.tracker_id.in_(tracker_ids))
.order_by(EventLog.created_at.desc())
.limit(count)
)
events = result.all()
events_data = [
{"type": e.event_type, "album": e.collection_name,
"count": e.assets_count, "date": e.created_at.strftime("%m/%d %H:%M")}
for e in events
]
return {"events": events_data}
async def cmd_latest(
client: Any, all_album_ids: list[str], count: int,
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Handle /latest command with concurrent album fetching."""
album_ids = all_album_ids[:10]
if not album_ids:
return _format_assets([], "latest", "", locale, response_mode, client, cmd_templates)
results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
latest_assets: list[dict[str, Any]] = []
for album_id, result in zip(album_ids, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
for aid, asset in list(result.assets.items())[:count]:
latest_assets.append({
"id": asset.id, "originalFileName": asset.filename,
"type": asset.type, "createdAt": asset.created_at,
})
latest_assets.sort(key=lambda a: a.get("createdAt", ""), reverse=True)
return _format_assets(latest_assets[:count], "latest", "", locale, response_mode, client, cmd_templates)
async def cmd_random(
client: Any, all_album_ids: list[str], count: int,
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Handle /random command with concurrent album fetching."""
album_ids = all_album_ids[:10]
if not album_ids:
return _format_assets([], "random", "", locale, response_mode, client, cmd_templates)
results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
random_assets: list[dict[str, Any]] = []
for album_id, result in zip(album_ids, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
asset_list = list(result.assets.values())
sampled = rng.sample(asset_list, min(count, len(asset_list)))
for asset in sampled:
random_assets.append({
"id": asset.id, "originalFileName": asset.filename,
"type": asset.type,
})
rng.shuffle(random_assets)
return _format_assets(random_assets[:count], "random", "", locale, response_mode, client, cmd_templates)
async def _check_native_memory(bot: TelegramBot) -> bool:
"""Check if any tracker-target linked to this bot uses native memory source."""
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(NotificationTarget).where(
NotificationTarget.type == "telegram",
NotificationTarget.user_id == bot.user_id,
)
)
targets = result.all()
bot_target_ids = {t.id for t in targets if t.config.get("bot_token") == bot.token}
if not bot_target_ids:
return False
tt_result = await session.exec(
select(NotificationTrackerTarget).where(
NotificationTrackerTarget.target_id.in_(bot_target_ids)
)
)
for tt in tt_result.all():
if tt.tracking_config_id:
tc = await session.get(TrackingConfig, tt.tracking_config_id)
if tc and tc.memory_source == "native":
return True
return False
async def cmd_memory(
bot: TelegramBot, client: Any, all_album_ids: list[str], count: int,
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Handle /memory command with concurrent album fetching."""
use_native = await _check_native_memory(bot)
today = datetime.now(timezone.utc)
memory_assets: list[dict[str, Any]] = []
if use_native:
memories = await client.get_memories()
tracked_ids = set(all_album_ids) if all_album_ids else None
for mem in memories:
year = mem.get("data", {}).get("year")
for raw_asset in mem.get("assets", []):
if tracked_ids:
asset_albums = raw_asset.get("albums", [])
if not any(a.get("id") in tracked_ids for a in asset_albums):
continue
memory_assets.append({
"id": raw_asset.get("id", ""),
"originalFileName": raw_asset.get("originalFileName", ""),
"type": raw_asset.get("type", "IMAGE"),
"createdAt": raw_asset.get("fileCreatedAt", raw_asset.get("createdAt", "")),
"year": year,
})
else:
album_ids = all_album_ids[:10]
if album_ids:
results = await asyncio.gather(
*[client.get_album(aid) for aid in album_ids],
return_exceptions=True,
)
month_day = (today.month, today.day)
for album_id, result in zip(album_ids, results):
if isinstance(result, Exception):
_LOGGER.warning("Failed to fetch album %s: %s", album_id, result)
continue
if result:
for aid, asset in result.assets.items():
try:
dt = datetime.fromisoformat(asset.created_at.replace("Z", "+00:00"))
if (dt.month, dt.day) == month_day and dt.year != today.year:
memory_assets.append({
"id": asset.id, "originalFileName": asset.filename,
"type": asset.type, "createdAt": asset.created_at,
"year": dt.year,
})
except (ValueError, AttributeError):
pass
memory_assets = memory_assets[:count]
if not memory_assets:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "memory", "query": ""})
return _format_assets(memory_assets, "memory", "", locale, response_mode, client, cmd_templates)
@@ -0,0 +1,168 @@
"""Immich-specific bot command handler — main dispatch class."""
from __future__ import annotations
import logging
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ...database.engine import get_engine
from ...database.models import (
CommandConfig, CommandTracker, EventLog,
ServiceProvider, TelegramBot,
)
from ...services import make_immich_provider
from ..base import ProviderCommandHandler
from ..handler import _get_notification_trackers_for_providers, _render_cmd_template
from .albums import _cmd_albums, cmd_favorites, cmd_summary
from .common import _IMMICH_COMMANDS
from .events import _cmd_events, cmd_latest, cmd_memory, cmd_random
from .search import cmd_find, cmd_person, cmd_place, cmd_search
_LOGGER = logging.getLogger(__name__)
async def _cmd_status(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
active = sum(1 for t in trackers if t.enabled)
total = len(trackers)
total_albums = sum(len(t.collection_ids or []) for t in trackers)
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(EventLog).order_by(EventLog.created_at.desc()).limit(1)
)
last_event = result.first()
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
return {
"trackers_active": active, "trackers_total": total,
"total_albums": total_albums, "last_event": last_str,
}
async def _cmd_people(
providers_map: dict[int, ServiceProvider], locale: str,
) -> dict[str, Any]:
all_people: dict[str, str] = {}
async with aiohttp.ClientSession() as http:
for provider in providers_map.values():
if provider.type != "immich":
continue
immich = make_immich_provider(http, provider)
people = await immich.client.get_people()
all_people.update(people)
names = sorted(all_people.values())
return {"people": names}
class ImmichCommandHandler(ProviderCommandHandler):
"""Handles all Immich-specific bot commands."""
provider_type = "immich"
def get_provider_commands(self) -> set[str]:
return _IMMICH_COMMANDS
def get_rate_categories(self) -> dict[str, str]:
return {
"search": "search", "find": "search", "person": "search",
"place": "search", "favorites": "search", "people": "search",
}
async def handle(
self,
cmd: str,
args: str,
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
if cmd == "status":
ctx = await _cmd_status(bot, providers_map, locale)
return _render_cmd_template(cmd_templates, "status", locale, ctx)
if cmd == "albums":
ctx = await _cmd_albums(bot, providers_map, locale)
return _render_cmd_template(cmd_templates, "albums", locale, ctx)
if cmd == "events":
ctx = await _cmd_events(bot, providers_map, count, locale)
return _render_cmd_template(cmd_templates, "events", locale, ctx)
if cmd == "people":
ctx = await _cmd_people(providers_map, locale)
return _render_cmd_template(cmd_templates, "people", locale, ctx)
if cmd in ("search", "find", "person", "place", "latest",
"random", "favorites", "summary", "memory"):
return await _cmd_immich(
bot, cmd, args, count, locale, response_mode,
providers_map, cmd_templates,
)
return None
async def _cmd_immich(
bot: TelegramBot, cmd: str, args: str, count: int, locale: str,
response_mode: str, providers_map: dict[int, ServiceProvider],
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Handle commands that need Immich API access and may return media."""
if not providers_map:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
provider_ids = set(providers_map.keys())
notification_trackers = await _get_notification_trackers_for_providers(provider_ids)
all_album_ids: list[str] = []
for t in notification_trackers:
all_album_ids.extend(t.collection_ids or [])
provider: ServiceProvider | None = None
for p in providers_map.values():
if p.type == "immich":
provider = p
break
if not provider:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
async with aiohttp.ClientSession() as http:
immich = make_immich_provider(http, provider)
client = immich.client
if cmd == "search":
return await cmd_search(client, args, all_album_ids, count, locale, response_mode, cmd_templates)
if cmd == "find":
return await cmd_find(client, args, all_album_ids, count, locale, response_mode, cmd_templates)
if cmd == "person":
return await cmd_person(client, args, count, locale, response_mode, cmd_templates)
if cmd == "place":
return await cmd_place(client, args, all_album_ids, count, locale, response_mode, cmd_templates)
if cmd == "favorites":
return await cmd_favorites(bot, providers_map, all_album_ids, count, locale, response_mode, client, cmd_templates)
if cmd == "latest":
return await cmd_latest(client, all_album_ids, count, locale, response_mode, cmd_templates)
if cmd == "random":
return await cmd_random(client, all_album_ids, count, locale, response_mode, cmd_templates)
if cmd == "summary":
return await cmd_summary(client, all_album_ids, locale, cmd_templates)
if cmd == "memory":
return await cmd_memory(bot, client, all_album_ids, count, locale, response_mode, cmd_templates)
return None
@@ -0,0 +1,66 @@
"""Search-related Immich bot commands: search, find, person, place."""
from __future__ import annotations
from typing import Any
from ..handler import _render_cmd_template
from .common import _format_assets
async def cmd_search(
client: Any, args: str, all_album_ids: list[str], count: int,
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Handle /search command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "search", "query": ""})
assets = await client.search_smart(args, album_ids=all_album_ids, limit=count)
return _format_assets(assets, "search", args, locale, response_mode, client, cmd_templates)
async def cmd_find(
client: Any, args: str, all_album_ids: list[str], count: int,
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Handle /find command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "find", "query": ""})
assets = await client.search_metadata(args, album_ids=all_album_ids, limit=count)
return _format_assets(assets, "find", args, locale, response_mode, client, cmd_templates)
async def cmd_person(
client: Any, args: str, count: int,
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Handle /person command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": ""})
people = await client.get_people()
person_id = None
for pid, pname in people.items():
if args.lower() in pname.lower():
person_id = pid
break
if not person_id:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": args})
assets = await client.search_by_person(person_id, limit=count)
return _format_assets(assets, "person", args, locale, response_mode, client, cmd_templates)
async def cmd_place(
client: Any, args: str, all_album_ids: list[str], count: int,
locale: str, response_mode: str,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Handle /place command."""
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "place", "query": ""})
assets = await client.search_smart(
f"photos taken in {args}", album_ids=all_album_ids, limit=count
)
return _format_assets(assets, "place", args, locale, response_mode, client, cmd_templates)
@@ -1,414 +0,0 @@
"""Immich-specific bot command handler."""
from __future__ import annotations
import logging
import random as rng
from datetime import datetime, timezone
from typing import Any
import aiohttp
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.engine import get_engine
from ..database.models import (
CommandConfig, CommandTracker, NotificationTarget,
NotificationTracker, NotificationTrackerTarget,
ServiceProvider, TelegramBot, TrackingConfig,
)
from ..services import make_immich_provider
from .base import ProviderCommandHandler
from .handler import _render_cmd_template, _get_notification_trackers_for_providers
_LOGGER = logging.getLogger(__name__)
_IMMICH_COMMANDS = {
"status", "albums", "events", "people",
"search", "find", "person", "place",
"latest", "random", "favorites", "summary", "memory",
}
class ImmichCommandHandler(ProviderCommandHandler):
"""Handles all Immich-specific bot commands."""
provider_type = "immich"
def get_provider_commands(self) -> set[str]:
return _IMMICH_COMMANDS
def get_rate_categories(self) -> dict[str, str]:
return {
"search": "search", "find": "search", "person": "search",
"place": "search", "favorites": "search", "people": "search",
}
async def handle(
self,
cmd: str,
args: str,
count: int,
locale: str,
response_mode: str,
providers_map: dict[int, ServiceProvider],
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot,
ctx_tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
) -> str | list[dict[str, Any]] | None:
if cmd == "status":
ctx = await _cmd_status(bot, providers_map, locale)
return _render_cmd_template(cmd_templates, "status", locale, ctx)
if cmd == "albums":
ctx = await _cmd_albums(bot, providers_map, locale)
return _render_cmd_template(cmd_templates, "albums", locale, ctx)
if cmd == "events":
ctx = await _cmd_events(bot, providers_map, count, locale)
return _render_cmd_template(cmd_templates, "events", locale, ctx)
if cmd == "people":
ctx = await _cmd_people(providers_map, locale)
return _render_cmd_template(cmd_templates, "people", locale, ctx)
if cmd in ("search", "find", "person", "place", "latest",
"random", "favorites", "summary", "memory"):
return await _cmd_immich(
bot, cmd, args, count, locale, response_mode,
providers_map, cmd_templates,
)
return None
# --- Immich command implementations (moved from handler.py) ---
async def _cmd_status(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
) -> dict[str, Any]:
from ..database.models import EventLog
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
active = sum(1 for t in trackers if t.enabled)
total = len(trackers)
total_albums = sum(len(t.collection_ids or []) for t in trackers)
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(EventLog).order_by(EventLog.created_at.desc()).limit(1)
)
last_event = result.first()
last_str = last_event.created_at.strftime("%Y-%m-%d %H:%M") if last_event else "-"
return {
"trackers_active": active, "trackers_total": total,
"total_albums": total_albums, "last_event": last_str,
}
async def _cmd_albums(
bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str,
) -> dict[str, Any]:
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
if not trackers:
return {"albums": []}
albums_data: list[dict] = []
async with aiohttp.ClientSession() as http:
for tracker in trackers:
provider = providers_map.get(tracker.provider_id)
if not provider or provider.type != "immich":
continue
immich = make_immich_provider(http, provider)
for album_id in (tracker.collection_ids or []):
try:
album = await immich.client.get_album(album_id)
if album:
albums_data.append({
"name": album.name, "asset_count": album.asset_count, "id": album_id,
})
except Exception:
albums_data.append({
"name": f"{album_id[:8]}...", "asset_count": "?", "id": album_id,
})
return {"albums": albums_data}
async def _cmd_events(
bot: TelegramBot, providers_map: dict[int, ServiceProvider],
count: int, locale: str,
) -> dict[str, Any]:
from ..database.models import EventLog
provider_ids = set(providers_map.keys())
trackers = await _get_notification_trackers_for_providers(provider_ids)
tracker_ids = [t.id for t in trackers]
if not tracker_ids:
return {"events": []}
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(EventLog)
.where(EventLog.tracker_id.in_(tracker_ids))
.order_by(EventLog.created_at.desc())
.limit(count)
)
events = result.all()
events_data = [
{"type": e.event_type, "album": e.collection_name,
"count": e.assets_count, "date": e.created_at.strftime("%m/%d %H:%M")}
for e in events
]
return {"events": events_data}
async def _cmd_people(
providers_map: dict[int, ServiceProvider], locale: str,
) -> dict[str, Any]:
all_people: dict[str, str] = {}
async with aiohttp.ClientSession() as http:
for provider in providers_map.values():
if provider.type != "immich":
continue
immich = make_immich_provider(http, provider)
people = await immich.client.get_people()
all_people.update(people)
names = sorted(all_people.values())
return {"people": names}
async def _check_native_memory(bot: TelegramBot) -> bool:
"""Check if any tracker-target linked to this bot uses native memory source."""
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(NotificationTarget).where(
NotificationTarget.type == "telegram",
NotificationTarget.user_id == bot.user_id,
)
)
targets = result.all()
bot_target_ids = {t.id for t in targets if t.config.get("bot_token") == bot.token}
if not bot_target_ids:
return False
tt_result = await session.exec(
select(NotificationTrackerTarget).where(
NotificationTrackerTarget.target_id.in_(bot_target_ids)
)
)
for tt in tt_result.all():
if tt.tracking_config_id:
tc = await session.get(TrackingConfig, tt.tracking_config_id)
if tc and tc.memory_source == "native":
return True
return False
async def _cmd_immich(
bot: TelegramBot, cmd: str, args: str, count: int, locale: str,
response_mode: str, providers_map: dict[int, ServiceProvider],
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Handle commands that need Immich API access and may return media."""
if not providers_map:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
provider_ids = set(providers_map.keys())
notification_trackers = await _get_notification_trackers_for_providers(provider_ids)
all_album_ids: list[str] = []
for t in notification_trackers:
all_album_ids.extend(t.collection_ids or [])
provider: ServiceProvider | None = None
for p in providers_map.values():
if p.type == "immich":
provider = p
break
if not provider:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": args})
async with aiohttp.ClientSession() as http:
immich = make_immich_provider(http, provider)
client = immich.client
if cmd == "search":
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": ""})
assets = await client.search_smart(args, album_ids=all_album_ids, limit=count)
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
if cmd == "find":
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": ""})
assets = await client.search_metadata(args, album_ids=all_album_ids, limit=count)
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
if cmd == "person":
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": ""})
people = await client.get_people()
person_id = None
for pid, pname in people.items():
if args.lower() in pname.lower():
person_id = pid
break
if not person_id:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "person", "query": args})
assets = await client.search_by_person(person_id, limit=count)
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
if cmd == "place":
if not args:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "place", "query": ""})
assets = await client.search_smart(
f"photos taken in {args}", album_ids=all_album_ids, limit=count
)
return _format_assets(assets, cmd, args, locale, response_mode, client, cmd_templates)
if cmd == "favorites":
fav_assets: list[dict[str, Any]] = []
for album_id in all_album_ids[:10]:
try:
album = await client.get_album(album_id)
if album:
for aid, asset in list(album.assets.items())[:50]:
if asset.is_favorite and len(fav_assets) < count:
fav_assets.append({
"id": asset.id, "originalFileName": asset.filename,
"type": asset.type,
})
except Exception:
pass
if len(fav_assets) >= count:
break
return _format_assets(fav_assets, cmd, "", locale, response_mode, client, cmd_templates)
if cmd == "latest":
latest_assets: list[dict[str, Any]] = []
for album_id in all_album_ids[:10]:
try:
album = await client.get_album(album_id)
if album:
for aid, asset in list(album.assets.items())[:count]:
latest_assets.append({
"id": asset.id, "originalFileName": asset.filename,
"type": asset.type, "createdAt": asset.created_at,
})
except Exception:
pass
latest_assets.sort(key=lambda a: a.get("createdAt", ""), reverse=True)
return _format_assets(latest_assets[:count], cmd, "", locale, response_mode, client, cmd_templates)
if cmd == "random":
random_assets: list[dict[str, Any]] = []
for album_id in all_album_ids[:10]:
try:
album = await client.get_album(album_id)
if album:
asset_list = list(album.assets.values())
sampled = rng.sample(asset_list, min(count, len(asset_list)))
for asset in sampled:
random_assets.append({
"id": asset.id, "originalFileName": asset.filename,
"type": asset.type,
})
except Exception:
pass
rng.shuffle(random_assets)
return _format_assets(random_assets[:count], cmd, "", locale, response_mode, client, cmd_templates)
if cmd == "summary":
albums_data: list[dict] = []
for album_id in all_album_ids:
try:
album = await client.get_album(album_id)
if album:
albums_data.append({
"name": album.name, "asset_count": album.asset_count, "id": album_id,
})
except Exception:
pass
return _render_cmd_template(cmd_templates, "summary", locale, {"albums": albums_data})
if cmd == "memory":
use_native = await _check_native_memory(bot)
today = datetime.now(timezone.utc)
memory_assets: list[dict[str, Any]] = []
if use_native:
memories = await client.get_memories()
tracked_ids = set(all_album_ids) if all_album_ids else None
for mem in memories:
year = mem.get("data", {}).get("year")
for raw_asset in mem.get("assets", []):
if tracked_ids:
asset_albums = raw_asset.get("albums", [])
if not any(a.get("id") in tracked_ids for a in asset_albums):
continue
memory_assets.append({
"id": raw_asset.get("id", ""),
"originalFileName": raw_asset.get("originalFileName", ""),
"type": raw_asset.get("type", "IMAGE"),
"createdAt": raw_asset.get("fileCreatedAt", raw_asset.get("createdAt", "")),
"year": year,
})
else:
month_day = (today.month, today.day)
for album_id in all_album_ids[:10]:
try:
album = await client.get_album(album_id)
if album:
for aid, asset in album.assets.items():
try:
dt = datetime.fromisoformat(asset.created_at.replace("Z", "+00:00"))
if (dt.month, dt.day) == month_day and dt.year != today.year:
memory_assets.append({
"id": asset.id, "originalFileName": asset.filename,
"type": asset.type, "createdAt": asset.created_at,
"year": dt.year,
})
except (ValueError, AttributeError):
pass
except Exception:
pass
memory_assets = memory_assets[:count]
if not memory_assets:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": "memory", "query": ""})
return _format_assets(memory_assets, cmd, "", locale, response_mode, client, cmd_templates)
return None
def _format_assets(
assets: list[dict[str, Any]], cmd: str, query: str,
locale: str, response_mode: str, client: Any,
cmd_templates: dict[str, dict[str, str]],
) -> str | list[dict[str, Any]]:
"""Format asset results as text or media payload."""
if not assets:
return _render_cmd_template(cmd_templates, "no_results", locale, {"command": cmd, "query": query})
if response_mode == "media":
media_items = []
for asset in assets:
asset_id = asset.get("id", "")
filename = asset.get("originalFileName", "")
year = asset.get("year", "")
caption = f"{filename} ({year})" if year else filename
media_items.append({
"type": "photo",
"asset_id": asset_id,
"caption": caption,
"thumbnail_url": f"{client.url}/api/assets/{asset_id}/thumbnail?size=preview",
"api_key": client.api_key,
})
return media_items
slot_map = {"find": "search", "person": "search", "place": "search"}
slot_name = slot_map.get(cmd, cmd)
return _render_cmd_template(cmd_templates, slot_name, locale, {
"assets": assets, "query": query, "command": cmd, "count": len(assets),
})
@@ -12,7 +12,7 @@ def parse_command(text: str) -> tuple[str, str, int | None]:
"/events 10" -> ("events", "", 10)
"/help@mybot" -> ("help", "", None)
"""
text = text.strip()
text = text[:512].strip()
if not text.startswith("/"):
return ("", text, None)
@@ -33,6 +33,9 @@ class Settings(BaseSettings):
telegram_webhook_secret: str = ""
cors_allowed_origins: str = "*"
"""Comma-separated allowed origins for CORS (e.g. 'http://localhost:5173,https://myapp.com'). Use '*' for dev."""
model_config = {"env_prefix": "NOTIFY_BRIDGE_"}
@property
@@ -207,7 +207,12 @@ class TemplateSlot(SQLModel, table=True):
)
id: int | None = Field(default=None, primary_key=True)
config_id: int = Field(foreign_key="template_config.id", index=True)
config_id: int = Field(
foreign_key="template_config.id",
index=True,
)
slot_name: str
template: str = Field(default="", sa_column=Column(Text, default=""))
@@ -245,7 +250,12 @@ class TargetReceiver(SQLModel, table=True):
)
id: int | None = Field(default=None, primary_key=True)
target_id: int = Field(foreign_key="notification_target.id", index=True)
target_id: int = Field(
foreign_key="notification_target.id",
index=True,
)
name: str = Field(default="")
config: dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON))
receiver_key: str = Field(default="") # dedup key (e.g. chat_id, url, email)
@@ -283,7 +293,12 @@ class NotificationTrackerTarget(SQLModel, table=True):
index=True,
sa_column_kwargs={"name": "notification_tracker_id"},
)
target_id: int = Field(foreign_key="notification_target.id", index=True)
target_id: int = Field(
foreign_key="notification_target.id",
index=True,
)
tracking_config_id: int | None = Field(
default=None, foreign_key="tracking_config.id"
)
@@ -366,7 +381,12 @@ class CommandTemplateSlot(SQLModel, table=True):
)
id: int | None = Field(default=None, primary_key=True)
config_id: int = Field(foreign_key="command_template_config.id", index=True)
config_id: int = Field(
foreign_key="command_template_config.id",
index=True,
)
slot_name: str
locale: str = Field(default="en")
template: str = Field(default="", sa_column=Column(Text, default=""))
@@ -399,7 +419,11 @@ class CommandTrackerListener(SQLModel, table=True):
)
id: int | None = Field(default=None, primary_key=True)
command_tracker_id: int = Field(foreign_key="command_tracker.id")
command_tracker_id: int = Field(
foreign_key="command_tracker.id",
)
listener_type: str # e.g. "telegram_bot"
listener_id: int
created_at: datetime = Field(default_factory=_utcnow)
@@ -0,0 +1,324 @@
"""Database seed functions — create/update system-owned defaults on startup."""
import json
import logging
from datetime import datetime, timezone
from sqlalchemy import text
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from .engine import get_engine
from .models import (
CommandConfig,
CommandTemplateConfig,
CommandTemplateSlot,
TemplateConfig,
TemplateSlot,
TrackingConfig,
)
_LOGGER = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _seed_provider_template(
session: AsyncSession,
provider_type: str,
label: str,
) -> None:
"""Seed templates for a single provider type across all locales."""
from notify_bridge_core.templates.defaults import load_default_templates
result = await session.exec(
select(TemplateConfig).where(
TemplateConfig.user_id == 0,
TemplateConfig.provider_type == provider_type,
)
)
configs = result.all()
existing_locales = {
(c.locale if c.locale else ("ru" if "(RU)" in c.name else "en")): c
for c in configs
}
for locale in ("en", "ru"):
slots = load_default_templates(locale, provider_type=provider_type)
if not slots:
continue
if locale not in existing_locales:
now = datetime.now(timezone.utc).isoformat()
name = f"Default {label} ({locale.upper()})"
desc = f"Default {label} templates ({locale.upper()})"
# Get column names to build INSERT with defaults for legacy cols
col_info = (await session.execute(
text("PRAGMA table_info(template_config)")
)).fetchall()
col_names = [c[1] for c in col_info if c[1] != "id"]
values: dict[str, object] = {}
for col in col_names:
if col == "user_id":
values[col] = 0
elif col == "provider_type":
values[col] = provider_type
elif col == "name":
values[col] = name
elif col == "description":
values[col] = desc
elif col == "created_at":
values[col] = now
elif col == "date_format":
values[col] = "%d.%m.%Y, %H:%M UTC"
elif col == "date_only_format":
values[col] = "%d.%m.%Y"
elif col == "locale":
values[col] = locale
else:
values[col] = "" # empty string for legacy columns
cols_str = ", ".join(values.keys())
placeholders = ", ".join(f":{k}" for k in values.keys())
await session.execute(
text(f"INSERT INTO template_config ({cols_str}) VALUES ({placeholders})"),
values,
)
config_id = (await session.execute(
text("SELECT last_insert_rowid()")
)).scalar()
for slot_name, template_text in slots.items():
session.add(TemplateSlot(
config_id=config_id,
slot_name=slot_name,
template=template_text,
))
else:
config = existing_locales[locale]
for slot_name, template_text in slots.items():
slot_result = await session.exec(
select(TemplateSlot).where(
TemplateSlot.config_id == config.id,
TemplateSlot.slot_name == slot_name,
)
)
existing = slot_result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(TemplateSlot(
config_id=config.id,
slot_name=slot_name,
template=template_text,
))
async def _seed_provider_command_template(
session: AsyncSession,
provider_type: str,
name: str,
description: str,
) -> None:
"""Seed command templates for a single provider type across all locales."""
from notify_bridge_core.templates.command_defaults import load_default_command_templates
result = await session.exec(
select(CommandTemplateConfig).where(
CommandTemplateConfig.user_id == 0,
CommandTemplateConfig.provider_type == provider_type,
)
)
configs = result.all()
if not configs:
config = CommandTemplateConfig(
user_id=0,
provider_type=provider_type,
name=name,
description=description,
)
session.add(config)
await session.flush()
else:
config = configs[0]
for locale in ("en", "ru"):
slots = load_default_command_templates(locale, provider_type=provider_type)
if not slots:
continue
for slot_name, template_text in slots.items():
slot_result = await session.exec(
select(CommandTemplateSlot).where(
CommandTemplateSlot.config_id == config.id,
CommandTemplateSlot.slot_name == slot_name,
CommandTemplateSlot.locale == locale,
)
)
existing = slot_result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(CommandTemplateSlot(
config_id=config.id,
slot_name=slot_name,
locale=locale,
template=template_text,
))
# ---------------------------------------------------------------------------
# Top-level seed functions
# ---------------------------------------------------------------------------
async def _seed_default_templates() -> None:
"""Seed or update default (system-owned) templates on startup.
Uses TemplateSlot child rows for template content.
"""
engine = get_engine()
async with AsyncSession(engine) as session:
await _seed_provider_template(session, "immich", "Immich")
await _seed_provider_template(session, "gitea", "Gitea")
await _seed_provider_template(session, "scheduler", "Scheduler")
await session.commit()
async def _seed_default_command_templates() -> None:
"""Seed or update default command response templates on startup.
Creates a single config per provider with locale-aware slots
(each slot has an EN and RU version stored as separate rows).
"""
engine = get_engine()
async with AsyncSession(engine) as session:
await _seed_provider_command_template(
session, "immich", "Default Commands", "Default Immich command templates",
)
await _seed_provider_command_template(
session, "gitea", "Default Gitea Commands", "Default Gitea command templates",
)
await session.commit()
async def _seed_default_tracking_configs() -> None:
"""Seed system-owned default tracking configs for each provider type."""
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(TrackingConfig).where(TrackingConfig.user_id == 0)
)
existing = {c.provider_type: c for c in result.all()}
defaults = [
{
"provider_type": "gitea",
"name": "Default Gitea",
"track_push": True,
"track_issue_opened": True,
"track_issue_closed": True,
"track_issue_commented": False,
"track_pr_opened": True,
"track_pr_closed": True,
"track_pr_merged": True,
"track_pr_commented": False,
"track_release_published": True,
},
{
"provider_type": "scheduler",
"name": "Default Scheduler",
"track_scheduled_message": True,
},
]
for cfg in defaults:
ptype = cfg["provider_type"]
if ptype in existing:
continue
session.add(TrackingConfig(user_id=0, **cfg))
await session.commit()
async def _seed_default_command_configs() -> None:
"""Seed system-owned default command configs for each provider type."""
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(CommandConfig).where(CommandConfig.user_id == 0)
)
existing = {c.provider_type: c for c in result.all()}
# Find system command template configs to link
tmpl_result = await session.exec(
select(CommandTemplateConfig).where(CommandTemplateConfig.user_id == 0)
)
tmpl_by_type = {t.provider_type: t.id for t in tmpl_result.all()}
defaults = [
{
"provider_type": "immich",
"name": "Default Immich",
"enabled_commands": [
"help", "status", "albums", "events", "latest",
"random", "favorites", "summary", "memory",
],
"response_mode": "media",
"default_count": 5,
"rate_limits": {"search": 30, "default": 10},
},
{
"provider_type": "gitea",
"name": "Default Gitea",
"enabled_commands": [
"help", "status", "repos", "issues", "prs", "commits",
],
"response_mode": "text",
"default_count": 10,
"rate_limits": {"api": 15, "default": 10},
},
]
for cfg in defaults:
ptype = cfg["provider_type"]
if ptype in existing:
continue
cmd_tmpl_id = tmpl_by_type.get(ptype)
await session.execute(
text(
"INSERT INTO command_config "
"(user_id, provider_type, name, icon, enabled_commands, locale, "
"response_mode, default_count, rate_limits, command_template_config_id, created_at) "
"VALUES (:uid, :pt, :name, :icon, :cmds, :locale, :rm, :dc, :rl, :ctid, :ca)"
),
{
"uid": 0,
"pt": ptype,
"name": cfg["name"],
"icon": "",
"cmds": json.dumps(cfg["enabled_commands"]),
"locale": "en",
"rm": cfg["response_mode"],
"dc": cfg["default_count"],
"rl": json.dumps(cfg["rate_limits"]),
"ctid": cmd_tmpl_id,
"ca": datetime.now(timezone.utc).isoformat(),
},
)
await session.commit()
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
async def seed_all() -> None:
"""Run all seed functions in order."""
await _seed_default_templates()
await _seed_default_command_templates()
await _seed_default_tracking_configs()
await _seed_default_command_configs()
+23 -494
View File
@@ -4,6 +4,10 @@ import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
# Ensure app-level loggers are visible
logging.basicConfig(level=logging.INFO)
@@ -50,10 +54,8 @@ async def lifespan(app: FastAPI):
await migrate_template_locale(engine)
await migrate_receivers_from_config(engine)
await migrate_command_slot_locale(engine)
await _seed_default_templates()
await _seed_default_command_templates()
await _seed_default_tracking_configs()
await _seed_default_command_configs()
from .database.seeds import seed_all
await seed_all()
# Configure webhook secret from DB setting (falls back to env var)
from sqlmodel.ext.asyncio.session import AsyncSession as _AS
from .api.app_settings import get_setting as _get_setting
@@ -71,6 +73,23 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Notify Bridge", version="0.1.0", lifespan=lifespan)
# --- Rate limiting ---
from .auth.routes import limiter
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
# --- CORS ---
from .config import settings as _cfg
_origins = [o.strip() for o in _cfg.cors_allowed_origins.split(",") if o.strip()]
app.add_middleware(
CORSMiddleware,
allow_origins=_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Register routes — static paths before parameterized
app.include_router(auth_router)
app.include_router(template_vars_router)
@@ -99,496 +118,6 @@ async def health():
return {"status": "ok"}
async def _seed_default_templates():
"""Seed or update default (system-owned) templates on startup.
Uses TemplateSlot child rows for template content.
"""
from sqlalchemy import text
from sqlmodel import func, select
from sqlmodel.ext.asyncio.session import AsyncSession
from .database.engine import get_engine
from .database.models import TemplateConfig, TemplateSlot
from notify_bridge_core.templates.defaults import load_default_templates
engine = get_engine()
async with AsyncSession(engine) as session:
# Find existing system-owned templates
result = await session.exec(
select(TemplateConfig).where(TemplateConfig.user_id == 0)
)
system_configs = result.all()
existing_locales = {
(c.locale if c.locale else ("ru" if "(RU)" in c.name else "en")): c
for c in system_configs
}
for locale in ("en", "ru"):
slots = load_default_templates(locale, provider_type="immich")
if not slots:
continue
if locale not in existing_locales:
# Create missing system template via raw SQL
# (legacy NOT NULL columns may still exist in the DB)
name = f"Default ({locale.upper()})"
desc = f"Default Immich templates ({locale.upper()})"
# Get column names to build INSERT with defaults for legacy cols
col_info = (await session.execute(
text("PRAGMA table_info(template_config)")
)).fetchall()
col_names = [c[1] for c in col_info if c[1] != "id"]
values = {}
from datetime import datetime, timezone
now = datetime.now(timezone.utc).isoformat()
for col in col_names:
if col == "user_id":
values[col] = 0
elif col == "provider_type":
values[col] = "immich"
elif col == "name":
values[col] = name
elif col == "description":
values[col] = desc
elif col == "created_at":
values[col] = now
elif col == "date_format":
values[col] = "%d.%m.%Y, %H:%M UTC"
elif col == "date_only_format":
values[col] = "%d.%m.%Y"
elif col == "locale":
values[col] = locale
else:
values[col] = "" # empty string for legacy columns
cols_str = ", ".join(values.keys())
placeholders = ", ".join(f":{k}" for k in values.keys())
await session.execute(
text(f"INSERT INTO template_config ({cols_str}) VALUES ({placeholders})"),
values,
)
# Get the inserted ID
row = (await session.execute(text("SELECT last_insert_rowid()"))).scalar()
config_id = row
for slot_name, template_text in slots.items():
session.add(TemplateSlot(
config_id=config_id,
slot_name=slot_name,
template=template_text,
))
else:
# Update existing system template slots
config = existing_locales[locale]
for slot_name, template_text in slots.items():
slot_result = await session.exec(
select(TemplateSlot).where(
TemplateSlot.config_id == config.id,
TemplateSlot.slot_name == slot_name,
)
)
existing = slot_result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(TemplateSlot(
config_id=config.id,
slot_name=slot_name,
template=template_text,
))
# --- Seed Gitea default templates ---
gitea_result = await session.exec(
select(TemplateConfig).where(
TemplateConfig.user_id == 0,
TemplateConfig.provider_type == "gitea",
)
)
gitea_configs = gitea_result.all()
gitea_existing_locales = {
(c.locale if c.locale else "en"): c for c in gitea_configs
}
for locale in ("en", "ru"):
gitea_slots = load_default_templates(locale, provider_type="gitea")
if not gitea_slots:
continue
if locale not in gitea_existing_locales:
from datetime import datetime as _dt, timezone as _tz
now = _dt.now(_tz.utc).isoformat()
name = f"Default Gitea ({locale.upper()})"
desc = f"Default Gitea templates ({locale.upper()})"
col_info = (await session.execute(
text("PRAGMA table_info(template_config)")
)).fetchall()
col_names = [c[1] for c in col_info if c[1] != "id"]
values = {}
for col in col_names:
if col == "user_id":
values[col] = 0
elif col == "provider_type":
values[col] = "gitea"
elif col == "name":
values[col] = name
elif col == "description":
values[col] = desc
elif col == "created_at":
values[col] = now
elif col == "date_format":
values[col] = "%d.%m.%Y, %H:%M UTC"
elif col == "date_only_format":
values[col] = "%d.%m.%Y"
elif col == "locale":
values[col] = locale
else:
values[col] = ""
cols_str = ", ".join(values.keys())
placeholders = ", ".join(f":{k}" for k in values.keys())
await session.execute(
text(f"INSERT INTO template_config ({cols_str}) VALUES ({placeholders})"),
values,
)
row = (await session.execute(text("SELECT last_insert_rowid()"))).scalar()
gitea_config_id = row
for slot_name, template_text in gitea_slots.items():
session.add(TemplateSlot(
config_id=gitea_config_id,
slot_name=slot_name,
template=template_text,
))
else:
config = gitea_existing_locales[locale]
for slot_name, template_text in gitea_slots.items():
slot_result = await session.exec(
select(TemplateSlot).where(
TemplateSlot.config_id == config.id,
TemplateSlot.slot_name == slot_name,
)
)
existing = slot_result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(TemplateSlot(
config_id=config.id,
slot_name=slot_name,
template=template_text,
))
# --- Seed Scheduler default templates ---
sched_result = await session.exec(
select(TemplateConfig).where(
TemplateConfig.user_id == 0,
TemplateConfig.provider_type == "scheduler",
)
)
sched_configs = sched_result.all()
sched_existing_locales = {
(c.locale if c.locale else "en"): c for c in sched_configs
}
for locale in ("en", "ru"):
sched_slots = load_default_templates(locale, provider_type="scheduler")
if not sched_slots:
continue
if locale not in sched_existing_locales:
from datetime import datetime as _dt2, timezone as _tz2
now2 = _dt2.now(_tz2.utc).isoformat()
name2 = f"Default Scheduler ({locale.upper()})"
desc2 = f"Default Scheduler templates ({locale.upper()})"
col_info2 = (await session.execute(
text("PRAGMA table_info(template_config)")
)).fetchall()
col_names2 = [c[1] for c in col_info2 if c[1] != "id"]
values2 = {}
for col in col_names2:
if col == "user_id":
values2[col] = 0
elif col == "provider_type":
values2[col] = "scheduler"
elif col == "name":
values2[col] = name2
elif col == "description":
values2[col] = desc2
elif col == "created_at":
values2[col] = now2
elif col == "date_format":
values2[col] = "%d.%m.%Y, %H:%M UTC"
elif col == "date_only_format":
values2[col] = "%d.%m.%Y"
elif col == "locale":
values2[col] = locale
else:
values2[col] = ""
cols_str2 = ", ".join(values2.keys())
placeholders2 = ", ".join(f":{k}" for k in values2.keys())
await session.execute(
text(f"INSERT INTO template_config ({cols_str2}) VALUES ({placeholders2})"),
values2,
)
row2 = (await session.execute(text("SELECT last_insert_rowid()"))).scalar()
for slot_name, template_text in sched_slots.items():
session.add(TemplateSlot(
config_id=row2,
slot_name=slot_name,
template=template_text,
))
else:
config = sched_existing_locales[locale]
for slot_name, template_text in sched_slots.items():
slot_result = await session.exec(
select(TemplateSlot).where(
TemplateSlot.config_id == config.id,
TemplateSlot.slot_name == slot_name,
)
)
existing = slot_result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(TemplateSlot(
config_id=config.id,
slot_name=slot_name,
template=template_text,
))
await session.commit()
async def _seed_default_command_templates():
"""Seed or update default command response templates on startup.
Creates a single 'Default Commands' config with locale-aware slots
(each slot has an EN and RU version stored as separate rows).
"""
from sqlmodel import func, select
from sqlmodel.ext.asyncio.session import AsyncSession
from .database.engine import get_engine
from .database.models import CommandTemplateConfig, CommandTemplateSlot
from notify_bridge_core.templates.command_defaults import load_default_command_templates
engine = get_engine()
async with AsyncSession(engine) as session:
# Find or create the system-owned config
result = await session.exec(
select(CommandTemplateConfig).where(CommandTemplateConfig.user_id == 0)
)
system_configs = result.all()
if not system_configs:
# First startup — create single merged config
config = CommandTemplateConfig(
user_id=0,
provider_type="immich",
name="Default Commands",
description="Default Immich command templates",
)
session.add(config)
await session.flush()
else:
config = system_configs[0]
# Upsert slots for each locale
for locale in ("en", "ru"):
slots = load_default_command_templates(locale, provider_type="immich")
if not slots:
continue
for slot_name, template_text in slots.items():
slot_result = await session.exec(
select(CommandTemplateSlot).where(
CommandTemplateSlot.config_id == config.id,
CommandTemplateSlot.slot_name == slot_name,
CommandTemplateSlot.locale == locale,
)
)
existing = slot_result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(CommandTemplateSlot(
config_id=config.id,
slot_name=slot_name,
locale=locale,
template=template_text,
))
# --- Seed Gitea default command templates ---
gitea_cmd_result = await session.exec(
select(CommandTemplateConfig).where(
CommandTemplateConfig.user_id == 0,
CommandTemplateConfig.provider_type == "gitea",
)
)
gitea_cmd_configs = gitea_cmd_result.all()
if not gitea_cmd_configs:
gitea_cmd_config = CommandTemplateConfig(
user_id=0,
provider_type="gitea",
name="Default Gitea Commands",
description="Default Gitea command templates",
)
session.add(gitea_cmd_config)
await session.flush()
else:
gitea_cmd_config = gitea_cmd_configs[0]
for locale in ("en", "ru"):
gitea_cmd_slots = load_default_command_templates(locale, provider_type="gitea")
if not gitea_cmd_slots:
continue
for slot_name, template_text in gitea_cmd_slots.items():
slot_result = await session.exec(
select(CommandTemplateSlot).where(
CommandTemplateSlot.config_id == gitea_cmd_config.id,
CommandTemplateSlot.slot_name == slot_name,
CommandTemplateSlot.locale == locale,
)
)
existing = slot_result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(CommandTemplateSlot(
config_id=gitea_cmd_config.id,
slot_name=slot_name,
locale=locale,
template=template_text,
))
await session.commit()
async def _seed_default_tracking_configs():
"""Seed system-owned default tracking configs for each provider type."""
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from .database.engine import get_engine
from .database.models import TrackingConfig
engine = get_engine()
async with AsyncSession(engine) as session:
# Find existing system-owned tracking configs
result = await session.exec(
select(TrackingConfig).where(TrackingConfig.user_id == 0)
)
existing = {c.provider_type: c for c in result.all()}
defaults = [
{
"provider_type": "gitea",
"name": "Default Gitea",
"track_push": True,
"track_issue_opened": True,
"track_issue_closed": True,
"track_issue_commented": False,
"track_pr_opened": True,
"track_pr_closed": True,
"track_pr_merged": True,
"track_pr_commented": False,
"track_release_published": True,
},
{
"provider_type": "scheduler",
"name": "Default Scheduler",
"track_scheduled_message": True,
},
]
for cfg in defaults:
ptype = cfg["provider_type"]
if ptype in existing:
continue
session.add(TrackingConfig(user_id=0, **cfg))
await session.commit()
async def _seed_default_command_configs():
"""Seed system-owned default command configs for each provider type."""
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from .database.engine import get_engine
from .database.models import CommandConfig, CommandTemplateConfig
engine = get_engine()
async with AsyncSession(engine) as session:
# Find existing system-owned command configs
result = await session.exec(
select(CommandConfig).where(CommandConfig.user_id == 0)
)
existing = {c.provider_type: c for c in result.all()}
# Find system command template configs to link
tmpl_result = await session.exec(
select(CommandTemplateConfig).where(CommandTemplateConfig.user_id == 0)
)
tmpl_by_type = {t.provider_type: t.id for t in tmpl_result.all()}
defaults = [
{
"provider_type": "immich",
"name": "Default Immich",
"enabled_commands": [
"help", "status", "albums", "events", "latest",
"random", "favorites", "summary", "memory",
],
"response_mode": "media",
"default_count": 5,
"rate_limits": {"search": 30, "default": 10},
},
{
"provider_type": "gitea",
"name": "Default Gitea",
"enabled_commands": [
"help", "status", "repos", "issues", "prs", "commits",
],
"response_mode": "text",
"default_count": 10,
"rate_limits": {"api": 15, "default": 10},
},
]
for cfg in defaults:
ptype = cfg["provider_type"]
if ptype in existing:
continue
cmd_tmpl_id = tmpl_by_type.get(ptype)
# Use raw SQL to handle legacy NOT NULL columns
import json as _json2
from sqlalchemy import text as _text2
from datetime import datetime as _dt3, timezone as _tz3
await session.execute(
_text2(
"INSERT INTO command_config "
"(user_id, provider_type, name, icon, enabled_commands, locale, "
"response_mode, default_count, rate_limits, command_template_config_id, created_at) "
"VALUES (:uid, :pt, :name, :icon, :cmds, :locale, :rm, :dc, :rl, :ctid, :ca)"
),
{
"uid": 0,
"pt": ptype,
"name": cfg["name"],
"icon": "",
"cmds": _json2.dumps(cfg["enabled_commands"]),
"locale": "en",
"rm": cfg["response_mode"],
"dc": cfg["default_count"],
"rl": _json2.dumps(cfg["rate_limits"]),
"ctid": cmd_tmpl_id,
"ca": _dt3.now(_tz3.utc).isoformat(),
},
)
await session.commit()
def run():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8420)
@@ -0,0 +1,165 @@
"""Shared dispatch helpers used by both watcher and webhook handlers."""
from __future__ import annotations
import logging
from datetime import datetime, time, timezone
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from notify_bridge_core.models.events import ServiceEvent
from ..database.models import (
EmailBot,
MatrixBot,
NotificationTarget,
NotificationTrackerTarget,
TargetReceiver,
TemplateConfig,
TemplateSlot,
TrackingConfig,
)
_LOGGER = logging.getLogger(__name__)
def in_quiet_hours(start: str | None, end: str | None) -> bool:
"""Check if the current UTC time is within the quiet hours window."""
if not start or not end:
return False
try:
now = datetime.now(timezone.utc).time()
t_start = time.fromisoformat(start)
t_end = time.fromisoformat(end)
if t_start <= t_end:
return t_start <= now <= t_end
else:
# Overnight window (e.g., 22:00 - 06:00)
return now >= t_start or now <= t_end
except (ValueError, TypeError):
return False
def event_allowed_by_config(event: ServiceEvent, tc: TrackingConfig) -> bool:
"""Check if an event type is allowed by the tracking config's flags."""
event_type = event.event_type.value
flag_map = {
# Immich events
"assets_added": tc.track_assets_added,
"assets_removed": tc.track_assets_removed,
"collection_renamed": tc.track_collection_renamed,
"collection_deleted": tc.track_collection_deleted,
"sharing_changed": tc.track_sharing_changed,
# Gitea events
"push": tc.track_push,
"issue_opened": tc.track_issue_opened,
"issue_closed": tc.track_issue_closed,
"issue_commented": tc.track_issue_commented,
"pr_opened": tc.track_pr_opened,
"pr_closed": tc.track_pr_closed,
"pr_merged": tc.track_pr_merged,
"pr_commented": tc.track_pr_commented,
"release_published": tc.track_release_published,
# Scheduler events
"scheduled_message": tc.track_scheduled_message,
}
return flag_map.get(event_type, True)
async def load_link_data(
session: AsyncSession,
tracker_id: int,
*,
check_quiet_hours: bool = False,
) -> list[dict[str, Any]]:
"""Load tracker-target link data for dispatch.
Args:
session: Active database session.
tracker_id: ID of the tracker whose links to load.
check_quiet_hours: If True, skip links currently in quiet hours.
"""
tt_result = await session.exec(
select(NotificationTrackerTarget).where(
NotificationTrackerTarget.tracker_id == tracker_id
)
)
tracker_targets = tt_result.all()
link_data: list[dict[str, Any]] = []
for tt in tracker_targets:
if not tt.enabled:
continue
if check_quiet_hours and in_quiet_hours(tt.quiet_hours_start, tt.quiet_hours_end):
continue
target = await session.get(NotificationTarget, tt.target_id)
if not target:
continue
# Load receivers
recv_result = await session.exec(
select(TargetReceiver).where(
TargetReceiver.target_id == target.id,
TargetReceiver.enabled == True,
)
)
receivers = [dict(r.config) for r in recv_result.all()]
tracking_config = None
if tt.tracking_config_id:
tracking_config = await session.get(TrackingConfig, tt.tracking_config_id)
template_config = None
template_slots: dict[str, str] | None = None
if tt.template_config_id:
template_config = await session.get(TemplateConfig, tt.template_config_id)
if template_config:
slot_result = await session.exec(
select(TemplateSlot).where(TemplateSlot.config_id == template_config.id)
)
raw_slots = {s.slot_name: s.template for s in slot_result.all()}
template_slots = {}
for slot_name, tmpl_text in raw_slots.items():
event_key = slot_name.removeprefix("message_") if slot_name.startswith("message_") else slot_name
template_slots[event_key] = tmpl_text
target_config = dict(target.config)
# Inject chat_action for Telegram targets
if hasattr(target, 'chat_action') and target.chat_action:
target_config["chat_action"] = target.chat_action
# Inject bot credentials for bot-backed target types
if target.type == "email":
email_bot_id = target.config.get("email_bot_id")
if email_bot_id:
email_bot = await session.get(EmailBot, email_bot_id)
if email_bot:
target_config["smtp"] = {
"host": email_bot.smtp_host,
"port": email_bot.smtp_port,
"username": email_bot.smtp_username,
"password": email_bot.smtp_password,
"from_address": email_bot.email,
"from_name": email_bot.name,
"use_tls": email_bot.smtp_use_tls,
}
elif target.type == "matrix":
matrix_bot_id = target.config.get("matrix_bot_id")
if matrix_bot_id:
matrix_bot = await session.get(MatrixBot, matrix_bot_id)
if matrix_bot:
target_config["homeserver_url"] = matrix_bot.homeserver_url
target_config["access_token"] = matrix_bot.access_token
link_data.append({
"target_type": target.type,
"target_config": target_config,
"receivers": receivers,
"tracking_config": tracking_config,
"template_config": template_config,
"template_slots": template_slots,
})
return link_data
@@ -3,7 +3,6 @@
from __future__ import annotations
import logging
from datetime import datetime, time, timezone
from typing import Any
import aiohttp
@@ -17,19 +16,12 @@ from notify_bridge_core.storage import JsonFileBackend
from ..database.engine import get_engine
from ..database.models import (
EmailBot,
EventLog,
MatrixBot,
NotificationTarget,
NotificationTracker,
NotificationTrackerState,
NotificationTrackerTarget,
ServiceProvider,
TargetReceiver,
TemplateConfig,
TemplateSlot,
TrackingConfig,
)
from .dispatch_helpers import event_allowed_by_config, load_link_data
_LOGGER = logging.getLogger(__name__)
@@ -57,49 +49,6 @@ async def _get_telegram_caches() -> tuple[TelegramFileCache | None, TelegramFile
return _url_cache, _asset_cache
def _in_quiet_hours(start: str | None, end: str | None) -> bool:
"""Check if the current UTC time is within the quiet hours window."""
if not start or not end:
return False
try:
now = datetime.now(timezone.utc).time()
t_start = time.fromisoformat(start)
t_end = time.fromisoformat(end)
if t_start <= t_end:
return t_start <= now <= t_end
else:
# Overnight window (e.g., 22:00 - 06:00)
return now >= t_start or now <= t_end
except (ValueError, TypeError):
return False
def _event_allowed_by_config(event: ServiceEvent, tc: TrackingConfig) -> bool:
"""Check if an event type is allowed by the tracking config's flags."""
event_type = event.event_type.value
flag_map = {
# Immich events
"assets_added": tc.track_assets_added,
"assets_removed": tc.track_assets_removed,
"collection_renamed": tc.track_collection_renamed,
"collection_deleted": tc.track_collection_deleted,
"sharing_changed": tc.track_sharing_changed,
# Gitea events
"push": tc.track_push,
"issue_opened": tc.track_issue_opened,
"issue_closed": tc.track_issue_closed,
"issue_commented": tc.track_issue_commented,
"pr_opened": tc.track_pr_opened,
"pr_closed": tc.track_pr_closed,
"pr_merged": tc.track_pr_merged,
"pr_commented": tc.track_pr_commented,
"release_published": tc.track_release_published,
# Scheduler events
"scheduled_message": tc.track_scheduled_message,
}
return flag_map.get(event_type, True)
async def check_tracker(tracker_id: int) -> dict[str, Any]:
"""Poll a tracker's provider for changes and dispatch notifications."""
engine = get_engine()
@@ -128,88 +77,8 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
"shared": bool(s.shared),
}
# Load tracker-target links (replaces old target_ids JSON array)
tt_result = await session.exec(
select(NotificationTrackerTarget).where(NotificationTrackerTarget.tracker_id == tracker_id)
)
tracker_targets = tt_result.all()
# For each link, load target + tracking config + template config
link_data: list[dict[str, Any]] = []
for tt in tracker_targets:
if not tt.enabled:
continue
if _in_quiet_hours(tt.quiet_hours_start, tt.quiet_hours_end):
continue
target = await session.get(NotificationTarget, tt.target_id)
if not target:
continue
# Load receivers for this target
recv_result = await session.exec(
select(TargetReceiver).where(
TargetReceiver.target_id == target.id,
TargetReceiver.enabled == True,
)
)
receivers = [dict(r.config) for r in recv_result.all()]
tracking_config = None
if tt.tracking_config_id:
tracking_config = await session.get(TrackingConfig, tt.tracking_config_id)
template_config = None
template_slots: dict[str, str] | None = None
if tt.template_config_id:
template_config = await session.get(TemplateConfig, tt.template_config_id)
if template_config:
slot_result = await session.exec(
select(TemplateSlot).where(TemplateSlot.config_id == template_config.id)
)
raw_slots = {s.slot_name: s.template for s in slot_result.all()}
# Map slot names to event_type values for dispatcher lookup
template_slots = {}
for slot_name, tmpl_text in raw_slots.items():
# Strip "message_" prefix for event-type slots
event_key = slot_name.removeprefix("message_") if slot_name.startswith("message_") else slot_name
template_slots[event_key] = tmpl_text
target_config = dict(target.config)
# Inject chat_action for Telegram targets
if hasattr(target, 'chat_action') and target.chat_action:
target_config["chat_action"] = target.chat_action
# Inject bot credentials for bot-backed target types
if target.type == "email":
email_bot_id = target.config.get("email_bot_id")
if email_bot_id:
email_bot = await session.get(EmailBot, email_bot_id)
if email_bot:
target_config["smtp"] = {
"host": email_bot.smtp_host,
"port": email_bot.smtp_port,
"username": email_bot.smtp_username,
"password": email_bot.smtp_password,
"from_address": email_bot.email,
"from_name": email_bot.name,
"use_tls": email_bot.smtp_use_tls,
}
elif target.type == "matrix":
matrix_bot_id = target.config.get("matrix_bot_id")
if matrix_bot_id:
matrix_bot = await session.get(MatrixBot, matrix_bot_id)
if matrix_bot:
target_config["homeserver_url"] = matrix_bot.homeserver_url
target_config["access_token"] = matrix_bot.access_token
link_data.append({
"target_type": target.type,
"target_config": target_config,
"receivers": receivers,
"tracking_config": tracking_config,
"template_config": template_config,
"template_slots": template_slots,
})
# Load tracker-target links
link_data = await load_link_data(session, tracker_id, check_quiet_hours=True)
# Snapshot the data we need
provider_type = provider.type
@@ -327,7 +196,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
for ld in link_data:
# Apply per-link event filtering from tracking config
tc = ld["tracking_config"]
if tc and not _event_allowed_by_config(event, tc):
if tc and not event_allowed_by_config(event, tc):
_LOGGER.info(" Skipped by tracking config filter")
continue