Files
notify-bridge/packages/server/src/notify_bridge_server/services/backup_service.py
T
alexei.dolgolyov 6b2211353d feat: person excludes for auto-organize rules, backup & restore system
Add person exclude criteria to Immich auto-organize — assets containing
excluded persons are filtered out after candidate gathering. Also adds
full backup/restore system with export, import, scheduled backups, and
retention management.
2026-04-02 14:13:42 +03:00

856 lines
32 KiB
Python

"""Configuration backup/restore service — export and import logic."""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..database.models import (
Action, ActionRule, AppSetting, CommandConfig, CommandTemplateConfig,
CommandTemplateSlot, CommandTracker, CommandTrackerListener, EmailBot,
MatrixBot, NotificationTarget, NotificationTracker,
NotificationTrackerTarget, ServiceProvider, TargetReceiver,
TemplateConfig, TemplateSlot, TelegramBot, TrackingConfig,
)
from .backup_schema import (
ALL_CATEGORIES, ActionData, ActionRuleData, AppSettingData, BackupCategory,
BackupData, BackupFile, CommandConfigData, CommandTemplateConfigData,
CommandTemplateSlotData, CommandTrackerData, CommandTrackerListenerData,
ConflictMode, EmailBotData, ImportResult, MatrixBotData,
NotificationTrackerData, PROVIDER_SECRET_FIELDS, ProviderData,
ReceiverData, SecretsMode, TargetData, TemplateConfigData,
TemplateSlotData, TelegramBotData, TrackerTargetData,
TrackingConfigData, ValidateResult,
)
_LOGGER = logging.getLogger(__name__)
# Fields to skip when serializing TrackingConfig into the generic `fields` dict
_TRACKING_SKIP = frozenset(("id", "user_id", "provider_type", "name", "icon", "created_at"))
# ---------------------------------------------------------------------------
# Export
# ---------------------------------------------------------------------------
def _mask_secret(value: str) -> str:
return f"***{value[-4:]}" if len(value) > 4 else "***"
def _apply_secrets_provider(config: dict[str, Any], mode: SecretsMode) -> dict[str, Any]:
"""Return a copy of provider config with secrets handled per mode."""
result = dict(config)
for key in PROVIDER_SECRET_FIELDS:
if key in result and result[key]:
if mode == SecretsMode.EXCLUDE:
result[key] = ""
elif mode == SecretsMode.MASKED:
result[key] = _mask_secret(result[key])
return result
def _tracking_config_fields(tc: TrackingConfig) -> dict[str, Any]:
"""Extract all tracking config fields (booleans, ints, strings) as a dict."""
data = {}
for field_name in tc.model_fields:
if field_name in _TRACKING_SKIP:
continue
data[field_name] = getattr(tc, field_name)
return data
async def export_backup(
session: AsyncSession,
user_id: int,
categories: list[BackupCategory] | None = None,
secrets_mode: SecretsMode = SecretsMode.EXCLUDE,
) -> BackupFile:
"""Export user configuration as a BackupFile."""
cats = set(categories or ALL_CATEGORIES)
data = BackupData()
# -- Providers --
if BackupCategory.PROVIDERS in cats:
result = await session.exec(
select(ServiceProvider).where(ServiceProvider.user_id == user_id)
)
for p in result.all():
data.providers.append(ProviderData(
id=p.id,
type=p.type,
name=p.name,
icon=p.icon,
config=_apply_secrets_provider(p.config, secrets_mode),
))
# -- Telegram Bots --
if BackupCategory.TELEGRAM_BOTS in cats:
result = await session.exec(
select(TelegramBot).where(TelegramBot.user_id == user_id)
)
for b in result.all():
token = b.token
if secrets_mode == SecretsMode.EXCLUDE:
token = ""
elif secrets_mode == SecretsMode.MASKED:
token = _mask_secret(token)
data.telegram_bots.append(TelegramBotData(
id=b.id, name=b.name, token=token, icon=b.icon,
bot_username=b.bot_username, update_mode=b.update_mode,
))
# -- Matrix Bots --
if BackupCategory.MATRIX_BOTS in cats:
result = await session.exec(
select(MatrixBot).where(MatrixBot.user_id == user_id)
)
for b in result.all():
access_token = b.access_token
if secrets_mode == SecretsMode.EXCLUDE:
access_token = ""
elif secrets_mode == SecretsMode.MASKED:
access_token = _mask_secret(access_token)
data.matrix_bots.append(MatrixBotData(
id=b.id, name=b.name, icon=b.icon,
homeserver_url=b.homeserver_url, access_token=access_token,
display_name=b.display_name,
))
# -- Email Bots --
if BackupCategory.EMAIL_BOTS in cats:
result = await session.exec(
select(EmailBot).where(EmailBot.user_id == user_id)
)
for b in result.all():
smtp_password = b.smtp_password
if secrets_mode == SecretsMode.EXCLUDE:
smtp_password = ""
elif secrets_mode == SecretsMode.MASKED:
smtp_password = _mask_secret(smtp_password) if smtp_password else ""
data.email_bots.append(EmailBotData(
id=b.id, name=b.name, icon=b.icon, email=b.email,
smtp_host=b.smtp_host, smtp_port=b.smtp_port,
smtp_username=b.smtp_username, smtp_password=smtp_password,
smtp_use_tls=b.smtp_use_tls,
))
# -- Targets + Receivers --
if BackupCategory.TARGETS in cats:
result = await session.exec(
select(NotificationTarget).where(NotificationTarget.user_id == user_id)
)
for tgt in result.all():
recv_result = await session.exec(
select(TargetReceiver).where(TargetReceiver.target_id == tgt.id)
)
receivers = [
ReceiverData(
name=r.name, config=r.config, receiver_key=r.receiver_key,
locale=r.locale, enabled=r.enabled,
)
for r in recv_result.all()
]
data.targets.append(TargetData(
id=tgt.id, type=tgt.type, name=tgt.name, icon=tgt.icon,
config=tgt.config, chat_action=tgt.chat_action,
receivers=receivers,
))
# -- Tracking Configs --
if BackupCategory.TRACKING_CONFIGS in cats:
result = await session.exec(
select(TrackingConfig).where(TrackingConfig.user_id == user_id)
)
for tc in result.all():
data.tracking_configs.append(TrackingConfigData(
id=tc.id, provider_type=tc.provider_type, name=tc.name,
icon=tc.icon, fields=_tracking_config_fields(tc),
))
# -- Template Configs + Slots (user-owned only) --
if BackupCategory.TEMPLATE_CONFIGS in cats:
result = await session.exec(
select(TemplateConfig).where(TemplateConfig.user_id == user_id)
)
for tc in result.all():
slot_result = await session.exec(
select(TemplateSlot).where(TemplateSlot.config_id == tc.id)
)
slots = [
TemplateSlotData(
slot_name=s.slot_name, locale=s.locale, template=s.template,
)
for s in slot_result.all()
]
data.template_configs.append(TemplateConfigData(
id=tc.id, provider_type=tc.provider_type, name=tc.name,
description=tc.description, icon=tc.icon, locale=tc.locale,
date_format=tc.date_format, date_only_format=tc.date_only_format,
slots=slots,
))
# -- Command Template Configs + Slots (user-owned only) --
if BackupCategory.COMMAND_TEMPLATE_CONFIGS in cats:
result = await session.exec(
select(CommandTemplateConfig).where(CommandTemplateConfig.user_id == user_id)
)
for ctc in result.all():
slot_result = await session.exec(
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == ctc.id)
)
slots = [
CommandTemplateSlotData(
slot_name=s.slot_name, locale=s.locale, template=s.template,
)
for s in slot_result.all()
]
data.command_template_configs.append(CommandTemplateConfigData(
id=ctc.id, provider_type=ctc.provider_type, name=ctc.name,
description=ctc.description, icon=ctc.icon, locale=ctc.locale,
slots=slots,
))
# -- Command Configs --
if BackupCategory.COMMAND_CONFIGS in cats:
result = await session.exec(
select(CommandConfig).where(CommandConfig.user_id == user_id)
)
for cc in result.all():
data.command_configs.append(CommandConfigData(
id=cc.id, provider_type=cc.provider_type, name=cc.name,
icon=cc.icon, enabled_commands=cc.enabled_commands,
response_mode=cc.response_mode, default_count=cc.default_count,
rate_limits=cc.rate_limits,
command_template_config_id=cc.command_template_config_id,
))
# -- Notification Trackers + Tracker-Targets --
if BackupCategory.NOTIFICATION_TRACKERS in cats:
result = await session.exec(
select(NotificationTracker).where(NotificationTracker.user_id == user_id)
)
for nt in result.all():
tt_result = await session.exec(
select(NotificationTrackerTarget).where(
NotificationTrackerTarget.tracker_id == nt.id
)
)
targets = [
TrackerTargetData(
target_id=tt.target_id,
tracking_config_id=tt.tracking_config_id,
template_config_id=tt.template_config_id,
enabled=tt.enabled,
quiet_hours_start=tt.quiet_hours_start,
quiet_hours_end=tt.quiet_hours_end,
)
for tt in tt_result.all()
]
data.notification_trackers.append(NotificationTrackerData(
id=nt.id, provider_id=nt.provider_id, name=nt.name,
icon=nt.icon, collection_ids=nt.collection_ids,
filters=nt.filters, scan_interval=nt.scan_interval,
batch_duration=nt.batch_duration,
default_tracking_config_id=nt.default_tracking_config_id,
default_template_config_id=nt.default_template_config_id,
enabled=nt.enabled, targets=targets,
))
# -- Command Trackers + Listeners --
if BackupCategory.COMMAND_TRACKERS in cats:
result = await session.exec(
select(CommandTracker).where(CommandTracker.user_id == user_id)
)
for ct in result.all():
lis_result = await session.exec(
select(CommandTrackerListener).where(
CommandTrackerListener.command_tracker_id == ct.id
)
)
listeners = [
CommandTrackerListenerData(
listener_type=l.listener_type, listener_id=l.listener_id,
)
for l in lis_result.all()
]
data.command_trackers.append(CommandTrackerData(
id=ct.id, provider_id=ct.provider_id,
command_config_id=ct.command_config_id, name=ct.name,
icon=ct.icon, enabled=ct.enabled, listeners=listeners,
))
# -- Actions + Rules --
if BackupCategory.ACTIONS in cats:
result = await session.exec(
select(Action).where(Action.user_id == user_id)
)
for a in result.all():
rule_result = await session.exec(
select(ActionRule).where(ActionRule.action_id == a.id)
)
rules = [
ActionRuleData(
name=r.name, rule_config=r.rule_config,
enabled=r.enabled, order=r.order,
)
for r in rule_result.all()
]
data.actions.append(ActionData(
id=a.id, provider_id=a.provider_id, name=a.name,
icon=a.icon, action_type=a.action_type, config=a.config,
schedule_type=a.schedule_type,
schedule_interval=a.schedule_interval,
schedule_cron=a.schedule_cron, enabled=a.enabled,
rules=rules,
))
# -- App Settings --
if BackupCategory.APP_SETTINGS in cats:
result = await session.exec(select(AppSetting))
for s in result.all():
value = s.value
if s.key == "telegram_webhook_secret" and value:
if secrets_mode == SecretsMode.EXCLUDE:
value = ""
elif secrets_mode == SecretsMode.MASKED:
value = _mask_secret(value)
data.app_settings.append(AppSettingData(key=s.key, value=value))
return BackupFile(
format="notify-bridge-backup",
version=1,
created_at=datetime.now(timezone.utc).isoformat(),
app_version="1.0.0",
secrets_mode=secrets_mode,
categories=[c.value for c in cats],
data=data,
)
# ---------------------------------------------------------------------------
# Export to file (for scheduled backups)
# ---------------------------------------------------------------------------
async def export_backup_to_file(
session: AsyncSession,
user_id: int,
backup_dir: Path,
secrets_mode: SecretsMode = SecretsMode.EXCLUDE,
) -> Path:
"""Export backup and write to a file in backup_dir. Returns the file path."""
backup_dir.mkdir(parents=True, exist_ok=True)
backup = await export_backup(session, user_id, secrets_mode=secrets_mode)
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
filename = f"backup-{ts}.json"
filepath = backup_dir / filename
filepath.write_text(
json.dumps(backup.model_dump(), indent=2, ensure_ascii=False),
encoding="utf-8",
)
_LOGGER.info("Scheduled backup saved: %s", filepath)
return filepath
def cleanup_old_backups(backup_dir: Path, keep: int = 5) -> list[str]:
"""Delete oldest backup files exceeding `keep` count. Returns deleted filenames."""
if not backup_dir.is_dir():
return []
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
deleted = []
for old in files[keep:]:
old.unlink()
deleted.append(old.name)
if deleted:
_LOGGER.info("Cleaned up %d old backup(s): %s", len(deleted), deleted)
return deleted
def list_backup_files(backup_dir: Path) -> list[dict[str, Any]]:
"""List backup files in the directory with metadata."""
if not backup_dir.is_dir():
return []
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
result = []
for f in files:
stat = f.stat()
result.append({
"filename": f.name,
"size": stat.st_size,
"created_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
})
return result
# ---------------------------------------------------------------------------
# Validate
# ---------------------------------------------------------------------------
def validate_backup(raw: dict[str, Any]) -> ValidateResult:
"""Validate a backup file dict without importing. Returns summary."""
warnings: list[str] = []
errors: list[str] = []
fmt = raw.get("format")
if fmt != "notify-bridge-backup":
errors.append(f"Unknown format: {fmt}")
return ValidateResult(valid=False, errors=errors)
version = raw.get("version", 0)
if version > 1:
errors.append(f"Unsupported backup version: {version} (max supported: 1)")
return ValidateResult(valid=False, version=version, errors=errors)
secrets_mode = raw.get("secrets_mode", "exclude")
if secrets_mode in ("exclude", "masked"):
warnings.append(
f"Backup was exported with secrets_mode={secrets_mode}. "
"Imported entities will have empty/placeholder secrets that need manual update."
)
try:
backup = BackupFile.model_validate(raw)
except Exception as e:
errors.append(f"Schema validation failed: {e}")
return ValidateResult(valid=False, version=version, errors=errors)
counts: dict[str, int] = {}
d = backup.data
for cat in ("providers", "telegram_bots", "matrix_bots", "email_bots",
"targets", "tracking_configs", "template_configs",
"command_configs", "command_template_configs",
"notification_trackers", "command_trackers", "actions",
"app_settings"):
items = getattr(d, cat, [])
if items:
counts[cat] = len(items)
return ValidateResult(
valid=True, version=version,
entity_counts=counts, warnings=warnings, errors=errors,
)
# ---------------------------------------------------------------------------
# Import
# ---------------------------------------------------------------------------
async def import_backup(
session: AsyncSession,
user_id: int,
backup: BackupFile,
conflict_mode: ConflictMode = ConflictMode.SKIP,
) -> ImportResult:
"""Import a backup file into the database. Atomic — rolls back on error."""
result = ImportResult()
# Maps: category -> {old_id: new_id}
id_map: dict[str, dict[int, int]] = {}
d = backup.data
try:
# 1. App Settings (simple upsert)
for s in d.app_settings:
existing = await session.get(AppSetting, s.key)
if existing:
if conflict_mode == ConflictMode.SKIP:
result.skipped += 1
continue
elif conflict_mode == ConflictMode.OVERWRITE:
existing.value = s.value
session.add(existing)
result.overwritten += 1
else: # rename — not applicable for settings, just skip
result.skipped += 1
continue
else:
session.add(AppSetting(key=s.key, value=s.value))
result.created += 1
# 2. Telegram Bots
id_map["telegram_bots"] = {}
for b in d.telegram_bots:
name = await _resolve_name(
session, TelegramBot, b.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_bot = TelegramBot(
user_id=user_id, name=name, token=b.token, icon=b.icon,
bot_username=b.bot_username, update_mode=b.update_mode,
)
session.add(new_bot)
await session.flush()
id_map["telegram_bots"][b.id] = new_bot.id
# 3. Matrix Bots
id_map["matrix_bots"] = {}
for b in d.matrix_bots:
name = await _resolve_name(
session, MatrixBot, b.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_bot = MatrixBot(
user_id=user_id, name=name, icon=b.icon,
homeserver_url=b.homeserver_url, access_token=b.access_token,
display_name=b.display_name,
)
session.add(new_bot)
await session.flush()
id_map["matrix_bots"][b.id] = new_bot.id
# 4. Email Bots
id_map["email_bots"] = {}
for b in d.email_bots:
name = await _resolve_name(
session, EmailBot, b.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_bot = EmailBot(
user_id=user_id, name=name, icon=b.icon, email=b.email,
smtp_host=b.smtp_host, smtp_port=b.smtp_port,
smtp_username=b.smtp_username, smtp_password=b.smtp_password,
smtp_use_tls=b.smtp_use_tls,
)
session.add(new_bot)
await session.flush()
id_map["email_bots"][b.id] = new_bot.id
# 5. Providers
id_map["providers"] = {}
for p in d.providers:
name = await _resolve_name(
session, ServiceProvider, p.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_p = ServiceProvider(
user_id=user_id, type=p.type, name=name,
icon=p.icon, config=p.config,
)
session.add(new_p)
await session.flush()
id_map["providers"][p.id] = new_p.id
# 6. Tracking Configs
id_map["tracking_configs"] = {}
for tc in d.tracking_configs:
name = await _resolve_name(
session, TrackingConfig, tc.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_tc = TrackingConfig(
user_id=user_id, provider_type=tc.provider_type,
name=name, icon=tc.icon,
)
# Apply all tracked fields
for field_name, value in tc.fields.items():
if hasattr(new_tc, field_name):
setattr(new_tc, field_name, value)
session.add(new_tc)
await session.flush()
id_map["tracking_configs"][tc.id] = new_tc.id
# 7. Template Configs + Slots
id_map["template_configs"] = {}
for tc in d.template_configs:
name = await _resolve_name_template(
session, TemplateConfig, tc.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_tc = TemplateConfig(
user_id=user_id, provider_type=tc.provider_type,
name=name, description=tc.description, icon=tc.icon,
locale=tc.locale, date_format=tc.date_format,
date_only_format=tc.date_only_format,
)
session.add(new_tc)
await session.flush()
id_map["template_configs"][tc.id] = new_tc.id
for s in tc.slots:
session.add(TemplateSlot(
config_id=new_tc.id, slot_name=s.slot_name,
locale=s.locale, template=s.template,
))
result.created += len(tc.slots)
# 8. Command Template Configs + Slots
id_map["command_template_configs"] = {}
for ctc in d.command_template_configs:
name = await _resolve_name_template(
session, CommandTemplateConfig, ctc.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_ctc = CommandTemplateConfig(
user_id=user_id, provider_type=ctc.provider_type,
name=name, description=ctc.description, icon=ctc.icon,
locale=ctc.locale,
)
session.add(new_ctc)
await session.flush()
id_map["command_template_configs"][ctc.id] = new_ctc.id
for s in ctc.slots:
session.add(CommandTemplateSlot(
config_id=new_ctc.id, slot_name=s.slot_name,
locale=s.locale, template=s.template,
))
result.created += len(ctc.slots)
# 9. Command Configs
id_map["command_configs"] = {}
for cc in d.command_configs:
name = await _resolve_name(
session, CommandConfig, cc.name, user_id, conflict_mode, result,
)
if name is None:
continue
ctc_id = _map_id(id_map, "command_template_configs", cc.command_template_config_id)
new_cc = CommandConfig(
user_id=user_id, provider_type=cc.provider_type,
name=name, icon=cc.icon,
enabled_commands=cc.enabled_commands,
response_mode=cc.response_mode,
default_count=cc.default_count,
rate_limits=cc.rate_limits,
command_template_config_id=ctc_id,
)
session.add(new_cc)
await session.flush()
id_map["command_configs"][cc.id] = new_cc.id
# 10. Targets + Receivers
id_map["targets"] = {}
for tgt in d.targets:
name = await _resolve_name(
session, NotificationTarget, tgt.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_tgt = NotificationTarget(
user_id=user_id, type=tgt.type, name=name,
icon=tgt.icon, config=tgt.config,
chat_action=tgt.chat_action,
)
session.add(new_tgt)
await session.flush()
id_map["targets"][tgt.id] = new_tgt.id
for r in tgt.receivers:
session.add(TargetReceiver(
target_id=new_tgt.id, name=r.name, config=r.config,
receiver_key=r.receiver_key, locale=r.locale,
enabled=r.enabled,
))
result.created += len(tgt.receivers)
# 11. Notification Trackers + Tracker-Targets
for nt in d.notification_trackers:
provider_id = _map_id(id_map, "providers", nt.provider_id)
if provider_id is None:
result.warnings.append(
f"Skipped tracker '{nt.name}': provider {nt.provider_id} not found"
)
result.skipped += 1
continue
name = await _resolve_name(
session, NotificationTracker, nt.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_nt = NotificationTracker(
user_id=user_id, provider_id=provider_id,
name=name, icon=nt.icon, collection_ids=nt.collection_ids,
filters=nt.filters, scan_interval=nt.scan_interval,
batch_duration=nt.batch_duration,
default_tracking_config_id=_map_id(id_map, "tracking_configs", nt.default_tracking_config_id),
default_template_config_id=_map_id(id_map, "template_configs", nt.default_template_config_id),
enabled=nt.enabled,
)
session.add(new_nt)
await session.flush()
for tt in nt.targets:
target_id = _map_id(id_map, "targets", tt.target_id)
if target_id is None:
result.warnings.append(
f"Skipped tracker-target link in '{nt.name}': target {tt.target_id} not found"
)
continue
session.add(NotificationTrackerTarget(
tracker_id=new_nt.id,
target_id=target_id,
tracking_config_id=_map_id(id_map, "tracking_configs", tt.tracking_config_id),
template_config_id=_map_id(id_map, "template_configs", tt.template_config_id),
enabled=tt.enabled,
quiet_hours_start=tt.quiet_hours_start,
quiet_hours_end=tt.quiet_hours_end,
))
result.created += 1
# 12. Command Trackers + Listeners
for ct in d.command_trackers:
provider_id = _map_id(id_map, "providers", ct.provider_id)
if provider_id is None:
result.warnings.append(
f"Skipped command tracker '{ct.name}': provider {ct.provider_id} not found"
)
result.skipped += 1
continue
cc_id = _map_id(id_map, "command_configs", ct.command_config_id)
if cc_id is None:
result.warnings.append(
f"Skipped command tracker '{ct.name}': command config {ct.command_config_id} not found"
)
result.skipped += 1
continue
name = await _resolve_name(
session, CommandTracker, ct.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_ct = CommandTracker(
user_id=user_id, provider_id=provider_id,
command_config_id=cc_id, name=name, icon=ct.icon,
enabled=ct.enabled,
)
session.add(new_ct)
await session.flush()
for lis in ct.listeners:
# Map listener_id based on listener_type
mapped_listener_id = lis.listener_id
if lis.listener_type == "telegram_bot":
mapped_listener_id = _map_id(id_map, "telegram_bots", lis.listener_id) or lis.listener_id
session.add(CommandTrackerListener(
command_tracker_id=new_ct.id,
listener_type=lis.listener_type,
listener_id=mapped_listener_id,
))
result.created += 1
# 13. Actions + Rules
for a in d.actions:
provider_id = _map_id(id_map, "providers", a.provider_id)
if provider_id is None:
result.warnings.append(
f"Skipped action '{a.name}': provider {a.provider_id} not found"
)
result.skipped += 1
continue
name = await _resolve_name(
session, Action, a.name, user_id, conflict_mode, result,
)
if name is None:
continue
new_a = Action(
user_id=user_id, provider_id=provider_id, name=name,
icon=a.icon, action_type=a.action_type, config=a.config,
schedule_type=a.schedule_type,
schedule_interval=a.schedule_interval,
schedule_cron=a.schedule_cron, enabled=False, # always import disabled
)
session.add(new_a)
await session.flush()
for r in a.rules:
session.add(ActionRule(
action_id=new_a.id, name=r.name,
rule_config=r.rule_config, enabled=r.enabled,
order=r.order,
))
result.created += len(a.rules)
await session.commit()
except Exception as e:
await session.rollback()
_LOGGER.error("Backup import failed: %s", e)
result.errors.append(f"Import failed: {e}")
return result
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _map_id(
id_map: dict[str, dict[int, int]],
category: str,
old_id: int | None,
) -> int | None:
"""Resolve an old ID to a new ID via the id_map. Returns None if not found."""
if old_id is None:
return None
return id_map.get(category, {}).get(old_id)
async def _resolve_name(
session: AsyncSession,
model: type,
name: str,
user_id: int,
conflict_mode: ConflictMode,
result: ImportResult,
) -> str | None:
"""Check for name conflict and return the resolved name, or None to skip."""
existing = await session.exec(
select(model).where(
model.name == name,
model.user_id == user_id,
)
)
found = existing.first()
if found is None:
result.created += 1
return name
if conflict_mode == ConflictMode.SKIP:
result.skipped += 1
return None
elif conflict_mode == ConflictMode.RENAME:
result.created += 1
return f"{name} (imported)"
else: # OVERWRITE — delete existing, create new
await session.delete(found)
await session.flush()
result.overwritten += 1
return name
async def _resolve_name_template(
session: AsyncSession,
model: type,
name: str,
user_id: int,
conflict_mode: ConflictMode,
result: ImportResult,
) -> str | None:
"""Like _resolve_name but for template models where user_id can be 0 for system."""
existing = await session.exec(
select(model).where(
model.name == name,
model.user_id == user_id,
)
)
found = existing.first()
if found is None:
result.created += 1
return name
if conflict_mode == ConflictMode.SKIP:
result.skipped += 1
return None
elif conflict_mode == ConflictMode.RENAME:
result.created += 1
return f"{name} (imported)"
else:
await session.delete(found)
await session.flush()
result.overwritten += 1
return name