feat: provider-strict configs, slot-based templates, broadcast targets, email bots, command templates

Major architectural improvements:
- Provider-type enforcement: configs validated against provider type at assignment
- TemplateConfig migrated to slot-based pattern (TemplateSlot child table)
- Broadcast targets: TargetReceiver child table for multi-receiver dispatch
- EmailBot: first-class email sender entity with SMTP config, test connection
- CommandTemplateConfig: generic slot-based command response templates
- Provider capability registry: dynamic slot/event/command definitions per provider
- CommandTracker play/pause button matches NotificationTracker style
This commit is contained in:
2026-03-21 16:33:24 +03:00
parent 371ea70756
commit 846d480d38
27 changed files with 2355 additions and 205 deletions
@@ -0,0 +1,230 @@
"""Command template configuration CRUD API routes.
Template content is stored in CommandTemplateSlot child rows (one per slot_name).
Slot names correspond to command names (e.g. 'status', 'help', 'albums').
"""
import logging
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from jinja2.sandbox import SandboxedEnvironment
from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import CommandTemplateConfig, CommandTemplateSlot, User
_LOGGER = logging.getLogger(__name__)
router = APIRouter(prefix="/api/command-template-configs", tags=["command-template-configs"])
class CommandTemplateConfigCreate(BaseModel):
provider_type: str
name: str
description: str | None = None
icon: str | None = None
slots: dict[str, str] = {} # slot_name -> template text
class CommandTemplateConfigUpdate(BaseModel):
name: str | None = None
description: str | None = None
icon: str | None = None
slots: dict[str, str] | None = None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, str]:
result = await session.exec(
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == config_id)
)
return {s.slot_name: s.template for s in result.all()}
async def _save_slots(session: AsyncSession, config_id: int, slots: dict[str, str]) -> None:
for slot_name, template_text in slots.items():
result = await session.exec(
select(CommandTemplateSlot).where(
CommandTemplateSlot.config_id == config_id,
CommandTemplateSlot.slot_name == slot_name,
)
)
existing = result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(CommandTemplateSlot(
config_id=config_id,
slot_name=slot_name,
template=template_text,
))
async def _response(session: AsyncSession, c: CommandTemplateConfig) -> dict[str, Any]:
slots = await _load_slots(session, c.id)
return {
"id": c.id,
"user_id": c.user_id,
"provider_type": c.provider_type,
"name": c.name,
"description": c.description,
"icon": c.icon,
"slots": slots,
"created_at": c.created_at.isoformat(),
}
async def _get(session: AsyncSession, config_id: int, user_id: int) -> CommandTemplateConfig:
config = await session.get(CommandTemplateConfig, config_id)
if not config or (config.user_id != user_id and config.user_id != 0):
raise HTTPException(status_code=404, detail="Command template config not found")
return config
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@router.get("")
async def list_configs(
provider_type: str | None = None,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
from sqlalchemy import or_
query = select(CommandTemplateConfig).where(
or_(CommandTemplateConfig.user_id == user.id, CommandTemplateConfig.user_id == 0)
)
if provider_type:
query = query.where(CommandTemplateConfig.provider_type == provider_type)
result = await session.exec(query)
return [await _response(session, c) for c in result.all()]
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_config(
body: CommandTemplateConfigCreate,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
config = CommandTemplateConfig(
user_id=user.id,
provider_type=body.provider_type,
name=body.name,
description=body.description or "",
icon=body.icon or "",
)
session.add(config)
await session.flush()
if body.slots:
await _save_slots(session, config.id, body.slots)
await session.commit()
await session.refresh(config)
return await _response(session, config)
@router.get("/{config_id}")
async def get_config(
config_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
config = await _get(session, config_id, user.id)
return await _response(session, config)
@router.put("/{config_id}")
async def update_config(
config_id: int,
body: CommandTemplateConfigUpdate,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
config = await _get(session, config_id, user.id)
for field, value in body.model_dump(exclude_unset=True, exclude={"slots"}).items():
if value is not None:
setattr(config, field, value)
session.add(config)
if body.slots is not None:
await _save_slots(session, config.id, body.slots)
await session.commit()
await session.refresh(config)
return await _response(session, config)
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_config(
config_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
config = await _get(session, config_id, user.id)
slot_result = await session.exec(
select(CommandTemplateSlot).where(CommandTemplateSlot.config_id == config.id)
)
for slot in slot_result.all():
await session.delete(slot)
await session.delete(config)
await session.commit()
class PreviewRequest(BaseModel):
template: str
@router.post("/preview-raw")
async def preview_raw(
body: PreviewRequest,
user: User = Depends(get_current_user),
):
"""Render arbitrary Jinja2 template text with sample command context."""
sample_ctx = {
"trackers_active": 2,
"trackers_total": 3,
"total_albums": 5,
"last_event": "2026-03-19 14:30",
"albums": [
{"name": "Family Photos", "asset_count": 142, "url": "https://example.com/albums/1"},
{"name": "Vacation 2025", "asset_count": 87, "url": "https://example.com/albums/2"},
],
"events": [
{"type": "assets_added", "album": "Family Photos", "count": 3, "date": "2026-03-19 14:30"},
{"type": "assets_removed", "album": "Vacation 2025", "count": 1, "date": "2026-03-19 12:00"},
],
"people": ["Alice", "Bob", "Charlie"],
"assets": [
{"filename": "IMG_001.jpg", "type": "IMAGE", "created_at": "2026-03-19T14:30:00"},
{"filename": "VID_002.mp4", "type": "VIDEO", "created_at": "2026-03-19T15:00:00"},
],
"search_query": "sunset",
"search_results_count": 5,
"command": "status",
"bot_name": "NotifyBridgeBot",
"locale": "en",
}
try:
env = SandboxedEnvironment(autoescape=False)
env.from_string(body.template)
except TemplateSyntaxError as e:
return {"rendered": None, "error": e.message, "error_line": e.lineno}
try:
strict_env = SandboxedEnvironment(autoescape=False, undefined=StrictUndefined)
tmpl = strict_env.from_string(body.template)
rendered = tmpl.render(**sample_ctx)
return {"rendered": rendered}
except UndefinedError as e:
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
except Exception as e:
return {"rendered": None, "error": str(e), "error_line": None}
@@ -0,0 +1,148 @@
"""Email bot management API routes."""
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import EmailBot, User
_LOGGER = logging.getLogger(__name__)
router = APIRouter(prefix="/api/email-bots", tags=["email-bots"])
class EmailBotCreate(BaseModel):
name: str
icon: str = ""
email: str
smtp_host: str
smtp_port: int = 587
smtp_username: str = ""
smtp_password: str = ""
smtp_use_tls: bool = True
class EmailBotUpdate(BaseModel):
name: str | None = None
icon: str | None = None
email: str | None = None
smtp_host: str | None = None
smtp_port: int | None = None
smtp_username: str | None = None
smtp_password: str | None = None
smtp_use_tls: bool | None = None
@router.get("")
async def list_email_bots(
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
result = await session.exec(
select(EmailBot).where(EmailBot.user_id == user.id)
)
return [_response(b) for b in result.all()]
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_email_bot(
body: EmailBotCreate,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
bot = EmailBot(user_id=user.id, **body.model_dump())
session.add(bot)
await session.commit()
await session.refresh(bot)
return _response(bot)
@router.get("/{bot_id}")
async def get_email_bot(
bot_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
bot = await _get_user_bot(session, bot_id, user.id)
return _response(bot)
@router.put("/{bot_id}")
async def update_email_bot(
bot_id: int,
body: EmailBotUpdate,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
bot = await _get_user_bot(session, bot_id, user.id)
for field, value in body.model_dump(exclude_unset=True).items():
setattr(bot, field, value)
session.add(bot)
await session.commit()
await session.refresh(bot)
return _response(bot)
@router.delete("/{bot_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_email_bot(
bot_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
bot = await _get_user_bot(session, bot_id, user.id)
await session.delete(bot)
await session.commit()
@router.post("/{bot_id}/test")
async def test_email_bot(
bot_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Send a test email to the bot's own address to verify SMTP connection."""
bot = await _get_user_bot(session, bot_id, user.id)
from notify_bridge_core.notifications.email.client import EmailClient, SmtpConfig
client = EmailClient(SmtpConfig(
host=bot.smtp_host,
port=bot.smtp_port,
username=bot.smtp_username,
password=bot.smtp_password,
from_address=bot.email,
from_name=bot.name,
use_tls=bot.smtp_use_tls,
))
result = await client.send(
to_email=bot.email,
subject="Notify Bridge — Test Connection",
body_text="This is a test email from Notify Bridge. Your SMTP settings are working correctly.",
)
return result
def _response(bot: EmailBot) -> dict:
return {
"id": bot.id,
"name": bot.name,
"icon": bot.icon,
"email": bot.email,
"smtp_host": bot.smtp_host,
"smtp_port": bot.smtp_port,
"smtp_username": bot.smtp_username,
"smtp_password": "***" if bot.smtp_password else "",
"smtp_use_tls": bot.smtp_use_tls,
"created_at": bot.created_at.isoformat(),
}
async def _get_user_bot(session: AsyncSession, bot_id: int, user_id: int) -> EmailBot:
bot = await session.get(EmailBot, bot_id)
if not bot or bot.user_id != user_id:
raise HTTPException(status_code=404, detail="Email bot not found")
return bot
@@ -16,6 +16,7 @@ from ..database.models import (
NotificationTrackerTarget,
ServiceProvider,
TemplateConfig,
TemplateSlot,
TrackingConfig,
User,
)
@@ -65,7 +66,7 @@ async def create_notification_tracker_target(
session: AsyncSession = Depends(get_session),
):
"""Link a target to a notification tracker with per-link configuration."""
await _get_user_tracker(session, tracker_id, user.id)
tracker = await _get_user_tracker(session, tracker_id, user.id)
# Validate target exists and belongs to user
target = await session.get(NotificationTarget, body.target_id)
@@ -85,15 +86,30 @@ async def create_notification_tracker_target(
detail="Target is already linked to this tracker",
)
# Validate config ownership
# Resolve tracker's provider type for config validation
provider = await session.get(ServiceProvider, tracker.provider_id)
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# Validate config ownership + provider type match
if body.tracking_config_id:
tc = await session.get(TrackingConfig, body.tracking_config_id)
if not tc or tc.user_id != user.id:
raise HTTPException(status_code=404, detail="Tracking config not found")
if tc.provider_type != provider.type:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Tracking config provider type '{tc.provider_type}' does not match tracker provider '{provider.type}'",
)
if body.template_config_id:
tpc = await session.get(TemplateConfig, body.template_config_id)
if not tpc or (tpc.user_id != user.id and tpc.user_id != 0):
raise HTTPException(status_code=404, detail="Template config not found")
if tpc.provider_type != provider.type:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Template config provider type '{tpc.provider_type}' does not match tracker provider '{provider.type}'",
)
tt = NotificationTrackerTarget(tracker_id=tracker_id, **body.model_dump())
session.add(tt)
@@ -111,21 +127,34 @@ async def update_notification_tracker_target(
session: AsyncSession = Depends(get_session),
):
"""Update a notification tracker-target link's configuration."""
await _get_user_tracker(session, tracker_id, user.id)
tracker = await _get_user_tracker(session, tracker_id, user.id)
tt = await session.get(NotificationTrackerTarget, tracker_target_id)
if not tt or tt.tracker_id != tracker_id:
raise HTTPException(status_code=404, detail="Tracker-target link not found")
provider = await session.get(ServiceProvider, tracker.provider_id)
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
updates = body.model_dump(exclude_unset=True)
# Validate config ownership if being changed
# Validate config ownership + provider type match if being changed
if "tracking_config_id" in updates and updates["tracking_config_id"]:
tc = await session.get(TrackingConfig, updates["tracking_config_id"])
if not tc or tc.user_id != user.id:
raise HTTPException(status_code=404, detail="Tracking config not found")
if tc.provider_type != provider.type:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Tracking config provider type '{tc.provider_type}' does not match tracker provider '{provider.type}'",
)
if "template_config_id" in updates and updates["template_config_id"]:
tpc = await session.get(TemplateConfig, updates["template_config_id"])
if not tpc or (tpc.user_id != user.id and tpc.user_id != 0):
raise HTTPException(status_code=404, detail="Template config not found")
if tpc.provider_type != provider.type:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Template config provider type '{tpc.provider_type}' does not match tracker provider '{provider.type}'",
)
for field, value in updates.items():
setattr(tt, field, value)
@@ -183,15 +212,24 @@ async def test_notification_tracker_target(
# For periodic/scheduled/memory — fetch real data from provider
template_config = None
template_str = ""
if tt.template_config_id:
template_config = await session.get(TemplateConfig, tt.template_config_id)
slot_map = {
"periodic": "periodic_summary_message",
"scheduled": "scheduled_assets_message",
"memory": "memory_mode_message",
}
template_str = getattr(template_config, slot_map[test_type], "") if template_config else ""
if template_config:
slot_map = {
"periodic": "periodic_summary_message",
"scheduled": "scheduled_assets_message",
"memory": "memory_mode_message",
}
slot_name = slot_map[test_type]
slot_result = await session.exec(
select(TemplateSlot).where(
TemplateSlot.config_id == template_config.id,
TemplateSlot.slot_name == slot_name,
)
)
slot = slot_result.first()
template_str = slot.template if slot else ""
# Load provider and tracker data eagerly before aiohttp context
provider = await session.get(ServiceProvider, tracker.provider_id)
@@ -94,6 +94,40 @@ async def create_provider(
return _provider_response(provider)
@router.get("/capabilities")
async def list_provider_capabilities():
"""List capabilities for all registered provider types."""
from notify_bridge_core.providers.capabilities import get_all_capabilities
result = {}
for pt, caps in get_all_capabilities().items():
result[pt] = {
"provider_type": caps.provider_type,
"display_name": caps.display_name,
"notification_slots": caps.notification_slots,
"command_slots": caps.command_slots,
"events": caps.events,
"commands": caps.commands,
}
return result
@router.get("/capabilities/{provider_type}")
async def get_provider_capabilities(provider_type: str):
"""Get capabilities for a provider type (events, slots, commands)."""
from notify_bridge_core.providers.capabilities import get_capabilities
caps = get_capabilities(provider_type)
if not caps:
raise HTTPException(status_code=404, detail=f"Unknown provider type: {provider_type}")
return {
"provider_type": caps.provider_type,
"display_name": caps.display_name,
"notification_slots": caps.notification_slots,
"command_slots": caps.command_slots,
"events": caps.events,
"commands": caps.commands,
}
@router.get("/{provider_id}")
async def get_provider(
provider_id: int,
@@ -0,0 +1,147 @@
"""Target receiver management API routes (nested under targets)."""
import logging
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import NotificationTarget, TargetReceiver, User
_LOGGER = logging.getLogger(__name__)
router = APIRouter(prefix="/api/targets/{target_id}/receivers", tags=["target-receivers"])
class ReceiverCreate(BaseModel):
name: str = ""
config: dict[str, Any] = {}
enabled: bool = True
class ReceiverUpdate(BaseModel):
name: str | None = None
config: dict[str, Any] | None = None
enabled: bool | None = None
def _receiver_key(target_type: str, config: dict[str, Any]) -> str:
"""Derive a unique key for deduplication from receiver config."""
if target_type == "telegram":
return str(config.get("chat_id", ""))
elif target_type == "webhook":
return config.get("url", "")
elif target_type == "email":
return config.get("email", "")
return ""
@router.get("")
async def list_receivers(
target_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
target = await _get_user_target(session, target_id, user.id)
result = await session.exec(
select(TargetReceiver).where(TargetReceiver.target_id == target.id)
)
return [_response(r) for r in result.all()]
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_receiver(
target_id: int,
body: ReceiverCreate,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
target = await _get_user_target(session, target_id, user.id)
key = _receiver_key(target.type, body.config)
if not key:
raise HTTPException(status_code=400, detail="Receiver config must include a delivery endpoint (chat_id, url, or email)")
# Check for duplicate
existing = await session.exec(
select(TargetReceiver).where(
TargetReceiver.target_id == target.id,
TargetReceiver.receiver_key == key,
)
)
if existing.first():
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Receiver already exists for this target")
receiver = TargetReceiver(
target_id=target.id,
name=body.name,
config=body.config,
receiver_key=key,
enabled=body.enabled,
)
session.add(receiver)
await session.commit()
await session.refresh(receiver)
return _response(receiver)
@router.put("/{receiver_id}")
async def update_receiver(
target_id: int,
receiver_id: int,
body: ReceiverUpdate,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
await _get_user_target(session, target_id, user.id)
receiver = await session.get(TargetReceiver, receiver_id)
if not receiver or receiver.target_id != target_id:
raise HTTPException(status_code=404, detail="Receiver not found")
for field, value in body.model_dump(exclude_unset=True).items():
setattr(receiver, field, value)
# Update receiver_key if config changed
if body.config is not None:
target = await session.get(NotificationTarget, target_id)
receiver.receiver_key = _receiver_key(target.type, receiver.config)
session.add(receiver)
await session.commit()
await session.refresh(receiver)
return _response(receiver)
@router.delete("/{receiver_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_receiver(
target_id: int,
receiver_id: int,
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
await _get_user_target(session, target_id, user.id)
receiver = await session.get(TargetReceiver, receiver_id)
if not receiver or receiver.target_id != target_id:
raise HTTPException(status_code=404, detail="Receiver not found")
await session.delete(receiver)
await session.commit()
def _response(r: TargetReceiver) -> dict:
return {
"id": r.id,
"target_id": r.target_id,
"name": r.name,
"config": dict(r.config),
"receiver_key": r.receiver_key,
"enabled": r.enabled,
"created_at": r.created_at.isoformat(),
}
async def _get_user_target(session: AsyncSession, target_id: int, user_id: int) -> NotificationTarget:
target = await session.get(NotificationTarget, target_id)
if not target or target.user_id != user_id:
raise HTTPException(status_code=404, detail="Target not found")
return target
@@ -10,7 +10,7 @@ from typing import Any
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import NotificationTarget, NotificationTrackerTarget, TelegramBot, TelegramChat, User
from ..database.models import NotificationTarget, NotificationTrackerTarget, TargetReceiver, TelegramBot, TelegramChat, User
from ..services.notifier import send_test_notification
_LOGGER = logging.getLogger(__name__)
@@ -61,7 +61,15 @@ async def list_targets(
if chat:
chat_names[f"{bot_id}_{chat_id}"] = chat.title or chat.username or ""
return [_target_response(t, chat_names) for t in targets]
# Load receiver counts
receiver_counts: dict[int, int] = {}
for tgt in targets:
recv_result = await session.exec(
select(TargetReceiver).where(TargetReceiver.target_id == tgt.id)
)
receiver_counts[tgt.id] = len(recv_result.all())
return [_target_response(t, chat_names, receiver_counts.get(t.id, 0)) for t in targets]
@router.post("", status_code=status.HTTP_201_CREATED)
@@ -71,10 +79,10 @@ async def create_target(
session: AsyncSession = Depends(get_session),
):
"""Create a new notification target."""
if body.type not in ("telegram", "webhook"):
if body.type not in ("telegram", "webhook", "email"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Type must be 'telegram' or 'webhook'",
detail="Type must be 'telegram', 'webhook', or 'email'",
)
target = NotificationTarget(
user_id=user.id,
@@ -124,7 +132,7 @@ async def delete_target(
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
"""Delete a notification target and its tracker links."""
"""Delete a notification target, its tracker links, and receivers."""
target = await _get_user_target(session, target_id, user.id)
# Delete associated tracker-target links
result = await session.exec(
@@ -132,6 +140,12 @@ async def delete_target(
)
for tt in result.all():
await session.delete(tt)
# Delete receivers
recv_result = await session.exec(
select(TargetReceiver).where(TargetReceiver.target_id == target_id)
)
for r in recv_result.all():
await session.delete(r)
await session.delete(target)
await session.commit()
@@ -149,7 +163,7 @@ async def test_target(
return result
def _target_response(target: NotificationTarget, chat_names: dict[str, str] | None = None) -> dict:
def _target_response(target: NotificationTarget, chat_names: dict[str, str] | None = None, receiver_count: int = 0) -> dict:
resp = {
"id": target.id,
"type": target.type,
@@ -157,6 +171,7 @@ def _target_response(target: NotificationTarget, chat_names: dict[str, str] | No
"icon": target.icon,
"config": _safe_config(target),
"chat_action": target.chat_action,
"receiver_count": receiver_count,
"created_at": target.created_at.isoformat(),
}
# Attach resolved chat name for telegram targets
@@ -1,6 +1,11 @@
"""Template configuration CRUD API routes."""
"""Template configuration CRUD API routes.
Template content is stored in TemplateSlot child rows (one per slot_name).
The API exposes slots as a flat dict in create/update/response payloads.
"""
import logging
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
@@ -12,7 +17,7 @@ from jinja2 import TemplateSyntaxError, UndefinedError, StrictUndefined
from ..auth.dependencies import get_current_user
from ..database.engine import get_session
from ..database.models import TemplateConfig, User
from ..database.models import TemplateConfig, TemplateSlot, User
from ..services.sample_context import _SAMPLE_CONTEXT
_LOGGER = logging.getLogger(__name__)
@@ -25,21 +30,83 @@ class TemplateConfigCreate(BaseModel):
name: str
description: str | None = None
icon: str | None = None
message_assets_added: str | None = None
message_assets_removed: str | None = None
message_collection_renamed: str | None = None
message_collection_deleted: str | None = None
message_sharing_changed: str | None = None
periodic_summary_message: str | None = None
scheduled_assets_message: str | None = None
memory_mode_message: str | None = None
date_format: str | None = None
date_only_format: str | None = None
slots: dict[str, str] = {} # slot_name -> template text
TemplateConfigUpdate = TemplateConfigCreate # Same shape, all optional
class TemplateConfigUpdate(BaseModel):
name: str | None = None
description: str | None = None
icon: str | None = None
date_format: str | None = None
date_only_format: str | None = None
slots: dict[str, str] | None = None # partial update: only provided slots change
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _load_slots(session: AsyncSession, config_id: int) -> dict[str, str]:
"""Load all template slots for a config as a dict."""
result = await session.exec(
select(TemplateSlot).where(TemplateSlot.config_id == config_id)
)
return {s.slot_name: s.template for s in result.all()}
async def _save_slots(
session: AsyncSession, config_id: int, slots: dict[str, str]
) -> None:
"""Create or update template slots for a config."""
for slot_name, template_text in slots.items():
result = await session.exec(
select(TemplateSlot).where(
TemplateSlot.config_id == config_id,
TemplateSlot.slot_name == slot_name,
)
)
existing = result.first()
if existing:
existing.template = template_text
session.add(existing)
else:
session.add(TemplateSlot(
config_id=config_id,
slot_name=slot_name,
template=template_text,
))
async def _response(session: AsyncSession, c: TemplateConfig) -> dict[str, Any]:
"""Build API response dict for a TemplateConfig, including its slots."""
slots = await _load_slots(session, c.id)
return {
"id": c.id,
"user_id": c.user_id,
"provider_type": c.provider_type,
"name": c.name,
"description": c.description,
"icon": c.icon,
"date_format": c.date_format,
"date_only_format": c.date_only_format,
"slots": slots,
"created_at": c.created_at.isoformat(),
}
async def _get(session: AsyncSession, config_id: int, user_id: int) -> TemplateConfig:
config = await session.get(TemplateConfig, config_id)
if not config or (config.user_id != user_id and config.user_id != 0):
raise HTTPException(status_code=404, detail="Template config not found")
return config
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@router.get("")
async def list_configs(
provider_type: str | None = None,
@@ -53,7 +120,7 @@ async def list_configs(
if provider_type:
query = query.where(TemplateConfig.provider_type == provider_type)
result = await session.exec(query)
return [_response(c) for c in result.all()]
return [await _response(session, c) for c in result.all()]
@router.get("/variables")
@@ -180,12 +247,22 @@ async def create_config(
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
data = {k: v for k, v in body.model_dump().items() if v is not None}
config = TemplateConfig(user_id=user.id, **data)
config = TemplateConfig(
user_id=user.id,
provider_type=body.provider_type,
name=body.name,
description=body.description or "",
icon=body.icon or "",
date_format=body.date_format or "%d.%m.%Y, %H:%M UTC",
date_only_format=body.date_only_format or "%d.%m.%Y",
)
session.add(config)
await session.flush() # get config.id
if body.slots:
await _save_slots(session, config.id, body.slots)
await session.commit()
await session.refresh(config)
return _response(config)
return await _response(session, config)
@router.get("/{config_id}")
@@ -194,7 +271,8 @@ async def get_config(
user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session),
):
return _response(await _get(session, config_id, user.id))
config = await _get(session, config_id, user.id)
return await _response(session, config)
@router.put("/{config_id}")
@@ -205,13 +283,15 @@ async def update_config(
session: AsyncSession = Depends(get_session),
):
config = await _get(session, config_id, user.id)
for field, value in body.model_dump(exclude_unset=True).items():
for field, value in body.model_dump(exclude_unset=True, exclude={"slots"}).items():
if value is not None:
setattr(config, field, value)
session.add(config)
if body.slots is not None:
await _save_slots(session, config.id, body.slots)
await session.commit()
await session.refresh(config)
return _response(config)
return await _response(session, config)
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT)
@@ -221,6 +301,12 @@ async def delete_config(
session: AsyncSession = Depends(get_session),
):
config = await _get(session, config_id, user.id)
# Delete child slots first
slot_result = await session.exec(
select(TemplateSlot).where(TemplateSlot.config_id == config.id)
)
for slot in slot_result.all():
await session.delete(slot)
await session.delete(config)
await session.commit()
@@ -234,9 +320,10 @@ async def preview_config(
):
"""Render a specific template slot with sample data."""
config = await _get(session, config_id, user.id)
template_body = getattr(config, slot, None)
if template_body is None:
raise HTTPException(status_code=400, detail=f"Unknown slot: {slot}")
slots = await _load_slots(session, config.id)
template_body = slots.get(slot, "")
if not template_body:
raise HTTPException(status_code=400, detail=f"Slot '{slot}' has no template")
try:
env = SandboxedEnvironment(autoescape=False)
tmpl = env.from_string(template_body)
@@ -320,17 +407,3 @@ async def preview_raw(
return {"rendered": None, "error": str(e), "error_line": None, "error_type": "undefined"}
except Exception as e:
return {"rendered": None, "error": str(e), "error_line": None}
def _response(c: TemplateConfig) -> dict:
return {k: getattr(c, k) for k in TemplateConfig.model_fields if k not in ("user_id", "created_at")} | {
"user_id": c.user_id,
"created_at": c.created_at.isoformat(),
}
async def _get(session: AsyncSession, config_id: int, user_id: int) -> TemplateConfig:
config = await session.get(TemplateConfig, config_id)
if not config or (config.user_id != user_id and config.user_id != 0):
raise HTTPException(status_code=404, detail="Template config not found")
return config