f0739ca949
- Add outbound URL validation (SSRF) for webhook/Discord/Slack/ntfy/Matrix dispatch - Template renderer: input/output caps and thread-based render timeout - Webhook log filter: strip Authorization/signature/token-like headers; atomic prune - Auth/JWT/backup/config tightening; misc frontend UX fixes
909 lines
35 KiB
Python
909 lines
35 KiB
Python
"""Configuration backup/restore service — export and import logic."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from sqlmodel import select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
|
|
from ..database.models import (
|
|
Action, ActionRule, AppSetting, CommandConfig, CommandTemplateConfig,
|
|
CommandTemplateSlot, CommandTracker, CommandTrackerListener, EmailBot,
|
|
MatrixBot, NotificationTarget, NotificationTracker,
|
|
NotificationTrackerTarget, ServiceProvider, TargetReceiver,
|
|
TemplateConfig, TemplateSlot, TelegramBot, TrackingConfig,
|
|
)
|
|
from .backup_schema import (
|
|
ALL_CATEGORIES, ActionData, ActionRuleData, AppSettingData, BackupCategory,
|
|
BackupData, BackupFile, CommandConfigData, CommandTemplateConfigData,
|
|
CommandTemplateSlotData, CommandTrackerData, CommandTrackerListenerData,
|
|
ConflictMode, EmailBotData, ImportResult, MatrixBotData,
|
|
NotificationTrackerData, PROVIDER_SECRET_FIELDS, ProviderData,
|
|
ReceiverData, SecretsMode, TargetData, TemplateConfigData,
|
|
TemplateSlotData, TelegramBotData, TrackerTargetData,
|
|
TrackingConfigData, ValidateResult,
|
|
)
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
# Fields to skip when serializing TrackingConfig into the generic `fields` dict
|
|
_TRACKING_SKIP = frozenset(("id", "user_id", "provider_type", "name", "icon", "created_at"))
|
|
|
|
# Import-time config hardening limits
|
|
_MAX_CONFIG_DEPTH = 6
|
|
_MAX_CONFIG_KEYS = 200
|
|
_MAX_STRING_LEN = 8192
|
|
|
|
|
|
def _sanitize_config(value: Any, depth: int = 0) -> Any:
|
|
"""Clamp imported config values to safe shapes before persistence.
|
|
|
|
Rejects anything that is not a JSON-compatible primitive/container, truncates
|
|
over-long strings, and caps dict/list sizes. Returns a defensively-copied
|
|
structure; the caller should never see attacker-controlled references.
|
|
"""
|
|
if depth > _MAX_CONFIG_DEPTH:
|
|
raise ValueError("Config nesting exceeds maximum depth")
|
|
if value is None or isinstance(value, bool):
|
|
return value
|
|
if isinstance(value, (int, float)):
|
|
return value
|
|
if isinstance(value, str):
|
|
return value[:_MAX_STRING_LEN]
|
|
if isinstance(value, list):
|
|
if len(value) > _MAX_CONFIG_KEYS:
|
|
raise ValueError("Config list exceeds maximum length")
|
|
return [_sanitize_config(v, depth + 1) for v in value]
|
|
if isinstance(value, dict):
|
|
if len(value) > _MAX_CONFIG_KEYS:
|
|
raise ValueError("Config dict exceeds maximum key count")
|
|
cleaned: dict[str, Any] = {}
|
|
for k, v in value.items():
|
|
if not isinstance(k, str):
|
|
raise ValueError("Config keys must be strings")
|
|
if len(k) > 128:
|
|
raise ValueError(f"Config key too long: {k[:40]}...")
|
|
cleaned[k] = _sanitize_config(v, depth + 1)
|
|
return cleaned
|
|
raise ValueError(f"Unsupported config value type: {type(value).__name__}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Export
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _mask_secret(value: str) -> str:
|
|
return f"***{value[-4:]}" if len(value) > 4 else "***"
|
|
|
|
|
|
def _apply_secrets_provider(config: dict[str, Any], mode: SecretsMode) -> dict[str, Any]:
|
|
"""Return a copy of provider config with secrets handled per mode."""
|
|
result = dict(config)
|
|
for key in PROVIDER_SECRET_FIELDS:
|
|
if key in result and result[key]:
|
|
if mode == SecretsMode.EXCLUDE:
|
|
result[key] = ""
|
|
elif mode == SecretsMode.MASKED:
|
|
result[key] = _mask_secret(result[key])
|
|
return result
|
|
|
|
|
|
def _tracking_config_fields(tc: TrackingConfig) -> dict[str, Any]:
|
|
"""Extract all tracking config fields (booleans, ints, strings) as a dict."""
|
|
data = {}
|
|
for field_name in tc.model_fields:
|
|
if field_name in _TRACKING_SKIP:
|
|
continue
|
|
data[field_name] = getattr(tc, field_name)
|
|
return data
|
|
|
|
|
|
async def export_backup(
|
|
session: AsyncSession,
|
|
user_id: int,
|
|
categories: list[BackupCategory] | None = None,
|
|
secrets_mode: SecretsMode = SecretsMode.EXCLUDE,
|
|
) -> BackupFile:
|
|
"""Export user configuration as a BackupFile."""
|
|
cats = set(categories or ALL_CATEGORIES)
|
|
data = BackupData()
|
|
|
|
# -- Providers --
|
|
if BackupCategory.PROVIDERS in cats:
|
|
result = await session.exec(
|
|
select(ServiceProvider).where(ServiceProvider.user_id == user_id)
|
|
)
|
|
for p in result.all():
|
|
data.providers.append(ProviderData(
|
|
id=p.id,
|
|
type=p.type,
|
|
name=p.name,
|
|
icon=p.icon,
|
|
config=_apply_secrets_provider(p.config, secrets_mode),
|
|
))
|
|
|
|
# -- Telegram Bots --
|
|
if BackupCategory.TELEGRAM_BOTS in cats:
|
|
result = await session.exec(
|
|
select(TelegramBot).where(TelegramBot.user_id == user_id)
|
|
)
|
|
for b in result.all():
|
|
token = b.token
|
|
if secrets_mode == SecretsMode.EXCLUDE:
|
|
token = ""
|
|
elif secrets_mode == SecretsMode.MASKED:
|
|
token = _mask_secret(token)
|
|
data.telegram_bots.append(TelegramBotData(
|
|
id=b.id, name=b.name, token=token, icon=b.icon,
|
|
bot_username=b.bot_username, update_mode=b.update_mode,
|
|
))
|
|
|
|
# -- Matrix Bots --
|
|
if BackupCategory.MATRIX_BOTS in cats:
|
|
result = await session.exec(
|
|
select(MatrixBot).where(MatrixBot.user_id == user_id)
|
|
)
|
|
for b in result.all():
|
|
access_token = b.access_token
|
|
if secrets_mode == SecretsMode.EXCLUDE:
|
|
access_token = ""
|
|
elif secrets_mode == SecretsMode.MASKED:
|
|
access_token = _mask_secret(access_token)
|
|
data.matrix_bots.append(MatrixBotData(
|
|
id=b.id, name=b.name, icon=b.icon,
|
|
homeserver_url=b.homeserver_url, access_token=access_token,
|
|
display_name=b.display_name,
|
|
))
|
|
|
|
# -- Email Bots --
|
|
if BackupCategory.EMAIL_BOTS in cats:
|
|
result = await session.exec(
|
|
select(EmailBot).where(EmailBot.user_id == user_id)
|
|
)
|
|
for b in result.all():
|
|
smtp_password = b.smtp_password
|
|
if secrets_mode == SecretsMode.EXCLUDE:
|
|
smtp_password = ""
|
|
elif secrets_mode == SecretsMode.MASKED:
|
|
smtp_password = _mask_secret(smtp_password) if smtp_password else ""
|
|
data.email_bots.append(EmailBotData(
|
|
id=b.id, name=b.name, icon=b.icon, email=b.email,
|
|
smtp_host=b.smtp_host, smtp_port=b.smtp_port,
|
|
smtp_username=b.smtp_username, smtp_password=smtp_password,
|
|
smtp_use_tls=b.smtp_use_tls,
|
|
))
|
|
|
|
# -- Targets + Receivers --
|
|
if BackupCategory.TARGETS in cats:
|
|
result = await session.exec(
|
|
select(NotificationTarget).where(NotificationTarget.user_id == user_id)
|
|
)
|
|
for tgt in result.all():
|
|
recv_result = await session.exec(
|
|
select(TargetReceiver).where(TargetReceiver.target_id == tgt.id)
|
|
)
|
|
receivers = [
|
|
ReceiverData(
|
|
name=r.name, config=r.config, receiver_key=r.receiver_key,
|
|
locale=r.locale, enabled=r.enabled,
|
|
)
|
|
for r in recv_result.all()
|
|
]
|
|
data.targets.append(TargetData(
|
|
id=tgt.id, type=tgt.type, name=tgt.name, icon=tgt.icon,
|
|
config=tgt.config, chat_action=tgt.chat_action,
|
|
receivers=receivers,
|
|
))
|
|
|
|
# -- Tracking Configs --
|
|
if BackupCategory.TRACKING_CONFIGS in cats:
|
|
result = await session.exec(
|
|
select(TrackingConfig).where(TrackingConfig.user_id == user_id)
|
|
)
|
|
for tc in result.all():
|
|
data.tracking_configs.append(TrackingConfigData(
|
|
id=tc.id, provider_type=tc.provider_type, name=tc.name,
|
|
icon=tc.icon, fields=_tracking_config_fields(tc),
|
|
))
|
|
|
|
# -- Template Configs + Slots (user-owned only) --
|
|
if BackupCategory.TEMPLATE_CONFIGS in cats:
|
|
result = await session.exec(
|
|
select(TemplateConfig).where(TemplateConfig.user_id == user_id)
|
|
)
|
|
for tc in result.all():
|
|
slot_result = await session.exec(
|
|
select(TemplateSlot).where(TemplateSlot.config_id == tc.id)
|
|
)
|
|
slots = [
|
|
TemplateSlotData(
|
|
slot_name=s.slot_name, locale=s.locale, template=s.template,
|
|
)
|
|
for s in slot_result.all()
|
|
]
|
|
data.template_configs.append(TemplateConfigData(
|
|
id=tc.id, provider_type=tc.provider_type, name=tc.name,
|
|
description=tc.description, icon=tc.icon, locale=tc.locale,
|
|
date_format=tc.date_format, date_only_format=tc.date_only_format,
|
|
slots=slots,
|
|
))
|
|
|
|
# -- Command Template Configs + Slots (user-owned only) --
|
|
if BackupCategory.COMMAND_TEMPLATE_CONFIGS in cats:
|
|
result = await session.exec(
|
|
select(CommandTemplateConfig).where(CommandTemplateConfig.user_id == user_id)
|
|
)
|
|
for ctc in result.all():
|
|
slot_result = await session.exec(
|
|
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == ctc.id)
|
|
)
|
|
slots = [
|
|
CommandTemplateSlotData(
|
|
slot_name=s.slot_name, locale=s.locale, template=s.template,
|
|
)
|
|
for s in slot_result.all()
|
|
]
|
|
data.command_template_configs.append(CommandTemplateConfigData(
|
|
id=ctc.id, provider_type=ctc.provider_type, name=ctc.name,
|
|
description=ctc.description, icon=ctc.icon, locale=ctc.locale,
|
|
slots=slots,
|
|
))
|
|
|
|
# -- Command Configs --
|
|
if BackupCategory.COMMAND_CONFIGS in cats:
|
|
result = await session.exec(
|
|
select(CommandConfig).where(CommandConfig.user_id == user_id)
|
|
)
|
|
for cc in result.all():
|
|
data.command_configs.append(CommandConfigData(
|
|
id=cc.id, provider_type=cc.provider_type, name=cc.name,
|
|
icon=cc.icon, enabled_commands=cc.enabled_commands,
|
|
response_mode=cc.response_mode, default_count=cc.default_count,
|
|
rate_limits=cc.rate_limits,
|
|
command_template_config_id=cc.command_template_config_id,
|
|
))
|
|
|
|
# -- Notification Trackers + Tracker-Targets --
|
|
if BackupCategory.NOTIFICATION_TRACKERS in cats:
|
|
result = await session.exec(
|
|
select(NotificationTracker).where(NotificationTracker.user_id == user_id)
|
|
)
|
|
for nt in result.all():
|
|
tt_result = await session.exec(
|
|
select(NotificationTrackerTarget).where(
|
|
NotificationTrackerTarget.tracker_id == nt.id
|
|
)
|
|
)
|
|
targets = [
|
|
TrackerTargetData(
|
|
target_id=tt.target_id,
|
|
tracking_config_id=tt.tracking_config_id,
|
|
template_config_id=tt.template_config_id,
|
|
enabled=tt.enabled,
|
|
quiet_hours_start=tt.quiet_hours_start,
|
|
quiet_hours_end=tt.quiet_hours_end,
|
|
)
|
|
for tt in tt_result.all()
|
|
]
|
|
data.notification_trackers.append(NotificationTrackerData(
|
|
id=nt.id, provider_id=nt.provider_id, name=nt.name,
|
|
icon=nt.icon, collection_ids=nt.collection_ids,
|
|
filters=nt.filters, scan_interval=nt.scan_interval,
|
|
batch_duration=nt.batch_duration,
|
|
default_tracking_config_id=nt.default_tracking_config_id,
|
|
default_template_config_id=nt.default_template_config_id,
|
|
enabled=nt.enabled, targets=targets,
|
|
))
|
|
|
|
# -- Command Trackers + Listeners --
|
|
if BackupCategory.COMMAND_TRACKERS in cats:
|
|
result = await session.exec(
|
|
select(CommandTracker).where(CommandTracker.user_id == user_id)
|
|
)
|
|
for ct in result.all():
|
|
lis_result = await session.exec(
|
|
select(CommandTrackerListener).where(
|
|
CommandTrackerListener.command_tracker_id == ct.id
|
|
)
|
|
)
|
|
listeners = [
|
|
CommandTrackerListenerData(
|
|
listener_type=l.listener_type, listener_id=l.listener_id,
|
|
)
|
|
for l in lis_result.all()
|
|
]
|
|
data.command_trackers.append(CommandTrackerData(
|
|
id=ct.id, provider_id=ct.provider_id,
|
|
command_config_id=ct.command_config_id, name=ct.name,
|
|
icon=ct.icon, enabled=ct.enabled, listeners=listeners,
|
|
))
|
|
|
|
# -- Actions + Rules --
|
|
if BackupCategory.ACTIONS in cats:
|
|
result = await session.exec(
|
|
select(Action).where(Action.user_id == user_id)
|
|
)
|
|
for a in result.all():
|
|
rule_result = await session.exec(
|
|
select(ActionRule).where(ActionRule.action_id == a.id)
|
|
)
|
|
rules = [
|
|
ActionRuleData(
|
|
name=r.name, rule_config=r.rule_config,
|
|
enabled=r.enabled, order=r.order,
|
|
)
|
|
for r in rule_result.all()
|
|
]
|
|
data.actions.append(ActionData(
|
|
id=a.id, provider_id=a.provider_id, name=a.name,
|
|
icon=a.icon, action_type=a.action_type, config=a.config,
|
|
schedule_type=a.schedule_type,
|
|
schedule_interval=a.schedule_interval,
|
|
schedule_cron=a.schedule_cron, enabled=a.enabled,
|
|
rules=rules,
|
|
))
|
|
|
|
# -- App Settings --
|
|
if BackupCategory.APP_SETTINGS in cats:
|
|
result = await session.exec(select(AppSetting))
|
|
for s in result.all():
|
|
value = s.value
|
|
if s.key == "telegram_webhook_secret" and value:
|
|
if secrets_mode == SecretsMode.EXCLUDE:
|
|
value = ""
|
|
elif secrets_mode == SecretsMode.MASKED:
|
|
value = _mask_secret(value)
|
|
data.app_settings.append(AppSettingData(key=s.key, value=value))
|
|
|
|
return BackupFile(
|
|
format="notify-bridge-backup",
|
|
version=1,
|
|
created_at=datetime.now(timezone.utc).isoformat(),
|
|
app_version="1.0.0",
|
|
secrets_mode=secrets_mode,
|
|
categories=[c.value for c in cats],
|
|
data=data,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Export to file (for scheduled backups)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def export_backup_to_file(
|
|
session: AsyncSession,
|
|
user_id: int,
|
|
backup_dir: Path,
|
|
secrets_mode: SecretsMode = SecretsMode.EXCLUDE,
|
|
) -> Path:
|
|
"""Export backup and write to a file in backup_dir. Returns the file path."""
|
|
backup_dir.mkdir(parents=True, exist_ok=True)
|
|
backup = await export_backup(session, user_id, secrets_mode=secrets_mode)
|
|
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%S")
|
|
filename = f"backup-{ts}.json"
|
|
filepath = backup_dir / filename
|
|
filepath.write_text(
|
|
json.dumps(backup.model_dump(), indent=2, ensure_ascii=False),
|
|
encoding="utf-8",
|
|
)
|
|
_LOGGER.info("Scheduled backup saved: %s", filepath)
|
|
return filepath
|
|
|
|
|
|
def cleanup_old_backups(backup_dir: Path, keep: int = 5) -> list[str]:
|
|
"""Delete oldest backup files exceeding `keep` count. Returns deleted filenames."""
|
|
if not backup_dir.is_dir():
|
|
return []
|
|
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
|
|
deleted = []
|
|
for old in files[keep:]:
|
|
old.unlink()
|
|
deleted.append(old.name)
|
|
if deleted:
|
|
_LOGGER.info("Cleaned up %d old backup(s): %s", len(deleted), deleted)
|
|
return deleted
|
|
|
|
|
|
def list_backup_files(backup_dir: Path) -> list[dict[str, Any]]:
|
|
"""List backup files in the directory with metadata."""
|
|
if not backup_dir.is_dir():
|
|
return []
|
|
files = sorted(backup_dir.glob("backup-*.json"), key=lambda f: f.name, reverse=True)
|
|
result = []
|
|
for f in files:
|
|
stat = f.stat()
|
|
result.append({
|
|
"filename": f.name,
|
|
"size": stat.st_size,
|
|
"created_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
|
|
})
|
|
return result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Validate
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def validate_backup(raw: dict[str, Any]) -> ValidateResult:
|
|
"""Validate a backup file dict without importing. Returns summary."""
|
|
warnings: list[str] = []
|
|
errors: list[str] = []
|
|
|
|
fmt = raw.get("format")
|
|
if fmt != "notify-bridge-backup":
|
|
errors.append(f"Unknown format: {fmt}")
|
|
return ValidateResult(valid=False, errors=errors)
|
|
|
|
version = raw.get("version", 0)
|
|
if version > 1:
|
|
errors.append(f"Unsupported backup version: {version} (max supported: 1)")
|
|
return ValidateResult(valid=False, version=version, errors=errors)
|
|
|
|
secrets_mode = raw.get("secrets_mode", "exclude")
|
|
if secrets_mode in ("exclude", "masked"):
|
|
warnings.append(
|
|
f"Backup was exported with secrets_mode={secrets_mode}. "
|
|
"Imported entities will have empty/placeholder secrets that need manual update."
|
|
)
|
|
|
|
try:
|
|
backup = BackupFile.model_validate(raw)
|
|
except Exception as e:
|
|
errors.append(f"Schema validation failed: {e}")
|
|
return ValidateResult(valid=False, version=version, errors=errors)
|
|
|
|
counts: dict[str, int] = {}
|
|
d = backup.data
|
|
for cat in ("providers", "telegram_bots", "matrix_bots", "email_bots",
|
|
"targets", "tracking_configs", "template_configs",
|
|
"command_configs", "command_template_configs",
|
|
"notification_trackers", "command_trackers", "actions",
|
|
"app_settings"):
|
|
items = getattr(d, cat, [])
|
|
if items:
|
|
counts[cat] = len(items)
|
|
|
|
return ValidateResult(
|
|
valid=True, version=version,
|
|
entity_counts=counts, warnings=warnings, errors=errors,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Import
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def import_backup(
|
|
session: AsyncSession,
|
|
user_id: int,
|
|
backup: BackupFile,
|
|
conflict_mode: ConflictMode = ConflictMode.SKIP,
|
|
) -> ImportResult:
|
|
"""Import a backup file into the database. Atomic — rolls back on error."""
|
|
result = ImportResult()
|
|
# Maps: category -> {old_id: new_id}
|
|
id_map: dict[str, dict[int, int]] = {}
|
|
d = backup.data
|
|
|
|
try:
|
|
# 1. App Settings (simple upsert)
|
|
for s in d.app_settings:
|
|
existing = await session.get(AppSetting, s.key)
|
|
if existing:
|
|
if conflict_mode == ConflictMode.SKIP:
|
|
result.skipped += 1
|
|
continue
|
|
elif conflict_mode == ConflictMode.OVERWRITE:
|
|
existing.value = s.value
|
|
session.add(existing)
|
|
result.overwritten += 1
|
|
else: # rename — not applicable for settings, just skip
|
|
result.skipped += 1
|
|
continue
|
|
else:
|
|
session.add(AppSetting(key=s.key, value=s.value))
|
|
result.created += 1
|
|
|
|
# 2. Telegram Bots
|
|
id_map["telegram_bots"] = {}
|
|
for b in d.telegram_bots:
|
|
name = await _resolve_name(
|
|
session, TelegramBot, b.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
new_bot = TelegramBot(
|
|
user_id=user_id, name=name, token=b.token, icon=b.icon,
|
|
bot_username=b.bot_username, update_mode=b.update_mode,
|
|
)
|
|
session.add(new_bot)
|
|
await session.flush()
|
|
id_map["telegram_bots"][b.id] = new_bot.id
|
|
|
|
# 3. Matrix Bots
|
|
id_map["matrix_bots"] = {}
|
|
for b in d.matrix_bots:
|
|
name = await _resolve_name(
|
|
session, MatrixBot, b.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
new_bot = MatrixBot(
|
|
user_id=user_id, name=name, icon=b.icon,
|
|
homeserver_url=b.homeserver_url, access_token=b.access_token,
|
|
display_name=b.display_name,
|
|
)
|
|
session.add(new_bot)
|
|
await session.flush()
|
|
id_map["matrix_bots"][b.id] = new_bot.id
|
|
|
|
# 4. Email Bots
|
|
id_map["email_bots"] = {}
|
|
for b in d.email_bots:
|
|
name = await _resolve_name(
|
|
session, EmailBot, b.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
new_bot = EmailBot(
|
|
user_id=user_id, name=name, icon=b.icon, email=b.email,
|
|
smtp_host=b.smtp_host, smtp_port=b.smtp_port,
|
|
smtp_username=b.smtp_username, smtp_password=b.smtp_password,
|
|
smtp_use_tls=b.smtp_use_tls,
|
|
)
|
|
session.add(new_bot)
|
|
await session.flush()
|
|
id_map["email_bots"][b.id] = new_bot.id
|
|
|
|
# 5. Providers
|
|
id_map["providers"] = {}
|
|
for p in d.providers:
|
|
name = await _resolve_name(
|
|
session, ServiceProvider, p.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
try:
|
|
safe_cfg = _sanitize_config(p.config or {})
|
|
except ValueError as exc:
|
|
result.warnings.append(f"Skipped provider '{p.name}': {exc}")
|
|
continue
|
|
new_p = ServiceProvider(
|
|
user_id=user_id, type=p.type, name=name,
|
|
icon=p.icon, config=safe_cfg,
|
|
)
|
|
session.add(new_p)
|
|
await session.flush()
|
|
id_map["providers"][p.id] = new_p.id
|
|
|
|
# 6. Tracking Configs
|
|
id_map["tracking_configs"] = {}
|
|
for tc in d.tracking_configs:
|
|
name = await _resolve_name(
|
|
session, TrackingConfig, tc.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
new_tc = TrackingConfig(
|
|
user_id=user_id, provider_type=tc.provider_type,
|
|
name=name, icon=tc.icon,
|
|
)
|
|
# Apply all tracked fields
|
|
for field_name, value in tc.fields.items():
|
|
if hasattr(new_tc, field_name):
|
|
setattr(new_tc, field_name, value)
|
|
session.add(new_tc)
|
|
await session.flush()
|
|
id_map["tracking_configs"][tc.id] = new_tc.id
|
|
|
|
# 7. Template Configs + Slots
|
|
id_map["template_configs"] = {}
|
|
for tc in d.template_configs:
|
|
name = await _resolve_name_template(
|
|
session, TemplateConfig, tc.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
new_tc = TemplateConfig(
|
|
user_id=user_id, provider_type=tc.provider_type,
|
|
name=name, description=tc.description, icon=tc.icon,
|
|
locale=tc.locale, date_format=tc.date_format,
|
|
date_only_format=tc.date_only_format,
|
|
)
|
|
session.add(new_tc)
|
|
await session.flush()
|
|
id_map["template_configs"][tc.id] = new_tc.id
|
|
for s in tc.slots:
|
|
session.add(TemplateSlot(
|
|
config_id=new_tc.id, slot_name=s.slot_name,
|
|
locale=s.locale, template=s.template,
|
|
))
|
|
result.created += len(tc.slots)
|
|
|
|
# 8. Command Template Configs + Slots
|
|
id_map["command_template_configs"] = {}
|
|
for ctc in d.command_template_configs:
|
|
name = await _resolve_name_template(
|
|
session, CommandTemplateConfig, ctc.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
new_ctc = CommandTemplateConfig(
|
|
user_id=user_id, provider_type=ctc.provider_type,
|
|
name=name, description=ctc.description, icon=ctc.icon,
|
|
locale=ctc.locale,
|
|
)
|
|
session.add(new_ctc)
|
|
await session.flush()
|
|
id_map["command_template_configs"][ctc.id] = new_ctc.id
|
|
for s in ctc.slots:
|
|
session.add(CommandTemplateSlot(
|
|
config_id=new_ctc.id, slot_name=s.slot_name,
|
|
locale=s.locale, template=s.template,
|
|
))
|
|
result.created += len(ctc.slots)
|
|
|
|
# 9. Command Configs
|
|
id_map["command_configs"] = {}
|
|
for cc in d.command_configs:
|
|
name = await _resolve_name(
|
|
session, CommandConfig, cc.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
ctc_id = _map_id(id_map, "command_template_configs", cc.command_template_config_id)
|
|
new_cc = CommandConfig(
|
|
user_id=user_id, provider_type=cc.provider_type,
|
|
name=name, icon=cc.icon,
|
|
enabled_commands=cc.enabled_commands,
|
|
response_mode=cc.response_mode,
|
|
default_count=cc.default_count,
|
|
rate_limits=cc.rate_limits,
|
|
command_template_config_id=ctc_id,
|
|
)
|
|
session.add(new_cc)
|
|
await session.flush()
|
|
id_map["command_configs"][cc.id] = new_cc.id
|
|
|
|
# 10. Targets + Receivers
|
|
id_map["targets"] = {}
|
|
for tgt in d.targets:
|
|
name = await _resolve_name(
|
|
session, NotificationTarget, tgt.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
try:
|
|
safe_tgt_cfg = _sanitize_config(tgt.config or {})
|
|
except ValueError as exc:
|
|
result.warnings.append(f"Skipped target '{tgt.name}': {exc}")
|
|
continue
|
|
new_tgt = NotificationTarget(
|
|
user_id=user_id, type=tgt.type, name=name,
|
|
icon=tgt.icon, config=safe_tgt_cfg,
|
|
chat_action=tgt.chat_action,
|
|
)
|
|
session.add(new_tgt)
|
|
await session.flush()
|
|
id_map["targets"][tgt.id] = new_tgt.id
|
|
for r in tgt.receivers:
|
|
try:
|
|
safe_r_cfg = _sanitize_config(r.config or {})
|
|
except ValueError as exc:
|
|
result.warnings.append(f"Skipped receiver in '{tgt.name}': {exc}")
|
|
continue
|
|
session.add(TargetReceiver(
|
|
target_id=new_tgt.id, name=r.name, config=safe_r_cfg,
|
|
receiver_key=r.receiver_key, locale=r.locale,
|
|
enabled=r.enabled,
|
|
))
|
|
result.created += len(tgt.receivers)
|
|
|
|
# 11. Notification Trackers + Tracker-Targets
|
|
for nt in d.notification_trackers:
|
|
provider_id = _map_id(id_map, "providers", nt.provider_id)
|
|
if provider_id is None:
|
|
result.warnings.append(
|
|
f"Skipped tracker '{nt.name}': provider {nt.provider_id} not found"
|
|
)
|
|
result.skipped += 1
|
|
continue
|
|
name = await _resolve_name(
|
|
session, NotificationTracker, nt.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
new_nt = NotificationTracker(
|
|
user_id=user_id, provider_id=provider_id,
|
|
name=name, icon=nt.icon, collection_ids=nt.collection_ids,
|
|
filters=nt.filters, scan_interval=nt.scan_interval,
|
|
batch_duration=nt.batch_duration,
|
|
default_tracking_config_id=_map_id(id_map, "tracking_configs", nt.default_tracking_config_id),
|
|
default_template_config_id=_map_id(id_map, "template_configs", nt.default_template_config_id),
|
|
enabled=nt.enabled,
|
|
)
|
|
session.add(new_nt)
|
|
await session.flush()
|
|
for tt in nt.targets:
|
|
target_id = _map_id(id_map, "targets", tt.target_id)
|
|
if target_id is None:
|
|
result.warnings.append(
|
|
f"Skipped tracker-target link in '{nt.name}': target {tt.target_id} not found"
|
|
)
|
|
continue
|
|
session.add(NotificationTrackerTarget(
|
|
tracker_id=new_nt.id,
|
|
target_id=target_id,
|
|
tracking_config_id=_map_id(id_map, "tracking_configs", tt.tracking_config_id),
|
|
template_config_id=_map_id(id_map, "template_configs", tt.template_config_id),
|
|
enabled=tt.enabled,
|
|
quiet_hours_start=tt.quiet_hours_start,
|
|
quiet_hours_end=tt.quiet_hours_end,
|
|
))
|
|
result.created += 1
|
|
|
|
# 12. Command Trackers + Listeners
|
|
for ct in d.command_trackers:
|
|
provider_id = _map_id(id_map, "providers", ct.provider_id)
|
|
if provider_id is None:
|
|
result.warnings.append(
|
|
f"Skipped command tracker '{ct.name}': provider {ct.provider_id} not found"
|
|
)
|
|
result.skipped += 1
|
|
continue
|
|
cc_id = _map_id(id_map, "command_configs", ct.command_config_id)
|
|
if cc_id is None:
|
|
result.warnings.append(
|
|
f"Skipped command tracker '{ct.name}': command config {ct.command_config_id} not found"
|
|
)
|
|
result.skipped += 1
|
|
continue
|
|
name = await _resolve_name(
|
|
session, CommandTracker, ct.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
new_ct = CommandTracker(
|
|
user_id=user_id, provider_id=provider_id,
|
|
command_config_id=cc_id, name=name, icon=ct.icon,
|
|
enabled=ct.enabled,
|
|
)
|
|
session.add(new_ct)
|
|
await session.flush()
|
|
for lis in ct.listeners:
|
|
# Map listener_id based on listener_type
|
|
mapped_listener_id = lis.listener_id
|
|
if lis.listener_type == "telegram_bot":
|
|
mapped_listener_id = _map_id(id_map, "telegram_bots", lis.listener_id) or lis.listener_id
|
|
session.add(CommandTrackerListener(
|
|
command_tracker_id=new_ct.id,
|
|
listener_type=lis.listener_type,
|
|
listener_id=mapped_listener_id,
|
|
))
|
|
result.created += 1
|
|
|
|
# 13. Actions + Rules
|
|
for a in d.actions:
|
|
provider_id = _map_id(id_map, "providers", a.provider_id)
|
|
if provider_id is None:
|
|
result.warnings.append(
|
|
f"Skipped action '{a.name}': provider {a.provider_id} not found"
|
|
)
|
|
result.skipped += 1
|
|
continue
|
|
name = await _resolve_name(
|
|
session, Action, a.name, user_id, conflict_mode, result,
|
|
)
|
|
if name is None:
|
|
continue
|
|
new_a = Action(
|
|
user_id=user_id, provider_id=provider_id, name=name,
|
|
icon=a.icon, action_type=a.action_type, config=a.config,
|
|
schedule_type=a.schedule_type,
|
|
schedule_interval=a.schedule_interval,
|
|
schedule_cron=a.schedule_cron, enabled=False, # always import disabled
|
|
)
|
|
session.add(new_a)
|
|
await session.flush()
|
|
for r in a.rules:
|
|
session.add(ActionRule(
|
|
action_id=new_a.id, name=r.name,
|
|
rule_config=r.rule_config, enabled=r.enabled,
|
|
order=r.order,
|
|
))
|
|
result.created += len(a.rules)
|
|
|
|
await session.commit()
|
|
except Exception as e:
|
|
await session.rollback()
|
|
_LOGGER.error("Backup import failed: %s", e)
|
|
result.errors.append(f"Import failed: {e}")
|
|
|
|
return result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _map_id(
|
|
id_map: dict[str, dict[int, int]],
|
|
category: str,
|
|
old_id: int | None,
|
|
) -> int | None:
|
|
"""Resolve an old ID to a new ID via the id_map. Returns None if not found."""
|
|
if old_id is None:
|
|
return None
|
|
return id_map.get(category, {}).get(old_id)
|
|
|
|
|
|
async def _resolve_name(
|
|
session: AsyncSession,
|
|
model: type,
|
|
name: str,
|
|
user_id: int,
|
|
conflict_mode: ConflictMode,
|
|
result: ImportResult,
|
|
) -> str | None:
|
|
"""Check for name conflict and return the resolved name, or None to skip."""
|
|
existing = await session.exec(
|
|
select(model).where(
|
|
model.name == name,
|
|
model.user_id == user_id,
|
|
)
|
|
)
|
|
found = existing.first()
|
|
if found is None:
|
|
result.created += 1
|
|
return name
|
|
|
|
if conflict_mode == ConflictMode.SKIP:
|
|
result.skipped += 1
|
|
return None
|
|
elif conflict_mode == ConflictMode.RENAME:
|
|
result.created += 1
|
|
return f"{name} (imported)"
|
|
else: # OVERWRITE — delete existing, create new
|
|
await session.delete(found)
|
|
await session.flush()
|
|
result.overwritten += 1
|
|
return name
|
|
|
|
|
|
async def _resolve_name_template(
|
|
session: AsyncSession,
|
|
model: type,
|
|
name: str,
|
|
user_id: int,
|
|
conflict_mode: ConflictMode,
|
|
result: ImportResult,
|
|
) -> str | None:
|
|
"""Like _resolve_name but for template models where user_id can be 0 for system."""
|
|
existing = await session.exec(
|
|
select(model).where(
|
|
model.name == name,
|
|
model.user_id == user_id,
|
|
)
|
|
)
|
|
found = existing.first()
|
|
if found is None:
|
|
result.created += 1
|
|
return name
|
|
|
|
if conflict_mode == ConflictMode.SKIP:
|
|
result.skipped += 1
|
|
return None
|
|
elif conflict_mode == ConflictMode.RENAME:
|
|
result.created += 1
|
|
return f"{name} (imported)"
|
|
else:
|
|
await session.delete(found)
|
|
await session.flush()
|
|
result.overwritten += 1
|
|
return name
|