"""Database seed functions — create/update system-owned defaults on startup.""" import json import logging from datetime import datetime, timezone from sqlalchemy import text from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from .engine import get_engine from .models import ( CommandConfig, CommandTemplateConfig, CommandTemplateSlot, TemplateConfig, TemplateSlot, TrackingConfig, ) _LOGGER = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- async def _seed_provider_template( session: AsyncSession, provider_type: str, label: str, ) -> None: """Seed templates for a single provider type across all locales.""" from notify_bridge_core.templates.defaults import load_default_templates result = await session.exec( select(TemplateConfig).where( TemplateConfig.user_id == 0, TemplateConfig.provider_type == provider_type, ) ) configs = result.all() existing_locales = { (c.locale if c.locale else ("ru" if "(RU)" in c.name else "en")): c for c in configs } for locale in ("en", "ru"): slots = load_default_templates(locale, provider_type=provider_type) if not slots: continue if locale not in existing_locales: now = datetime.now(timezone.utc).isoformat() name = f"Default {label} ({locale.upper()})" desc = f"Default {label} templates ({locale.upper()})" # Get column names to build INSERT with defaults for legacy cols col_info = (await session.execute( text("PRAGMA table_info(template_config)") )).fetchall() col_names = [c[1] for c in col_info if c[1] != "id"] values: dict[str, object] = {} for col in col_names: if col == "user_id": values[col] = 0 elif col == "provider_type": values[col] = provider_type elif col == "name": values[col] = name elif col == "description": values[col] = desc elif col == "created_at": values[col] = now elif col == "date_format": values[col] = "%d.%m.%Y, %H:%M UTC" elif col == "date_only_format": values[col] = "%d.%m.%Y" elif col == "locale": values[col] = locale else: values[col] = "" # empty string for legacy columns cols_str = ", ".join(values.keys()) placeholders = ", ".join(f":{k}" for k in values.keys()) await session.execute( text(f"INSERT INTO template_config ({cols_str}) VALUES ({placeholders})"), values, ) config_id = (await session.execute( text("SELECT last_insert_rowid()") )).scalar() for slot_name, template_text in slots.items(): session.add(TemplateSlot( config_id=config_id, slot_name=slot_name, template=template_text, )) else: config = existing_locales[locale] for slot_name, template_text in slots.items(): slot_result = await session.exec( select(TemplateSlot).where( TemplateSlot.config_id == config.id, TemplateSlot.slot_name == slot_name, ) ) existing = slot_result.first() if existing: existing.template = template_text session.add(existing) else: session.add(TemplateSlot( config_id=config.id, slot_name=slot_name, template=template_text, )) async def _seed_provider_command_template( session: AsyncSession, provider_type: str, name: str, description: str, ) -> None: """Seed command templates for a single provider type across all locales.""" from notify_bridge_core.templates.command_defaults import load_default_command_templates result = await session.exec( select(CommandTemplateConfig).where( CommandTemplateConfig.user_id == 0, CommandTemplateConfig.provider_type == provider_type, ) ) configs = result.all() if not configs: config = CommandTemplateConfig( user_id=0, provider_type=provider_type, name=name, description=description, ) session.add(config) await session.flush() else: config = configs[0] for locale in ("en", "ru"): slots = load_default_command_templates(locale, provider_type=provider_type) if not slots: continue for slot_name, template_text in slots.items(): slot_result = await session.exec( select(CommandTemplateSlot).where( CommandTemplateSlot.config_id == config.id, CommandTemplateSlot.slot_name == slot_name, CommandTemplateSlot.locale == locale, ) ) existing = slot_result.first() if existing: existing.template = template_text session.add(existing) else: session.add(CommandTemplateSlot( config_id=config.id, slot_name=slot_name, locale=locale, template=template_text, )) # --------------------------------------------------------------------------- # Top-level seed functions # --------------------------------------------------------------------------- async def _seed_default_templates() -> None: """Seed or update default (system-owned) templates on startup. Uses TemplateSlot child rows for template content. """ engine = get_engine() async with AsyncSession(engine) as session: await _seed_provider_template(session, "immich", "Immich") await _seed_provider_template(session, "gitea", "Gitea") await _seed_provider_template(session, "planka", "Planka") await _seed_provider_template(session, "scheduler", "Scheduler") await session.commit() async def _seed_default_command_templates() -> None: """Seed or update default command response templates on startup. Creates a single config per provider with locale-aware slots (each slot has an EN and RU version stored as separate rows). """ engine = get_engine() async with AsyncSession(engine) as session: await _seed_provider_command_template( session, "immich", "Default Commands", "Default Immich command templates", ) await _seed_provider_command_template( session, "gitea", "Default Gitea Commands", "Default Gitea command templates", ) await _seed_provider_command_template( session, "planka", "Default Planka Commands", "Default Planka command templates", ) await session.commit() async def _seed_default_tracking_configs() -> None: """Seed system-owned default tracking configs for each provider type.""" engine = get_engine() async with AsyncSession(engine) as session: result = await session.exec( select(TrackingConfig).where(TrackingConfig.user_id == 0) ) existing = {c.provider_type: c for c in result.all()} defaults = [ { "provider_type": "gitea", "name": "Default Gitea", "track_push": True, "track_issue_opened": True, "track_issue_closed": True, "track_issue_commented": False, "track_pr_opened": True, "track_pr_closed": True, "track_pr_merged": True, "track_pr_commented": False, "track_release_published": True, }, { "provider_type": "planka", "name": "Default Planka", "track_card_created": True, "track_card_updated": False, "track_card_moved": True, "track_card_deleted": False, "track_card_commented": True, "track_comment_updated": False, "track_board_created": True, "track_board_updated": False, "track_board_deleted": True, "track_list_created": False, "track_list_updated": False, "track_list_deleted": False, "track_attachment_created": True, "track_card_label_added": False, "track_task_completed": True, }, { "provider_type": "scheduler", "name": "Default Scheduler", "track_scheduled_message": True, }, ] for cfg in defaults: ptype = cfg["provider_type"] if ptype in existing: continue session.add(TrackingConfig(user_id=0, **cfg)) await session.commit() async def _seed_default_command_configs() -> None: """Seed system-owned default command configs for each provider type.""" engine = get_engine() async with AsyncSession(engine) as session: result = await session.exec( select(CommandConfig).where(CommandConfig.user_id == 0) ) existing = {c.provider_type: c for c in result.all()} # Find system command template configs to link tmpl_result = await session.exec( select(CommandTemplateConfig).where(CommandTemplateConfig.user_id == 0) ) tmpl_by_type = {t.provider_type: t.id for t in tmpl_result.all()} defaults = [ { "provider_type": "immich", "name": "Default Immich", "enabled_commands": [ "help", "status", "albums", "events", "latest", "random", "favorites", "summary", "memory", ], "response_mode": "media", "default_count": 5, "rate_limits": {"search": 30, "default": 10}, }, { "provider_type": "gitea", "name": "Default Gitea", "enabled_commands": [ "help", "status", "repos", "issues", "prs", "commits", ], "response_mode": "text", "default_count": 10, "rate_limits": {"api": 15, "default": 10}, }, { "provider_type": "planka", "name": "Default Planka", "enabled_commands": [ "help", "status", "boards", "cards", "lists", ], "response_mode": "text", "default_count": 10, "rate_limits": {"api": 15, "default": 10}, }, ] for cfg in defaults: ptype = cfg["provider_type"] if ptype in existing: continue cmd_tmpl_id = tmpl_by_type.get(ptype) await session.execute( text( "INSERT INTO command_config " "(user_id, provider_type, name, icon, enabled_commands, locale, " "response_mode, default_count, rate_limits, command_template_config_id, created_at) " "VALUES (:uid, :pt, :name, :icon, :cmds, :locale, :rm, :dc, :rl, :ctid, :ca)" ), { "uid": 0, "pt": ptype, "name": cfg["name"], "icon": "", "cmds": json.dumps(cfg["enabled_commands"]), "locale": "en", "rm": cfg["response_mode"], "dc": cfg["default_count"], "rl": json.dumps(cfg["rate_limits"]), "ctid": cmd_tmpl_id, "ca": datetime.now(timezone.utc).isoformat(), }, ) await session.commit() # --------------------------------------------------------------------------- # Public entry point # --------------------------------------------------------------------------- async def seed_all() -> None: """Run all seed functions in order.""" await _seed_default_templates() await _seed_default_command_templates() await _seed_default_tracking_configs() await _seed_default_command_configs()