Files
notify-bridge/packages/server/src/notify_bridge_server/services/backup_service.py
T
alexei.dolgolyov f0739ca949 feat: security hardening — SSRF guard, template sandbox timeout, webhook log prune, auth & backup polish
- Add outbound URL validation (SSRF) for webhook/Discord/Slack/ntfy/Matrix dispatch
- Template renderer: input/output caps and thread-based render timeout
- Webhook log filter: strip Authorization/signature/token-like headers; atomic prune
- Auth/JWT/backup/config tightening; misc frontend UX fixes
2026-04-16 03:21:45 +03:00

909 lines
35 KiB
Python

"""Configuration backup/restore service — export and import logic."""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.models import (
Action, ActionRule, AppSetting, CommandConfig, CommandTemplateConfig,
CommandTemplateSlot, CommandTracker, CommandTrackerListener, EmailBot,
MatrixBot, NotificationTarget, NotificationTracker,
NotificationTrackerTarget, ServiceProvider, TargetReceiver,
TemplateConfig, TemplateSlot, TelegramBot, TrackingConfig,
)
from .backup_schema import (
ALL_CATEGORIES, ActionData, ActionRuleData, AppSettingData, BackupCategory,
BackupData, BackupFile, CommandConfigData, CommandTemplateConfigData,
CommandTemplateSlotData, CommandTrackerData, CommandTrackerListenerData,
ConflictMode, EmailBotData, ImportResult, MatrixBotData,
NotificationTrackerData, PROVIDER_SECRET_FIELDS, ProviderData,
ReceiverData, SecretsMode, TargetData, TemplateConfigData,
TemplateSlotData, TelegramBotData, TrackerTargetData,
TrackingConfigData, ValidateResult,
)
_LOGGER = logging.getLogger(__name__)
# Fields to skip when serializing TrackingConfig into the generic `fields` dict
_TRACKING_SKIP = frozenset(("id", "user_id", "provider_type", "name", "icon", "created_at"))
# Import-time config hardening limits
_MAX_CONFIG_DEPTH = 6
_MAX_CONFIG_KEYS = 200
_MAX_STRING_LEN = 8192
def _sanitize_config(value: Any, depth: int = 0) -> Any:
"""Clamp imported config values to safe shapes before persistence.
Rejects anything that is not a JSON-compatible primitive/container, truncates
over-long strings, and caps dict/list sizes. Returns a defensively-copied
structure; the caller should never see attacker-controlled references.
"""
if depth > _MAX_CONFIG_DEPTH:
raise ValueError("Config nesting exceeds maximum depth")
if value is None or isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return value
if isinstance(value, str):
return value[:_MAX_STRING_LEN]
if isinstance(value, list):
if len(value) > _MAX_CONFIG_KEYS:
raise ValueError("Config list exceeds maximum length")
return [_sanitize_config(v, depth + 1) for v in value]
if isinstance(value, dict):
if len(value) > _MAX_CONFIG_KEYS:
raise ValueError("Config dict exceeds maximum key count")
cleaned: dict[str, Any] = {}
for k, v in value.items():
if not isinstance(k, str):
raise ValueError("Config keys must be strings")
if len(k) > 128:
raise ValueError(f"Config key too long: {k[:40]}...")
cleaned[k] = _sanitize_config(v, depth + 1)
return cleaned
raise ValueError(f"Unsupported config value type: {type(value).__name__}")
# ---------------------------------------------------------------------------
# Export
# ---------------------------------------------------------------------------
def _mask_secret(value: str) -> str:
return f"***{value[-4:]}" if len(value) > 4 else "***"
def _apply_secrets_provider(config: dict[str, Any], mode: SecretsMode) -> dict[str, Any]:
"""Return a copy of provider config with secrets handled per mode."""
result = dict(config)
for key in PROVIDER_SECRET_FIELDS:
if key in result and result[key]:
if mode == SecretsMode.EXCLUDE:
result[key] = ""
elif mode == SecretsMode.MASKED:
result[key] = _mask_secret(result[key])
return result
def _tracking_config_fields(tc: TrackingConfig) -> dict[str, Any]:
"""Extract all tracking config fields (booleans, ints, strings) as a dict."""
data = {}
for field_name in tc.model_fields:
if field_name in _TRACKING_SKIP:
continue
data[field_name] = getattr(tc, field_name)
return data
async def export_backup(
session: AsyncSession,
user_id: int,
categories: list[BackupCategory] | None = None,
secrets_mode: SecretsMode = SecretsMode.EXCLUDE,
) -> BackupFile:
"""Export user configuration as a BackupFile."""
cats = set(categories or ALL_CATEGORIES)
data = BackupData()
# -- Providers --
if BackupCategory.PROVIDERS in cats:
result = await session.exec(
select(ServiceProvider).where(ServiceProvider.user_id == user_id)
)
for p in result.all():
data.providers.append(ProviderData(
id=p.id,
type=p.type,
name=p.name,
icon=p.icon,
config=_apply_secrets_provider(p.config, secrets_mode),
))
# -- Telegram Bots --
if BackupCategory.TELEGRAM_BOTS in cats:
result = await session.exec(
select(TelegramBot).where(TelegramBot.user_id == user_id)
)
for b in result.all():
token = b.token
if secrets_mode == SecretsMode.EXCLUDE:
token = ""
elif secrets_mode == SecretsMode.MASKED:
token = _mask_secret(token)
data.telegram_bots.append(TelegramBotData(
id=b.id, name=b.name, token=token, icon=b.icon,
bot_username=b.bot_username, update_mode=b.update_mode,
))
# -- Matrix Bots --
if BackupCategory.MATRIX_BOTS in cats:
result = await session.exec(
select(MatrixBot).where(MatrixBot.user_id == user_id)
)
for b in result.all():
access_token = b.access_token
if secrets_mode == SecretsMode.EXCLUDE:
access_token = ""
elif secrets_mode == SecretsMode.MASKED:
access_token = _mask_secret(access_token)
data.matrix_bots.append(MatrixBotData(
id=b.id, name=b.name, icon=b.icon,
homeserver_url=b.homeserver_url, access_token=access_token,
display_name=b.display_name,
))
# -- Email Bots --
if BackupCategory.EMAIL_BOTS in cats:
result = await session.exec(
select(EmailBot).where(EmailBot.user_id == user_id)
)
for b in result.all():
smtp_password = b.smtp_password
if secrets_mode == SecretsMode.EXCLUDE:
smtp_password = ""
elif secrets_mode == SecretsMode.MASKED:
smtp_password = _mask_secret(smtp_password) if smtp_password else ""
data.email_bots.append(EmailBotData(
id=b.id, name=b.name, icon=b.icon, email=b.email,
smtp_host=b.smtp_host, smtp_port=b.smtp_port,
smtp_username=b.smtp_username, smtp_password=smtp_password,
smtp_use_tls=b.smtp_use_tls,
))
# -- Targets + Receivers --
if BackupCategory.TARGETS in cats:
result = await session.exec(
select(NotificationTarget).where(NotificationTarget.user_id == user_id)
)
for tgt in result.all():
recv_result = await session.exec(
select(TargetReceiver).where(TargetReceiver.target_id == tgt.id)
)
receivers = [
ReceiverData(
name=r.name, config=r.config, receiver_key=r.receiver_key,
locale=r.locale, enabled=r.enabled,
)
for r in recv_result.all()
]
data.targets.append(TargetData(
id=tgt.id, type=tgt.type, name=tgt.name, icon=tgt.icon,
config=tgt.config, chat_action=tgt.chat_action,
receivers=receivers,
))
# -- Tracking Configs --
if BackupCategory.TRACKING_CONFIGS in cats:
result = await session.exec(
select(TrackingConfig).where(TrackingConfig.user_id == user_id)
)
for tc in result.all():
data.tracking_configs.append(TrackingConfigData(
id=tc.id, provider_type=tc.provider_type, name=tc.name,
icon=tc.icon, fields=_tracking_config_fields(tc),
))
# -- Template Configs + Slots (user-owned only) --
if BackupCategory.TEMPLATE_CONFIGS in cats:
result = await session.exec(
select(TemplateConfig).where(TemplateConfig.user_id == user_id)
)
for tc in result.all():
slot_result = await session.exec(
select(TemplateSlot).where(TemplateSlot.config_id == tc.id)
)
slots = [
TemplateSlotData(
slot_name=s.slot_name, locale=s.locale, template=s.template,
)
for s in slot_result.all()
]
data.template_configs.append(TemplateConfigData(
id=tc.id, provider_type=tc.provider_type, name=tc.name,
description=tc.description, icon=tc.icon, locale=tc.locale,
date_format=tc.date_format, date_only_format=tc.date_only_format,
slots=slots,
))
# -- Command Template Configs + Slots (user-owned only) --
if BackupCategory.COMMAND_TEMPLATE_CONFIGS in cats:
result = await session.exec(
select(CommandTemplateConfig).where(CommandTemplateConfig.user_id == user_id)
)
for ctc in result.all():
slot_result = await session.exec(
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == ctc.id)
)
slots = [
CommandTemplateSlotData(
slot_name=s.slot_name, locale=s.locale, template=s.template,
)
for s in slot_result.all()
]
data.command_template_configs.append(CommandTemplateConfigData(
id=ctc.id, provider_type=ctc.provider_type, name=ctc.name,
description=ctc.description, icon=ctc.icon, locale=ctc.locale,
slots=slots,
))
# -- Command Configs --
if BackupCategory.COMMAND_CONFIGS in cats:
result = await session.exec(
select(CommandConfig).where(CommandConfig.user_id == user_id)
)
for cc in result.all():
data.command_configs.append(CommandConfigData(
id=cc.id, provider_type=cc.provider_type, name=cc.name,
icon=cc.icon, enabled_commands=cc.enabled_commands,
response_mode=cc.response_mode, default_count=cc.default_count,
rate_limits=cc.rate_limits,
command_template_config_id=cc.command_template_config_id,
))
# -- Notification Trackers + Tracker-Targets --
if BackupCategory.NOTIFICATION_TRACKERS in cats:
result = await session.exec(
select(NotificationTracker).where(NotificationTracker.user_id == user_id)
)
for nt in result.all():
tt_result = await session.exec(
select(NotificationTrackerTarget).where(
NotificationTrackerTarget.tracker_id == nt.id
)
)
targets = [
TrackerTargetData(
target_id=tt.target_id,
tracking_config_id=tt.tracking_config_id,
template_config_id=tt.template_config_id,
enabled=tt.enabled,
quiet_hours_start=tt.quiet_hours_start,
quiet_hours_end=tt.quiet_hours_end,
)
for tt in tt_result.all()
]
data.notification_trackers.append(NotificationTrackerData(
id=nt.id, provider_id=nt.provider_id, name=nt.name,
icon=nt.icon, collection_ids=nt.collection_ids,
filters=nt.filters, scan_interval=nt.scan_interval,
batch_duration=nt.batch_duration,
default_tracking_config_id=nt.default_tracking_config_id,
default_template_config_id=nt.default_template_config_id,
enabled=nt.enabled, targets=targets,
))
# -- Command Trackers + Listeners --
if BackupCategory.COMMAND_TRACKERS in cats:
result = await session.exec(
select(CommandTracker).where(CommandTracker.user_id == user_id)
)
for ct in result.all():
lis_result = await session.exec(
select(CommandTrackerListener).where(
CommandTrackerListener.command_tracker_id == ct.id
)
)
listeners = [
CommandTrackerListenerData(
listener_type=l.listener_type, listener_id=l.listener_id,
)
for l in lis_result.all()
]
data.command_trackers.append(CommandTrackerData(
id=ct.id, provider_id=ct.provider_id,
command_config_id=ct.command_config_id, name=ct.name,
icon=ct.icon, enabled=ct.enabled, listeners=listeners,
))
# -- Actions + Rules --
if BackupCategory.ACTIONS in cats:
result = await session.exec(
select(Action).where(Action.user_id == user_id)
)
for a in result.all():
rule_result = await session.exec(
select(ActionRule).where(ActionRule.action_id == a.id)
)
rules = [
ActionRuleData(
name=r.name, rule_config=r.rule_config,
enabled=r.enabled, order=r.order,
)
for r in rule_result.all()
]
data.actions.append(ActionData(
id=a.id, provider_id=a.provider_id, name=a.name,
icon=a.icon, action_type=a.action_type, config=a.config,
schedule_type=a.schedule_type,
schedule_interval=a.schedule_interval,
schedule_cron=a.schedule_cron, enabled=a.enabled,
rules=rules,
))
# -- App Settings --
if BackupCategory.APP_SETTINGS in cats:
result = await session.exec(select(AppSetting))
for s in result.all():
value = s.value
if s.key == "telegram_webhook_secret" and value:
if secrets_mode == SecretsMode.EXCLUDE:
value = ""
elif secrets_mode == SecretsMode.MASKED:
value = _mask_secret(value)
data.app_settings.append(AppSettingData(key=s.key, value=value))
return BackupFile(
format="notify-bridge-backup",
version=1,
created_at=datetime.now(timezone.utc).isoformat(),
app_version="1.0.0",
secrets_mode=secrets_mode,
categories=[c.value for c in cats],
data=data,
)
# ---------------------------------------------------------------------------
# Export to file (for scheduled backups)
# ---------------------------------------------------------------------------
async def export_backup_to_file(
session: AsyncSession,
user_id: int,
backup_dir: Path,
secrets_mode: SecretsMode = SecretsMode.EXCLUDE,
) -> Path:
"""Export backup and write to a file in backup_dir. Returns the file path."""
backup_dir.mkdir(parents=True, exist_ok=True)
backup = await export_backup(session, user_id, secrets_mode=secrets_mode)
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
filename = f"backup-{ts}.json"
filepath = backup_dir / filename
filepath.write_text(
json.dumps(backup.model_dump(), indent=2, ensure_ascii=False),
encoding="utf-8",
)
_LOGGER.info("Scheduled backup saved: %s", filepath)
return filepath
def cleanup_old_backups(backup_dir: Path, keep: int = 5) -> list[str]:
"""Delete oldest backup files exceeding `keep` count. Returns deleted filenames."""
if not backup_dir.is_dir():
return []
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
deleted = []
for old in files[keep:]:
old.unlink()
deleted.append(old.name)
if deleted:
_LOGGER.info("Cleaned up %d old backup(s): %s", len(deleted), deleted)
return deleted
def list_backup_files(backup_dir: Path) -> list[dict[str, Any]]:
"""List backup files in the directory with metadata."""
if not backup_dir.is_dir():
return []
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
result = []
for f in files:
stat = f.stat()
result.append({
"filename": f.name,
"size": stat.st_size,
"created_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
})
return result
# ---------------------------------------------------------------------------
# Validate
# ---------------------------------------------------------------------------
def validate_backup(raw: dict[str, Any]) -> ValidateResult:
"""Validate a backup file dict without importing. Returns summary."""
warnings: list[str] = []
errors: list[str] = []
fmt = raw.get("format")
if fmt != "notify-bridge-backup":
errors.append(f"Unknown format: {fmt}")
return ValidateResult(valid=False, errors=errors)
version = raw.get("version", 0)
if version > 1:
errors.append(f"Unsupported backup version: {version} (max supported: 1)")
return ValidateResult(valid=False, version=version, errors=errors)
secrets_mode = raw.get("secrets_mode", "exclude")
if secrets_mode in ("exclude", "masked"):
warnings.append(
f"Backup was exported with secrets_mode={secrets_mode}. "
"Imported entities will have empty/placeholder secrets that need manual update."
)
try:
backup = BackupFile.model_validate(raw)
except Exception as e:
errors.append(f"Schema validation failed: {e}")
return ValidateResult(valid=False, version=version, errors=errors)
counts: dict[str, int] = {}
d = backup.data
for cat in ("providers", "telegram_bots", "matrix_bots", "email_bots",
"targets", "tracking_configs", "template_configs",
"command_configs", "command_template_configs",
"notification_trackers", "command_trackers", "actions",
"app_settings"):
items = getattr(d, cat, [])
if items:
counts[cat] = len(items)
return ValidateResult(
valid=True, version=version,
entity_counts=counts, warnings=warnings, errors=errors,
)
# ---------------------------------------------------------------------------
# Import
# ---------------------------------------------------------------------------
async def import_backup(
session: AsyncSession,
user_id: int,
backup: BackupFile,
conflict_mode: ConflictMode = ConflictMode.SKIP,
) -> ImportResult:
"""Import a backup file into the database. Atomic — rolls back on error."""
result = ImportResult()
# Maps: category -> {old_id: new_id}
id_map: dict[str, dict[int, int]] = {}
d = backup.data
try:
# 1. App Settings (simple upsert)
for s in d.app_settings:
existing = await session.get(AppSetting, s.key)
if existing:
if conflict_mode == ConflictMode.SKIP:
result.skipped += 1
continue
elif conflict_mode == ConflictMode.OVERWRITE:
existing.value = s.value
session.add(existing)
result.overwritten += 1
else: # rename — not applicable for settings, just skip
result.skipped += 1
continue
else:
session.add(AppSetting(key=s.key, value=s.value))
result.created += 1
# 2. Telegram Bots
id_map["telegram_bots"] = {}
for b in d.telegram_bots:
name = await _resolve_name(
session, TelegramBot, b.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_bot = TelegramBot(
user_id=user_id, name=name, token=b.token, icon=b.icon,
bot_username=b.bot_username, update_mode=b.update_mode,
)
session.add(new_bot)
await session.flush()
id_map["telegram_bots"][b.id] = new_bot.id
# 3. Matrix Bots
id_map["matrix_bots"] = {}
for b in d.matrix_bots:
name = await _resolve_name(
session, MatrixBot, b.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_bot = MatrixBot(
user_id=user_id, name=name, icon=b.icon,
homeserver_url=b.homeserver_url, access_token=b.access_token,
display_name=b.display_name,
)
session.add(new_bot)
await session.flush()
id_map["matrix_bots"][b.id] = new_bot.id
# 4. Email Bots
id_map["email_bots"] = {}
for b in d.email_bots:
name = await _resolve_name(
session, EmailBot, b.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_bot = EmailBot(
user_id=user_id, name=name, icon=b.icon, email=b.email,
smtp_host=b.smtp_host, smtp_port=b.smtp_port,
smtp_username=b.smtp_username, smtp_password=b.smtp_password,
smtp_use_tls=b.smtp_use_tls,
)
session.add(new_bot)
await session.flush()
id_map["email_bots"][b.id] = new_bot.id
# 5. Providers
id_map["providers"] = {}
for p in d.providers:
name = await _resolve_name(
session, ServiceProvider, p.name, user_id, conflict_mode, result,
)
if name is None:
continue
try:
safe_cfg = _sanitize_config(p.config or {})
except ValueError as exc:
result.warnings.append(f"Skipped provider '{p.name}': {exc}")
continue
new_p = ServiceProvider(
user_id=user_id, type=p.type, name=name,
icon=p.icon, config=safe_cfg,
)
session.add(new_p)
await session.flush()
id_map["providers"][p.id] = new_p.id
# 6. Tracking Configs
id_map["tracking_configs"] = {}
for tc in d.tracking_configs:
name = await _resolve_name(
session, TrackingConfig, tc.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_tc = TrackingConfig(
user_id=user_id, provider_type=tc.provider_type,
name=name, icon=tc.icon,
)
# Apply all tracked fields
for field_name, value in tc.fields.items():
if hasattr(new_tc, field_name):
setattr(new_tc, field_name, value)
session.add(new_tc)
await session.flush()
id_map["tracking_configs"][tc.id] = new_tc.id
# 7. Template Configs + Slots
id_map["template_configs"] = {}
for tc in d.template_configs:
name = await _resolve_name_template(
session, TemplateConfig, tc.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_tc = TemplateConfig(
user_id=user_id, provider_type=tc.provider_type,
name=name, description=tc.description, icon=tc.icon,
locale=tc.locale, date_format=tc.date_format,
date_only_format=tc.date_only_format,
)
session.add(new_tc)
await session.flush()
id_map["template_configs"][tc.id] = new_tc.id
for s in tc.slots:
session.add(TemplateSlot(
config_id=new_tc.id, slot_name=s.slot_name,
locale=s.locale, template=s.template,
))
result.created += len(tc.slots)
# 8. Command Template Configs + Slots
id_map["command_template_configs"] = {}
for ctc in d.command_template_configs:
name = await _resolve_name_template(
session, CommandTemplateConfig, ctc.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_ctc = CommandTemplateConfig(
user_id=user_id, provider_type=ctc.provider_type,
name=name, description=ctc.description, icon=ctc.icon,
locale=ctc.locale,
)
session.add(new_ctc)
await session.flush()
id_map["command_template_configs"][ctc.id] = new_ctc.id
for s in ctc.slots:
session.add(CommandTemplateSlot(
config_id=new_ctc.id, slot_name=s.slot_name,
locale=s.locale, template=s.template,
))
result.created += len(ctc.slots)
# 9. Command Configs
id_map["command_configs"] = {}
for cc in d.command_configs:
name = await _resolve_name(
session, CommandConfig, cc.name, user_id, conflict_mode, result,
)
if name is None:
continue
ctc_id = _map_id(id_map, "command_template_configs", cc.command_template_config_id)
new_cc = CommandConfig(
user_id=user_id, provider_type=cc.provider_type,
name=name, icon=cc.icon,
enabled_commands=cc.enabled_commands,
response_mode=cc.response_mode,
default_count=cc.default_count,
rate_limits=cc.rate_limits,
command_template_config_id=ctc_id,
)
session.add(new_cc)
await session.flush()
id_map["command_configs"][cc.id] = new_cc.id
# 10. Targets + Receivers
id_map["targets"] = {}
for tgt in d.targets:
name = await _resolve_name(
session, NotificationTarget, tgt.name, user_id, conflict_mode, result,
)
if name is None:
continue
try:
safe_tgt_cfg = _sanitize_config(tgt.config or {})
except ValueError as exc:
result.warnings.append(f"Skipped target '{tgt.name}': {exc}")
continue
new_tgt = NotificationTarget(
user_id=user_id, type=tgt.type, name=name,
icon=tgt.icon, config=safe_tgt_cfg,
chat_action=tgt.chat_action,
)
session.add(new_tgt)
await session.flush()
id_map["targets"][tgt.id] = new_tgt.id
for r in tgt.receivers:
try:
safe_r_cfg = _sanitize_config(r.config or {})
except ValueError as exc:
result.warnings.append(f"Skipped receiver in '{tgt.name}': {exc}")
continue
session.add(TargetReceiver(
target_id=new_tgt.id, name=r.name, config=safe_r_cfg,
receiver_key=r.receiver_key, locale=r.locale,
enabled=r.enabled,
))
result.created += len(tgt.receivers)
# 11. Notification Trackers + Tracker-Targets
for nt in d.notification_trackers:
provider_id = _map_id(id_map, "providers", nt.provider_id)
if provider_id is None:
result.warnings.append(
f"Skipped tracker '{nt.name}': provider {nt.provider_id} not found"
)
result.skipped += 1
continue
name = await _resolve_name(
session, NotificationTracker, nt.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_nt = NotificationTracker(
user_id=user_id, provider_id=provider_id,
name=name, icon=nt.icon, collection_ids=nt.collection_ids,
filters=nt.filters, scan_interval=nt.scan_interval,
batch_duration=nt.batch_duration,
default_tracking_config_id=_map_id(id_map, "tracking_configs", nt.default_tracking_config_id),
default_template_config_id=_map_id(id_map, "template_configs", nt.default_template_config_id),
enabled=nt.enabled,
)
session.add(new_nt)
await session.flush()
for tt in nt.targets:
target_id = _map_id(id_map, "targets", tt.target_id)
if target_id is None:
result.warnings.append(
f"Skipped tracker-target link in '{nt.name}': target {tt.target_id} not found"
)
continue
session.add(NotificationTrackerTarget(
tracker_id=new_nt.id,
target_id=target_id,
tracking_config_id=_map_id(id_map, "tracking_configs", tt.tracking_config_id),
template_config_id=_map_id(id_map, "template_configs", tt.template_config_id),
enabled=tt.enabled,
quiet_hours_start=tt.quiet_hours_start,
quiet_hours_end=tt.quiet_hours_end,
))
result.created += 1
# 12. Command Trackers + Listeners
for ct in d.command_trackers:
provider_id = _map_id(id_map, "providers", ct.provider_id)
if provider_id is None:
result.warnings.append(
f"Skipped command tracker '{ct.name}': provider {ct.provider_id} not found"
)
result.skipped += 1
continue
cc_id = _map_id(id_map, "command_configs", ct.command_config_id)
if cc_id is None:
result.warnings.append(
f"Skipped command tracker '{ct.name}': command config {ct.command_config_id} not found"
)
result.skipped += 1
continue
name = await _resolve_name(
session, CommandTracker, ct.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_ct = CommandTracker(
user_id=user_id, provider_id=provider_id,
command_config_id=cc_id, name=name, icon=ct.icon,
enabled=ct.enabled,
)
session.add(new_ct)
await session.flush()
for lis in ct.listeners:
# Map listener_id based on listener_type
mapped_listener_id = lis.listener_id
if lis.listener_type == "telegram_bot":
mapped_listener_id = _map_id(id_map, "telegram_bots", lis.listener_id) or lis.listener_id
session.add(CommandTrackerListener(
command_tracker_id=new_ct.id,
listener_type=lis.listener_type,
listener_id=mapped_listener_id,
))
result.created += 1
# 13. Actions + Rules
for a in d.actions:
provider_id = _map_id(id_map, "providers", a.provider_id)
if provider_id is None:
result.warnings.append(
f"Skipped action '{a.name}': provider {a.provider_id} not found"
)
result.skipped += 1
continue
name = await _resolve_name(
session, Action, a.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_a = Action(
user_id=user_id, provider_id=provider_id, name=name,
icon=a.icon, action_type=a.action_type, config=a.config,
schedule_type=a.schedule_type,
schedule_interval=a.schedule_interval,
schedule_cron=a.schedule_cron, enabled=False, # always import disabled
)
session.add(new_a)
await session.flush()
for r in a.rules:
session.add(ActionRule(
action_id=new_a.id, name=r.name,
rule_config=r.rule_config, enabled=r.enabled,
order=r.order,
))
result.created += len(a.rules)
await session.commit()
except Exception as e:
await session.rollback()
_LOGGER.error("Backup import failed: %s", e)
result.errors.append(f"Import failed: {e}")
return result
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _map_id(
id_map: dict[str, dict[int, int]],
category: str,
old_id: int | None,
) -> int | None:
"""Resolve an old ID to a new ID via the id_map. Returns None if not found."""
if old_id is None:
return None
return id_map.get(category, {}).get(old_id)
async def _resolve_name(
session: AsyncSession,
model: type,
name: str,
user_id: int,
conflict_mode: ConflictMode,
result: ImportResult,
) -> str | None:
"""Check for name conflict and return the resolved name, or None to skip."""
existing = await session.exec(
select(model).where(
model.name == name,
model.user_id == user_id,
)
)
found = existing.first()
if found is None:
result.created += 1
return name
if conflict_mode == ConflictMode.SKIP:
result.skipped += 1
return None
elif conflict_mode == ConflictMode.RENAME:
result.created += 1
return f"{name} (imported)"
else: # OVERWRITE — delete existing, create new
await session.delete(found)
await session.flush()
result.overwritten += 1
return name
async def _resolve_name_template(
session: AsyncSession,
model: type,
name: str,
user_id: int,
conflict_mode: ConflictMode,
result: ImportResult,
) -> str | None:
"""Like _resolve_name but for template models where user_id can be 0 for system."""
existing = await session.exec(
select(model).where(
model.name == name,
model.user_id == user_id,
)
)
found = existing.first()
if found is None:
result.created += 1
return name
if conflict_mode == ConflictMode.SKIP:
result.skipped += 1
return None
elif conflict_mode == ConflictMode.RENAME:
result.created += 1
return f"{name} (imported)"
else:
await session.delete(found)
await session.flush()
result.overwritten += 1
return name