feat: entity relationship refactor — notification trackers, command system, chat actions
Rework entity schema: rename Tracker→NotificationTracker, add CommandConfig/ CommandTracker/CommandTrackerListener entities for decoupled command handling. Commands now resolve through CommandTracker→CommandConfig instead of TelegramBot.commands_config. Smart ref-counted bot polling based on active listeners. Add chat_action to telegram targets. Full frontend CRUD pages for command configs and command trackers. Idempotent SQLite migrations. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,151 @@
|
||||
"""Command config management API routes."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import CommandConfig, CommandTracker, User
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/command-configs", tags=["command-configs"])
|
||||
|
||||
|
||||
class CommandConfigCreate(BaseModel):
|
||||
provider_type: str
|
||||
name: str
|
||||
icon: str = ""
|
||||
enabled_commands: list[str] = []
|
||||
locale: str = "en"
|
||||
response_mode: str = "media"
|
||||
default_count: int = 5
|
||||
rate_limits: dict[str, Any] = {}
|
||||
|
||||
|
||||
class CommandConfigUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
icon: str | None = None
|
||||
enabled_commands: list[str] | None = None
|
||||
locale: str | None = None
|
||||
response_mode: str | None = None
|
||||
default_count: int | None = None
|
||||
rate_limits: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_command_configs(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""List all command configs for the current user."""
|
||||
result = await session.exec(
|
||||
select(CommandConfig).where(CommandConfig.user_id == user.id)
|
||||
)
|
||||
return [_config_response(c) for c in result.all()]
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_command_config(
|
||||
body: CommandConfigCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Create a new command config."""
|
||||
# Validate provider_type
|
||||
valid_types = ("immich",)
|
||||
if body.provider_type not in valid_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid provider_type. Must be one of: {', '.join(valid_types)}",
|
||||
)
|
||||
|
||||
config = CommandConfig(user_id=user.id, **body.model_dump())
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return _config_response(config)
|
||||
|
||||
|
||||
@router.get("/{config_id}")
|
||||
async def get_command_config(
|
||||
config_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get a single command config."""
|
||||
config = await _get_user_config(session, config_id, user.id)
|
||||
return _config_response(config)
|
||||
|
||||
|
||||
@router.put("/{config_id}")
|
||||
async def update_command_config(
|
||||
config_id: int,
|
||||
body: CommandConfigUpdate,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Update a command config."""
|
||||
config = await _get_user_config(session, config_id, user.id)
|
||||
for field, value in body.model_dump(exclude_unset=True).items():
|
||||
setattr(config, field, value)
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return _config_response(config)
|
||||
|
||||
|
||||
@router.delete("/{config_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_command_config(
|
||||
config_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Delete a command config. Fails if in use by any command tracker."""
|
||||
config = await _get_user_config(session, config_id, user.id)
|
||||
|
||||
# Check if any command tracker references this config
|
||||
result = await session.exec(
|
||||
select(CommandTracker).where(CommandTracker.command_config_id == config_id)
|
||||
)
|
||||
if result.first():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Cannot delete: command config is in use by a command tracker",
|
||||
)
|
||||
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
|
||||
def _config_response(c: CommandConfig) -> dict:
|
||||
return {
|
||||
"id": c.id,
|
||||
"user_id": c.user_id,
|
||||
"provider_type": c.provider_type,
|
||||
"name": c.name,
|
||||
"icon": c.icon,
|
||||
"enabled_commands": c.enabled_commands or [],
|
||||
"locale": c.locale,
|
||||
"response_mode": c.response_mode,
|
||||
"default_count": c.default_count,
|
||||
"rate_limits": c.rate_limits or {},
|
||||
"created_at": c.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
async def _get_user_config(
|
||||
session: AsyncSession, config_id: int, user_id: int
|
||||
) -> CommandConfig:
|
||||
config = await session.get(CommandConfig, config_id)
|
||||
if not config or config.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Command config not found")
|
||||
return config
|
||||
@@ -0,0 +1,371 @@
|
||||
"""Command tracker and listener management API routes."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import func, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import (
|
||||
CommandConfig,
|
||||
CommandTracker,
|
||||
CommandTrackerListener,
|
||||
ServiceProvider,
|
||||
TelegramBot,
|
||||
User,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/command-trackers", tags=["command-trackers"])
|
||||
|
||||
|
||||
class CommandTrackerCreate(BaseModel):
|
||||
provider_id: int
|
||||
command_config_id: int
|
||||
name: str
|
||||
icon: str = ""
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class CommandTrackerUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
icon: str | None = None
|
||||
enabled: bool | None = None
|
||||
command_config_id: int | None = None
|
||||
|
||||
|
||||
class ListenerCreate(BaseModel):
|
||||
listener_type: str
|
||||
listener_id: int
|
||||
|
||||
|
||||
# --- Command Tracker CRUD ---
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_command_trackers(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""List all command trackers for the current user, with listener counts."""
|
||||
result = await session.exec(
|
||||
select(CommandTracker).where(CommandTracker.user_id == user.id)
|
||||
)
|
||||
trackers = result.all()
|
||||
return [await _tracker_response(session, t) for t in trackers]
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_command_tracker(
|
||||
body: CommandTrackerCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Create a new command tracker."""
|
||||
# Validate provider exists and user owns it
|
||||
provider = await session.get(ServiceProvider, body.provider_id)
|
||||
if not provider or provider.user_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# Validate command config exists and user owns it
|
||||
config = await session.get(CommandConfig, body.command_config_id)
|
||||
if not config or config.user_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Command config not found")
|
||||
|
||||
# Validate provider_type matches
|
||||
if config.provider_type != provider.type:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider type mismatch: provider is '{provider.type}' but command config is for '{config.provider_type}'",
|
||||
)
|
||||
|
||||
tracker = CommandTracker(user_id=user.id, **body.model_dump())
|
||||
session.add(tracker)
|
||||
await session.commit()
|
||||
await session.refresh(tracker)
|
||||
return await _tracker_response(session, tracker)
|
||||
|
||||
|
||||
@router.get("/{tracker_id}")
|
||||
async def get_command_tracker(
|
||||
tracker_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Get a single command tracker with its listeners."""
|
||||
tracker = await _get_user_tracker(session, tracker_id, user.id)
|
||||
return await _tracker_response(session, tracker, include_listeners=True)
|
||||
|
||||
|
||||
@router.put("/{tracker_id}")
|
||||
async def update_command_tracker(
|
||||
tracker_id: int,
|
||||
body: CommandTrackerUpdate,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Update a command tracker."""
|
||||
tracker = await _get_user_tracker(session, tracker_id, user.id)
|
||||
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
|
||||
# If changing command_config_id, validate ownership and provider_type match
|
||||
if "command_config_id" in updates and updates["command_config_id"] is not None:
|
||||
config = await session.get(CommandConfig, updates["command_config_id"])
|
||||
if not config or config.user_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Command config not found")
|
||||
provider = await session.get(ServiceProvider, tracker.provider_id)
|
||||
if provider and config.provider_type != provider.type:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Provider type mismatch: provider is '{provider.type}' but command config is for '{config.provider_type}'",
|
||||
)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(tracker, field, value)
|
||||
session.add(tracker)
|
||||
await session.commit()
|
||||
await session.refresh(tracker)
|
||||
return await _tracker_response(session, tracker)
|
||||
|
||||
|
||||
@router.delete("/{tracker_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_command_tracker(
|
||||
tracker_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Delete a command tracker and cascade delete its listeners."""
|
||||
tracker = await _get_user_tracker(session, tracker_id, user.id)
|
||||
|
||||
# Delete associated listeners, collecting bot IDs for polling cleanup
|
||||
result = await session.exec(
|
||||
select(CommandTrackerListener).where(
|
||||
CommandTrackerListener.command_tracker_id == tracker_id
|
||||
)
|
||||
)
|
||||
bot_ids_to_check: set[int] = set()
|
||||
for listener in result.all():
|
||||
if listener.listener_type == "telegram_bot":
|
||||
bot_ids_to_check.add(listener.listener_id)
|
||||
await session.delete(listener)
|
||||
|
||||
await session.delete(tracker)
|
||||
await session.commit()
|
||||
|
||||
# Stop polling for bots that may no longer be needed
|
||||
if bot_ids_to_check:
|
||||
from ..services.telegram_poller import stop_bot_if_unused
|
||||
for bot_id in bot_ids_to_check:
|
||||
await stop_bot_if_unused(bot_id)
|
||||
|
||||
|
||||
@router.post("/{tracker_id}/enable")
|
||||
async def enable_command_tracker(
|
||||
tracker_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Enable a command tracker."""
|
||||
tracker = await _get_user_tracker(session, tracker_id, user.id)
|
||||
tracker.enabled = True
|
||||
session.add(tracker)
|
||||
await session.commit()
|
||||
await session.refresh(tracker)
|
||||
|
||||
# Start polling for any telegram bot listeners
|
||||
lr = await session.exec(
|
||||
select(CommandTrackerListener).where(
|
||||
CommandTrackerListener.command_tracker_id == tracker_id,
|
||||
CommandTrackerListener.listener_type == "telegram_bot",
|
||||
)
|
||||
)
|
||||
from ..services.telegram_poller import start_bot_if_needed
|
||||
for listener in lr.all():
|
||||
await start_bot_if_needed(listener.listener_id)
|
||||
|
||||
return await _tracker_response(session, tracker)
|
||||
|
||||
|
||||
@router.post("/{tracker_id}/disable")
|
||||
async def disable_command_tracker(
|
||||
tracker_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Disable a command tracker."""
|
||||
tracker = await _get_user_tracker(session, tracker_id, user.id)
|
||||
tracker.enabled = False
|
||||
session.add(tracker)
|
||||
await session.commit()
|
||||
await session.refresh(tracker)
|
||||
|
||||
# Stop polling for any telegram bot listeners that are no longer needed
|
||||
lr = await session.exec(
|
||||
select(CommandTrackerListener).where(
|
||||
CommandTrackerListener.command_tracker_id == tracker_id,
|
||||
CommandTrackerListener.listener_type == "telegram_bot",
|
||||
)
|
||||
)
|
||||
from ..services.telegram_poller import stop_bot_if_unused
|
||||
for listener in lr.all():
|
||||
await stop_bot_if_unused(listener.listener_id)
|
||||
|
||||
return await _tracker_response(session, tracker)
|
||||
|
||||
|
||||
# --- Listener Management ---
|
||||
|
||||
|
||||
@router.get("/{tracker_id}/listeners")
|
||||
async def list_listeners(
|
||||
tracker_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""List all listeners for a command tracker."""
|
||||
await _get_user_tracker(session, tracker_id, user.id)
|
||||
result = await session.exec(
|
||||
select(CommandTrackerListener).where(
|
||||
CommandTrackerListener.command_tracker_id == tracker_id
|
||||
)
|
||||
)
|
||||
return [_listener_response(l) for l in result.all()]
|
||||
|
||||
|
||||
@router.post("/{tracker_id}/listeners", status_code=status.HTTP_201_CREATED)
|
||||
async def add_listener(
|
||||
tracker_id: int,
|
||||
body: ListenerCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Add a listener to a command tracker."""
|
||||
await _get_user_tracker(session, tracker_id, user.id)
|
||||
|
||||
# Validate listener exists and user owns it
|
||||
if body.listener_type == "telegram_bot":
|
||||
bot = await session.get(TelegramBot, body.listener_id)
|
||||
if not bot or bot.user_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Telegram bot not found")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported listener type: {body.listener_type}",
|
||||
)
|
||||
|
||||
# Check for duplicate
|
||||
result = await session.exec(
|
||||
select(CommandTrackerListener).where(
|
||||
CommandTrackerListener.command_tracker_id == tracker_id,
|
||||
CommandTrackerListener.listener_type == body.listener_type,
|
||||
CommandTrackerListener.listener_id == body.listener_id,
|
||||
)
|
||||
)
|
||||
if result.first():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Listener is already linked to this command tracker",
|
||||
)
|
||||
|
||||
listener = CommandTrackerListener(
|
||||
command_tracker_id=tracker_id,
|
||||
listener_type=body.listener_type,
|
||||
listener_id=body.listener_id,
|
||||
)
|
||||
session.add(listener)
|
||||
await session.commit()
|
||||
await session.refresh(listener)
|
||||
|
||||
# Start polling for this bot if needed
|
||||
if body.listener_type == "telegram_bot":
|
||||
from ..services.telegram_poller import start_bot_if_needed
|
||||
await start_bot_if_needed(body.listener_id)
|
||||
|
||||
return _listener_response(listener)
|
||||
|
||||
|
||||
@router.delete("/{tracker_id}/listeners/{listener_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_listener(
|
||||
tracker_id: int,
|
||||
listener_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Remove a listener from a command tracker."""
|
||||
await _get_user_tracker(session, tracker_id, user.id)
|
||||
listener = await session.get(CommandTrackerListener, listener_id)
|
||||
if not listener or listener.command_tracker_id != tracker_id:
|
||||
raise HTTPException(status_code=404, detail="Listener not found")
|
||||
|
||||
removed_type = listener.listener_type
|
||||
removed_id = listener.listener_id
|
||||
|
||||
await session.delete(listener)
|
||||
await session.commit()
|
||||
|
||||
# Stop polling for this bot if no longer needed
|
||||
if removed_type == "telegram_bot":
|
||||
from ..services.telegram_poller import stop_bot_if_unused
|
||||
await stop_bot_if_unused(removed_id)
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
|
||||
|
||||
async def _tracker_response(
|
||||
session: AsyncSession, t: CommandTracker, include_listeners: bool = False
|
||||
) -> dict:
|
||||
"""Build command tracker response."""
|
||||
# Get listener count
|
||||
result = await session.exec(
|
||||
select(func.count()).select_from(CommandTrackerListener).where(
|
||||
CommandTrackerListener.command_tracker_id == t.id
|
||||
)
|
||||
)
|
||||
listeners_count = result.one()
|
||||
|
||||
resp = {
|
||||
"id": t.id,
|
||||
"user_id": t.user_id,
|
||||
"provider_id": t.provider_id,
|
||||
"command_config_id": t.command_config_id,
|
||||
"name": t.name,
|
||||
"icon": t.icon,
|
||||
"enabled": t.enabled,
|
||||
"listeners_count": listeners_count,
|
||||
"created_at": t.created_at.isoformat(),
|
||||
}
|
||||
|
||||
if include_listeners:
|
||||
lr = await session.exec(
|
||||
select(CommandTrackerListener).where(
|
||||
CommandTrackerListener.command_tracker_id == t.id
|
||||
)
|
||||
)
|
||||
resp["listeners"] = [_listener_response(l) for l in lr.all()]
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
def _listener_response(l: CommandTrackerListener) -> dict:
|
||||
return {
|
||||
"id": l.id,
|
||||
"command_tracker_id": l.command_tracker_id,
|
||||
"listener_type": l.listener_type,
|
||||
"listener_id": l.listener_id,
|
||||
"created_at": l.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
async def _get_user_tracker(
|
||||
session: AsyncSession, tracker_id: int, user_id: int
|
||||
) -> CommandTracker:
|
||||
tracker = await session.get(CommandTracker, tracker_id)
|
||||
if not tracker or tracker.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Command tracker not found")
|
||||
return tracker
|
||||
+28
-31
@@ -1,4 +1,4 @@
|
||||
"""Tracker-Target link management API routes."""
|
||||
"""Notification tracker-target link management API routes."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -12,10 +12,10 @@ from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import (
|
||||
NotificationTarget,
|
||||
NotificationTracker,
|
||||
NotificationTrackerTarget,
|
||||
ServiceProvider,
|
||||
TemplateConfig,
|
||||
Tracker,
|
||||
TrackerTarget,
|
||||
TrackingConfig,
|
||||
User,
|
||||
)
|
||||
@@ -23,50 +23,48 @@ from ..services.notifier import send_real_data_notification, send_test_notificat
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/trackers/{tracker_id}/targets", tags=["tracker-targets"])
|
||||
router = APIRouter(prefix="/api/notification-trackers/{tracker_id}/targets", tags=["notification-tracker-targets"])
|
||||
|
||||
|
||||
class TrackerTargetCreate(BaseModel):
|
||||
class NotificationTrackerTargetCreate(BaseModel):
|
||||
target_id: int
|
||||
tracking_config_id: int | None = None
|
||||
template_config_id: int | None = None
|
||||
enabled: bool = True
|
||||
quiet_hours_start: str | None = None
|
||||
quiet_hours_end: str | None = None
|
||||
commands_config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TrackerTargetUpdate(BaseModel):
|
||||
class NotificationTrackerTargetUpdate(BaseModel):
|
||||
tracking_config_id: int | None = None
|
||||
template_config_id: int | None = None
|
||||
enabled: bool | None = None
|
||||
quiet_hours_start: str | None = None
|
||||
quiet_hours_end: str | None = None
|
||||
commands_config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_tracker_targets(
|
||||
async def list_notification_tracker_targets(
|
||||
tracker_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""List all target links for a tracker."""
|
||||
"""List all target links for a notification tracker."""
|
||||
await _get_user_tracker(session, tracker_id, user.id)
|
||||
result = await session.exec(
|
||||
select(TrackerTarget).where(TrackerTarget.tracker_id == tracker_id)
|
||||
select(NotificationTrackerTarget).where(NotificationTrackerTarget.tracker_id == tracker_id)
|
||||
)
|
||||
return [await _tt_response(session, tt) for tt in result.all()]
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_tracker_target(
|
||||
async def create_notification_tracker_target(
|
||||
tracker_id: int,
|
||||
body: TrackerTargetCreate,
|
||||
body: NotificationTrackerTargetCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Link a target to a tracker with per-link configuration."""
|
||||
"""Link a target to a notification tracker with per-link configuration."""
|
||||
await _get_user_tracker(session, tracker_id, user.id)
|
||||
|
||||
# Validate target exists and belongs to user
|
||||
@@ -76,9 +74,9 @@ async def create_tracker_target(
|
||||
|
||||
# Check for duplicate link
|
||||
result = await session.exec(
|
||||
select(TrackerTarget).where(
|
||||
TrackerTarget.tracker_id == tracker_id,
|
||||
TrackerTarget.target_id == body.target_id,
|
||||
select(NotificationTrackerTarget).where(
|
||||
NotificationTrackerTarget.tracker_id == tracker_id,
|
||||
NotificationTrackerTarget.target_id == body.target_id,
|
||||
)
|
||||
)
|
||||
if result.first():
|
||||
@@ -97,7 +95,7 @@ async def create_tracker_target(
|
||||
if not tpc or (tpc.user_id != user.id and tpc.user_id != 0):
|
||||
raise HTTPException(status_code=404, detail="Template config not found")
|
||||
|
||||
tt = TrackerTarget(tracker_id=tracker_id, **body.model_dump())
|
||||
tt = NotificationTrackerTarget(tracker_id=tracker_id, **body.model_dump())
|
||||
session.add(tt)
|
||||
await session.commit()
|
||||
await session.refresh(tt)
|
||||
@@ -105,16 +103,16 @@ async def create_tracker_target(
|
||||
|
||||
|
||||
@router.put("/{tracker_target_id}")
|
||||
async def update_tracker_target(
|
||||
async def update_notification_tracker_target(
|
||||
tracker_id: int,
|
||||
tracker_target_id: int,
|
||||
body: TrackerTargetUpdate,
|
||||
body: NotificationTrackerTargetUpdate,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Update a tracker-target link's configuration."""
|
||||
"""Update a notification tracker-target link's configuration."""
|
||||
await _get_user_tracker(session, tracker_id, user.id)
|
||||
tt = await session.get(TrackerTarget, tracker_target_id)
|
||||
tt = await session.get(NotificationTrackerTarget, tracker_target_id)
|
||||
if not tt or tt.tracker_id != tracker_id:
|
||||
raise HTTPException(status_code=404, detail="Tracker-target link not found")
|
||||
|
||||
@@ -138,15 +136,15 @@ async def update_tracker_target(
|
||||
|
||||
|
||||
@router.delete("/{tracker_target_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_tracker_target(
|
||||
async def delete_notification_tracker_target(
|
||||
tracker_id: int,
|
||||
tracker_target_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Remove a target link from a tracker."""
|
||||
"""Remove a target link from a notification tracker."""
|
||||
await _get_user_tracker(session, tracker_id, user.id)
|
||||
tt = await session.get(TrackerTarget, tracker_target_id)
|
||||
tt = await session.get(NotificationTrackerTarget, tracker_target_id)
|
||||
if not tt or tt.tracker_id != tracker_id:
|
||||
raise HTTPException(status_code=404, detail="Tracker-target link not found")
|
||||
await session.delete(tt)
|
||||
@@ -154,7 +152,7 @@ async def delete_tracker_target(
|
||||
|
||||
|
||||
@router.post("/{tracker_target_id}/test/{test_type}")
|
||||
async def test_tracker_target(
|
||||
async def test_notification_tracker_target(
|
||||
tracker_id: int,
|
||||
tracker_target_id: int,
|
||||
test_type: str,
|
||||
@@ -171,7 +169,7 @@ async def test_tracker_target(
|
||||
raise HTTPException(status_code=400, detail=f"Invalid test type. Must be one of: {', '.join(valid_types)}")
|
||||
|
||||
tracker = await _get_user_tracker(session, tracker_id, user.id)
|
||||
tt = await session.get(TrackerTarget, tracker_target_id)
|
||||
tt = await session.get(NotificationTrackerTarget, tracker_target_id)
|
||||
if not tt or tt.tracker_id != tracker_id:
|
||||
raise HTTPException(status_code=404, detail="Tracker-target link not found")
|
||||
|
||||
@@ -224,7 +222,7 @@ async def test_tracker_target(
|
||||
return {"target": target.name, **r}
|
||||
|
||||
|
||||
async def _tt_response(session: AsyncSession, tt: TrackerTarget) -> dict:
|
||||
async def _tt_response(session: AsyncSession, tt: NotificationTrackerTarget) -> dict:
|
||||
"""Build tracker-target response with target details."""
|
||||
target = await session.get(NotificationTarget, tt.target_id)
|
||||
return {
|
||||
@@ -239,15 +237,14 @@ async def _tt_response(session: AsyncSession, tt: TrackerTarget) -> dict:
|
||||
"enabled": tt.enabled,
|
||||
"quiet_hours_start": tt.quiet_hours_start,
|
||||
"quiet_hours_end": tt.quiet_hours_end,
|
||||
"commands_config": tt.commands_config,
|
||||
"created_at": tt.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
async def _get_user_tracker(
|
||||
session: AsyncSession, tracker_id: int, user_id: int
|
||||
) -> Tracker:
|
||||
tracker = await session.get(Tracker, tracker_id)
|
||||
) -> NotificationTracker:
|
||||
tracker = await session.get(NotificationTracker, tracker_id)
|
||||
if not tracker or tracker.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Tracker not found")
|
||||
return tracker
|
||||
+26
-26
@@ -1,4 +1,4 @@
|
||||
"""Tracker management API routes."""
|
||||
"""Notification tracker management API routes."""
|
||||
|
||||
import logging
|
||||
|
||||
@@ -11,22 +11,21 @@ from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import (
|
||||
EventLog,
|
||||
NotificationTracker,
|
||||
NotificationTrackerState,
|
||||
NotificationTrackerTarget,
|
||||
ServiceProvider,
|
||||
Tracker,
|
||||
TrackerState,
|
||||
TrackerTarget,
|
||||
User,
|
||||
)
|
||||
from ..services.scheduler import schedule_tracker, unschedule_tracker
|
||||
from ..services.watcher import check_tracker
|
||||
from .tracker_targets import _tt_response
|
||||
from .notification_tracker_targets import _tt_response
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/trackers", tags=["trackers"])
|
||||
router = APIRouter(prefix="/api/notification-trackers", tags=["notification-trackers"])
|
||||
|
||||
|
||||
class TrackerCreate(BaseModel):
|
||||
class NotificationTrackerCreate(BaseModel):
|
||||
provider_id: int
|
||||
name: str
|
||||
icon: str = ""
|
||||
@@ -36,7 +35,7 @@ class TrackerCreate(BaseModel):
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class TrackerUpdate(BaseModel):
|
||||
class NotificationTrackerUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
icon: str | None = None
|
||||
collection_ids: list[str] | None = None
|
||||
@@ -46,20 +45,20 @@ class TrackerUpdate(BaseModel):
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_trackers(
|
||||
async def list_notification_trackers(
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
result = await session.exec(
|
||||
select(Tracker).where(Tracker.user_id == user.id)
|
||||
select(NotificationTracker).where(NotificationTracker.user_id == user.id)
|
||||
)
|
||||
trackers = result.all()
|
||||
return [await _tracker_response(session, t) for t in trackers]
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_tracker(
|
||||
body: TrackerCreate,
|
||||
async def create_notification_tracker(
|
||||
body: NotificationTrackerCreate,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
@@ -67,7 +66,7 @@ async def create_tracker(
|
||||
if not provider or provider.user_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
tracker = Tracker(user_id=user.id, **body.model_dump())
|
||||
tracker = NotificationTracker(user_id=user.id, **body.model_dump())
|
||||
session.add(tracker)
|
||||
await session.commit()
|
||||
await session.refresh(tracker)
|
||||
@@ -77,7 +76,7 @@ async def create_tracker(
|
||||
|
||||
|
||||
@router.get("/{tracker_id}")
|
||||
async def get_tracker(
|
||||
async def get_notification_tracker(
|
||||
tracker_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
@@ -87,9 +86,9 @@ async def get_tracker(
|
||||
|
||||
|
||||
@router.put("/{tracker_id}")
|
||||
async def update_tracker(
|
||||
async def update_notification_tracker(
|
||||
tracker_id: int,
|
||||
body: TrackerUpdate,
|
||||
body: NotificationTrackerUpdate,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
@@ -107,7 +106,7 @@ async def update_tracker(
|
||||
|
||||
|
||||
@router.delete("/{tracker_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_tracker(
|
||||
async def delete_notification_tracker(
|
||||
tracker_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
@@ -115,13 +114,13 @@ async def delete_tracker(
|
||||
tracker = await _get_user_tracker(session, tracker_id, user.id)
|
||||
# Delete associated tracker-target links
|
||||
result = await session.exec(
|
||||
select(TrackerTarget).where(TrackerTarget.tracker_id == tracker_id)
|
||||
select(NotificationTrackerTarget).where(NotificationTrackerTarget.tracker_id == tracker_id)
|
||||
)
|
||||
for tt in result.all():
|
||||
await session.delete(tt)
|
||||
# Delete associated tracker state
|
||||
state_result = await session.exec(
|
||||
select(TrackerState).where(TrackerState.tracker_id == tracker_id)
|
||||
select(NotificationTrackerState).where(NotificationTrackerState.tracker_id == tracker_id)
|
||||
)
|
||||
for ts in state_result.all():
|
||||
await session.delete(ts)
|
||||
@@ -138,18 +137,19 @@ async def delete_tracker(
|
||||
|
||||
|
||||
@router.post("/{tracker_id}/trigger")
|
||||
async def trigger_tracker(
|
||||
async def trigger_notification_tracker(
|
||||
tracker_id: int,
|
||||
user: User = Depends(get_current_user),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
):
|
||||
tracker = await _get_user_tracker(session, tracker_id, user.id)
|
||||
from ..services.watcher import check_tracker
|
||||
result = await check_tracker(tracker.id)
|
||||
return {"triggered": True, "result": result}
|
||||
|
||||
|
||||
@router.get("/{tracker_id}/history")
|
||||
async def tracker_history(
|
||||
async def notification_tracker_history(
|
||||
tracker_id: int,
|
||||
limit: int = Query(default=20, ge=1, le=500),
|
||||
user: User = Depends(get_current_user),
|
||||
@@ -175,10 +175,10 @@ async def tracker_history(
|
||||
]
|
||||
|
||||
|
||||
async def _tracker_response(session: AsyncSession, t: Tracker) -> dict:
|
||||
async def _tracker_response(session: AsyncSession, t: NotificationTracker) -> dict:
|
||||
"""Build tracker response with nested tracker_targets."""
|
||||
result = await session.exec(
|
||||
select(TrackerTarget).where(TrackerTarget.tracker_id == t.id)
|
||||
select(NotificationTrackerTarget).where(NotificationTrackerTarget.tracker_id == t.id)
|
||||
)
|
||||
tracker_targets = [await _tt_response(session, tt) for tt in result.all()]
|
||||
|
||||
@@ -198,8 +198,8 @@ async def _tracker_response(session: AsyncSession, t: Tracker) -> dict:
|
||||
|
||||
async def _get_user_tracker(
|
||||
session: AsyncSession, tracker_id: int, user_id: int
|
||||
) -> Tracker:
|
||||
tracker = await session.get(Tracker, tracker_id)
|
||||
) -> NotificationTracker:
|
||||
tracker = await session.get(NotificationTracker, tracker_id)
|
||||
if not tracker or tracker.user_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="Tracker not found")
|
||||
return tracker
|
||||
@@ -8,7 +8,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import NotificationTarget, ServiceProvider, Tracker, EventLog, User
|
||||
from ..database.models import NotificationTarget, NotificationTracker, ServiceProvider, EventLog, User
|
||||
|
||||
router = APIRouter(prefix="/api/status", tags=["status"])
|
||||
|
||||
@@ -31,7 +31,7 @@ async def get_status(
|
||||
)).one()
|
||||
|
||||
trackers_result = await session.exec(
|
||||
select(Tracker).where(Tracker.user_id == user.id)
|
||||
select(NotificationTracker).where(NotificationTracker.user_id == user.id)
|
||||
)
|
||||
trackers = trackers_result.all()
|
||||
active_count = sum(1 for t in trackers if t.enabled)
|
||||
@@ -43,8 +43,8 @@ async def get_status(
|
||||
# Build events query with filters
|
||||
events_query = (
|
||||
select(EventLog)
|
||||
.join(Tracker, EventLog.tracker_id == Tracker.id)
|
||||
.where(Tracker.user_id == user.id)
|
||||
.join(NotificationTracker, EventLog.tracker_id == NotificationTracker.id)
|
||||
.where(NotificationTracker.user_id == user.id)
|
||||
)
|
||||
|
||||
if event_type:
|
||||
@@ -110,8 +110,8 @@ async def get_event_chart(
|
||||
EventLog.event_type,
|
||||
func.count().label("total"),
|
||||
)
|
||||
.join(Tracker, EventLog.tracker_id == Tracker.id)
|
||||
.where(Tracker.user_id == user.id, EventLog.created_at >= cutoff)
|
||||
.join(NotificationTracker, EventLog.tracker_id == NotificationTracker.id)
|
||||
.where(NotificationTracker.user_id == user.id, EventLog.created_at >= cutoff)
|
||||
.group_by(day_col, EventLog.event_type)
|
||||
.order_by(day_col)
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Any
|
||||
|
||||
from ..auth.dependencies import get_current_user
|
||||
from ..database.engine import get_session
|
||||
from ..database.models import NotificationTarget, TelegramBot, TelegramChat, TrackerTarget, User
|
||||
from ..database.models import NotificationTarget, NotificationTrackerTarget, TelegramBot, TelegramChat, User
|
||||
from ..services.notifier import send_test_notification
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -23,12 +23,14 @@ class TargetCreate(BaseModel):
|
||||
name: str
|
||||
icon: str = ""
|
||||
config: dict[str, Any] = {}
|
||||
chat_action: str | None = None
|
||||
|
||||
|
||||
class TargetUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
icon: str | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
chat_action: str | None = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
@@ -80,6 +82,7 @@ async def create_target(
|
||||
name=body.name,
|
||||
icon=body.icon,
|
||||
config=body.config,
|
||||
chat_action=body.chat_action,
|
||||
)
|
||||
session.add(target)
|
||||
await session.commit()
|
||||
@@ -125,7 +128,7 @@ async def delete_target(
|
||||
target = await _get_user_target(session, target_id, user.id)
|
||||
# Delete associated tracker-target links
|
||||
result = await session.exec(
|
||||
select(TrackerTarget).where(TrackerTarget.target_id == target_id)
|
||||
select(NotificationTrackerTarget).where(NotificationTrackerTarget.target_id == target_id)
|
||||
)
|
||||
for tt in result.all():
|
||||
await session.delete(tt)
|
||||
@@ -153,6 +156,7 @@ def _target_response(target: NotificationTarget, chat_names: dict[str, str] | No
|
||||
"name": target.name,
|
||||
"icon": target.icon,
|
||||
"config": _safe_config(target),
|
||||
"chat_action": target.chat_action,
|
||||
"created_at": target.created_at.isoformat(),
|
||||
}
|
||||
# Attach resolved chat name for telegram targets
|
||||
|
||||
@@ -34,7 +34,6 @@ class BotUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
icon: str | None = None
|
||||
update_mode: str | None = None
|
||||
commands_config: dict | None = None
|
||||
|
||||
|
||||
@router.get("")
|
||||
@@ -86,9 +85,6 @@ async def update_bot(
|
||||
bot.name = body.name
|
||||
if body.icon is not None:
|
||||
bot.icon = body.icon
|
||||
if body.commands_config is not None:
|
||||
bot.commands_config = body.commands_config
|
||||
|
||||
# Handle mode switching
|
||||
if body.update_mode is not None and body.update_mode != bot.update_mode:
|
||||
if body.update_mode == "webhook":
|
||||
@@ -403,7 +399,6 @@ def _bot_response(b: TelegramBot) -> dict:
|
||||
"bot_id": b.bot_id,
|
||||
"webhook_path_id": b.webhook_path_id,
|
||||
"update_mode": b.update_mode or "polling",
|
||||
"commands_config": b.commands_config or {},
|
||||
"token_preview": f"{b.token[:8]}...{b.token[-4:]}" if len(b.token) > 12 else "***",
|
||||
"created_at": b.created_at.isoformat(),
|
||||
}
|
||||
|
||||
@@ -16,12 +16,15 @@ from notify_bridge_core.notifications.telegram.media import TELEGRAM_API_BASE_UR
|
||||
from ..database.engine import get_engine
|
||||
from ..services import make_immich_provider
|
||||
from ..database.models import (
|
||||
CommandConfig,
|
||||
CommandTracker,
|
||||
CommandTrackerListener,
|
||||
EventLog,
|
||||
NotificationTarget,
|
||||
NotificationTracker,
|
||||
NotificationTrackerTarget,
|
||||
ServiceProvider,
|
||||
TelegramBot,
|
||||
Tracker,
|
||||
TrackerTarget,
|
||||
TrackingConfig,
|
||||
)
|
||||
from .parser import parse_command
|
||||
@@ -48,6 +51,70 @@ def _check_rate_limit(bot_id: int, chat_id: str, cmd: str, limits: dict[str, int
|
||||
return None
|
||||
|
||||
|
||||
async def _resolve_command_context(
|
||||
bot: TelegramBot,
|
||||
) -> list[tuple[CommandTracker, CommandConfig, ServiceProvider]]:
|
||||
"""Resolve all enabled command trackers, configs, and providers for a bot.
|
||||
|
||||
Finds CommandTrackerListener rows where listener_type="telegram_bot"
|
||||
and listener_id=bot.id, then loads the full chain:
|
||||
CommandTrackerListener -> CommandTracker (enabled) -> CommandConfig + ServiceProvider.
|
||||
"""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
# Find all listeners for this bot
|
||||
result = await session.exec(
|
||||
select(CommandTrackerListener).where(
|
||||
CommandTrackerListener.listener_type == "telegram_bot",
|
||||
CommandTrackerListener.listener_id == bot.id,
|
||||
)
|
||||
)
|
||||
listeners = result.all()
|
||||
|
||||
if not listeners:
|
||||
return []
|
||||
|
||||
tuples: list[tuple[CommandTracker, CommandConfig, ServiceProvider]] = []
|
||||
for listener in listeners:
|
||||
tracker = await session.get(CommandTracker, listener.command_tracker_id)
|
||||
if not tracker or not tracker.enabled:
|
||||
continue
|
||||
config = await session.get(CommandConfig, tracker.command_config_id)
|
||||
if not config:
|
||||
continue
|
||||
provider = await session.get(ServiceProvider, tracker.provider_id)
|
||||
if not provider:
|
||||
continue
|
||||
tuples.append((tracker, config, provider))
|
||||
|
||||
return tuples
|
||||
|
||||
|
||||
def _merge_command_context(
|
||||
ctx: list[tuple[CommandTracker, CommandConfig, ServiceProvider]],
|
||||
) -> tuple[list[str], str, str, int, dict[str, Any]]:
|
||||
"""Merge enabled_commands from all configs and pick defaults from first config.
|
||||
|
||||
Returns (enabled_commands, locale, response_mode, default_count, rate_limits).
|
||||
"""
|
||||
if not ctx:
|
||||
return [], "en", "media", 5, {}
|
||||
|
||||
# Union of all enabled commands across configs
|
||||
enabled: set[str] = set()
|
||||
for _, config, _ in ctx:
|
||||
enabled.update(config.enabled_commands or [])
|
||||
|
||||
# Use first config's settings as defaults
|
||||
first_config = ctx[0][1]
|
||||
locale = first_config.locale or "en"
|
||||
response_mode = first_config.response_mode or "media"
|
||||
default_count = first_config.default_count or 5
|
||||
rate_limits = first_config.rate_limits or {}
|
||||
|
||||
return sorted(enabled), locale, response_mode, default_count, rate_limits
|
||||
|
||||
|
||||
async def handle_command(
|
||||
bot: TelegramBot,
|
||||
chat_id: str,
|
||||
@@ -58,11 +125,8 @@ async def handle_command(
|
||||
if not cmd:
|
||||
return None
|
||||
|
||||
config = bot.commands_config or {}
|
||||
enabled = config.get("enabled", [])
|
||||
default_count = min(config.get("default_count", 5), 20)
|
||||
locale = config.get("locale", "en")
|
||||
rate_limits = config.get("rate_limits", {})
|
||||
ctx = await _resolve_command_context(bot)
|
||||
enabled, locale, response_mode, default_count, rate_limits = _merge_command_context(ctx)
|
||||
|
||||
if cmd == "start":
|
||||
msgs = {
|
||||
@@ -85,20 +149,25 @@ async def handle_command(
|
||||
|
||||
count = min(count_override or default_count, 20)
|
||||
|
||||
# Build providers map from command context
|
||||
providers_map: dict[int, ServiceProvider] = {}
|
||||
for _, _, provider in ctx:
|
||||
providers_map[provider.id] = provider
|
||||
|
||||
# Dispatch
|
||||
if cmd == "help":
|
||||
return _cmd_help(enabled, locale)
|
||||
if cmd == "status":
|
||||
return await _cmd_status(bot, locale)
|
||||
return await _cmd_status(bot, providers_map, locale)
|
||||
if cmd == "albums":
|
||||
return await _cmd_albums(bot, locale)
|
||||
return await _cmd_albums(bot, providers_map, locale)
|
||||
if cmd == "events":
|
||||
return await _cmd_events(bot, count, locale)
|
||||
return await _cmd_events(bot, providers_map, count, locale)
|
||||
if cmd == "people":
|
||||
return await _cmd_people(bot, locale)
|
||||
return await _cmd_people(providers_map, locale)
|
||||
if cmd in ("search", "find", "person", "place", "latest", "random",
|
||||
"favorites", "summary", "memory"):
|
||||
return await _cmd_immich(bot, cmd, args, count, locale)
|
||||
return await _cmd_immich(bot, cmd, args, count, locale, response_mode, providers_map)
|
||||
|
||||
return None
|
||||
|
||||
@@ -112,50 +181,24 @@ def _cmd_help(enabled: list[str], locale: str) -> str:
|
||||
return header.get(locale, header["en"]) + "\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def _get_bot_context(bot: TelegramBot) -> tuple[
|
||||
list[Tracker], dict[int, ServiceProvider]
|
||||
]:
|
||||
"""Get trackers and providers associated with a bot via its targets."""
|
||||
async def _get_notification_trackers_for_providers(
|
||||
provider_ids: set[int],
|
||||
) -> list[NotificationTracker]:
|
||||
"""Get notification trackers for the given provider IDs.
|
||||
|
||||
Used by commands like albums, events, status that need notification
|
||||
tracker data (collection_ids, event logs).
|
||||
"""
|
||||
if not provider_ids:
|
||||
return []
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
# Find targets that use this bot's token
|
||||
result = await session.exec(
|
||||
select(NotificationTarget).where(
|
||||
NotificationTarget.type == "telegram",
|
||||
NotificationTarget.user_id == bot.user_id,
|
||||
select(NotificationTracker).where(
|
||||
NotificationTracker.provider_id.in_(provider_ids)
|
||||
)
|
||||
)
|
||||
targets = result.all()
|
||||
bot_target_ids = {t.id for t in targets if t.config.get("bot_token") == bot.token}
|
||||
|
||||
if not bot_target_ids:
|
||||
return [], {}
|
||||
|
||||
# Find trackers linked to these targets via TrackerTarget
|
||||
tt_result = await session.exec(
|
||||
select(TrackerTarget).where(TrackerTarget.target_id.in_(bot_target_ids))
|
||||
)
|
||||
all_links = tt_result.all()
|
||||
tracker_ids = {tt.tracker_id for tt in all_links}
|
||||
|
||||
if not tracker_ids:
|
||||
return [], {}
|
||||
|
||||
trackers = []
|
||||
provider_ids = set()
|
||||
for tid in tracker_ids:
|
||||
tracker = await session.get(Tracker, tid)
|
||||
if tracker:
|
||||
trackers.append(tracker)
|
||||
provider_ids.add(tracker.provider_id)
|
||||
|
||||
providers_map: dict[int, ServiceProvider] = {}
|
||||
for pid in provider_ids:
|
||||
provider = await session.get(ServiceProvider, pid)
|
||||
if provider:
|
||||
providers_map[pid] = provider
|
||||
|
||||
return trackers, providers_map
|
||||
return list(result.all())
|
||||
|
||||
|
||||
async def _check_native_memory(bot: TelegramBot) -> bool:
|
||||
@@ -173,7 +216,7 @@ async def _check_native_memory(bot: TelegramBot) -> bool:
|
||||
if not bot_target_ids:
|
||||
return False
|
||||
tt_result = await session.exec(
|
||||
select(TrackerTarget).where(TrackerTarget.target_id.in_(bot_target_ids))
|
||||
select(NotificationTrackerTarget).where(NotificationTrackerTarget.target_id.in_(bot_target_ids))
|
||||
)
|
||||
for tt in tt_result.all():
|
||||
if tt.tracking_config_id:
|
||||
@@ -183,8 +226,9 @@ async def _check_native_memory(bot: TelegramBot) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def _cmd_status(bot: TelegramBot, locale: str) -> str:
|
||||
trackers, _ = await _get_bot_context(bot)
|
||||
async def _cmd_status(bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str) -> str:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
active = sum(1 for t in trackers if t.enabled)
|
||||
total = len(trackers)
|
||||
total_albums = sum(len(t.collection_ids or []) for t in trackers)
|
||||
@@ -212,8 +256,9 @@ async def _cmd_status(bot: TelegramBot, locale: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
async def _cmd_albums(bot: TelegramBot, locale: str) -> str:
|
||||
trackers, providers_map = await _get_bot_context(bot)
|
||||
async def _cmd_albums(bot: TelegramBot, providers_map: dict[int, ServiceProvider], locale: str) -> str:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
if not trackers:
|
||||
return "No tracked albums." if locale == "en" else "Нет отслеживаемых альбомов."
|
||||
|
||||
@@ -236,8 +281,9 @@ async def _cmd_albums(bot: TelegramBot, locale: str) -> str:
|
||||
return header + "\n" + "\n".join(lines) if lines else header + "\n (none)"
|
||||
|
||||
|
||||
async def _cmd_events(bot: TelegramBot, count: int, locale: str) -> str:
|
||||
trackers, _ = await _get_bot_context(bot)
|
||||
async def _cmd_events(bot: TelegramBot, providers_map: dict[int, ServiceProvider], count: int, locale: str) -> str:
|
||||
provider_ids = set(providers_map.keys())
|
||||
trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
tracker_ids = [t.id for t in trackers]
|
||||
if not tracker_ids:
|
||||
return "No events." if locale == "en" else "Нет событий."
|
||||
@@ -263,8 +309,7 @@ async def _cmd_events(bot: TelegramBot, count: int, locale: str) -> str:
|
||||
return header + "\n" + "\n".join(lines)
|
||||
|
||||
|
||||
async def _cmd_people(bot: TelegramBot, locale: str) -> str:
|
||||
_, providers_map = await _get_bot_context(bot)
|
||||
async def _cmd_people(providers_map: dict[int, ServiceProvider], locale: str) -> str:
|
||||
all_people: dict[str, str] = {}
|
||||
|
||||
async with aiohttp.ClientSession() as http:
|
||||
@@ -285,23 +330,28 @@ async def _cmd_people(bot: TelegramBot, locale: str) -> str:
|
||||
|
||||
async def _cmd_immich(
|
||||
bot: TelegramBot, cmd: str, args: str, count: int, locale: str,
|
||||
response_mode: str, providers_map: dict[int, ServiceProvider],
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Handle commands that need Immich API access and may return media."""
|
||||
trackers, providers_map = await _get_bot_context(bot)
|
||||
if not trackers:
|
||||
if not providers_map:
|
||||
return "No trackers configured." if locale == "en" else "Трекеры не настроены."
|
||||
|
||||
# Get notification trackers for album data
|
||||
provider_ids = set(providers_map.keys())
|
||||
notification_trackers = await _get_notification_trackers_for_providers(provider_ids)
|
||||
|
||||
all_album_ids: list[str] = []
|
||||
for t in trackers:
|
||||
for t in notification_trackers:
|
||||
all_album_ids.extend(t.collection_ids or [])
|
||||
|
||||
first_tracker = trackers[0]
|
||||
provider = providers_map.get(first_tracker.provider_id)
|
||||
if not provider or provider.type != "immich":
|
||||
# Pick the first immich provider
|
||||
provider: ServiceProvider | None = None
|
||||
for p in providers_map.values():
|
||||
if p.type == "immich":
|
||||
provider = p
|
||||
break
|
||||
if not provider:
|
||||
return "Server not found." if locale == "en" else "Сервер не найден."
|
||||
|
||||
config = bot.commands_config or {}
|
||||
response_mode = config.get("response_mode", "media")
|
||||
async with aiohttp.ClientSession() as http:
|
||||
immich = make_immich_provider(http, provider)
|
||||
client = immich.client
|
||||
@@ -578,10 +628,13 @@ async def send_media_group(
|
||||
|
||||
|
||||
async def register_commands_with_telegram(bot: TelegramBot) -> bool:
|
||||
"""Register enabled commands with Telegram BotFather API."""
|
||||
config = bot.commands_config or {}
|
||||
enabled = config.get("enabled", [])
|
||||
locale = config.get("locale", "en")
|
||||
"""Register enabled commands with Telegram BotFather API.
|
||||
|
||||
Resolves all command trackers and configs for this bot, merges
|
||||
enabled commands (union), and calls setMyCommands.
|
||||
"""
|
||||
ctx = await _resolve_command_context(bot)
|
||||
enabled, locale, _, _, _ = _merge_command_context(ctx)
|
||||
|
||||
commands = []
|
||||
for cmd in enabled:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Data migrations for schema changes.
|
||||
|
||||
Handles converting legacy JSON-array relationships to proper junction tables.
|
||||
Handles converting legacy JSON-array relationships to proper junction tables,
|
||||
and the Phase 1 entity refactor (tracker → notification_tracker, etc.).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from sqlalchemy import text
|
||||
@@ -11,97 +13,133 @@ from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _has_column(conn, table: str, column: str) -> bool:
|
||||
"""Check if a column exists in a SQLite table."""
|
||||
cols = await conn.run_sync(
|
||||
lambda sync_conn: [
|
||||
row[1]
|
||||
for row in sync_conn.execute(
|
||||
text(f"PRAGMA table_info('{table}')")
|
||||
).fetchall()
|
||||
]
|
||||
)
|
||||
return column in cols
|
||||
|
||||
|
||||
async def _has_table(conn, table: str) -> bool:
|
||||
"""Check if a table exists in the SQLite database."""
|
||||
result = await conn.run_sync(
|
||||
lambda sync_conn: sync_conn.execute(
|
||||
text(
|
||||
"SELECT name FROM sqlite_master "
|
||||
"WHERE type='table' AND name=:name"
|
||||
),
|
||||
{"name": table},
|
||||
).fetchone()
|
||||
)
|
||||
return result is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy schema migrations (pre-Phase 1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def migrate_schema(engine: AsyncEngine) -> None:
|
||||
"""Add missing columns to existing tables (SQLite ALTER TABLE ADD COLUMN)."""
|
||||
async with engine.begin() as conn:
|
||||
# Helper to check if column exists
|
||||
async def _has_column(table: str, column: str) -> bool:
|
||||
cols = await conn.run_sync(
|
||||
lambda sync_conn: [
|
||||
row[1]
|
||||
for row in sync_conn.execute(
|
||||
text(f"PRAGMA table_info('{table}')")
|
||||
).fetchall()
|
||||
]
|
||||
)
|
||||
return column in cols
|
||||
# --- Tracker table (may still be named "tracker" or already renamed) ---
|
||||
tracker_table = "notification_tracker" if await _has_table(conn, "notification_tracker") else "tracker"
|
||||
|
||||
# Add batch_duration to tracker if missing
|
||||
if not await _has_column("tracker", "batch_duration"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE tracker ADD COLUMN batch_duration INTEGER DEFAULT 0")
|
||||
)
|
||||
logger.info("Added batch_duration column to tracker table")
|
||||
if await _has_table(conn, tracker_table):
|
||||
if not await _has_column(conn, tracker_table, "batch_duration"):
|
||||
await conn.execute(
|
||||
text(f"ALTER TABLE {tracker_table} ADD COLUMN batch_duration INTEGER DEFAULT 0")
|
||||
)
|
||||
logger.info("Added batch_duration column to %s table", tracker_table)
|
||||
|
||||
# Add enriched fields to event_log if missing
|
||||
for col, sql in [
|
||||
("tracker_name", "ALTER TABLE event_log ADD COLUMN tracker_name TEXT DEFAULT ''"),
|
||||
("provider_id", "ALTER TABLE event_log ADD COLUMN provider_id INTEGER"),
|
||||
("provider_name", "ALTER TABLE event_log ADD COLUMN provider_name TEXT DEFAULT ''"),
|
||||
("assets_count", "ALTER TABLE event_log ADD COLUMN assets_count INTEGER DEFAULT 0"),
|
||||
]:
|
||||
if not await _has_column("event_log", col):
|
||||
await conn.execute(text(sql))
|
||||
logger.info("Added %s column to event_log table", col)
|
||||
if await _has_table(conn, "event_log"):
|
||||
for col, sql in [
|
||||
("tracker_name", "ALTER TABLE event_log ADD COLUMN tracker_name TEXT DEFAULT ''"),
|
||||
("provider_id", "ALTER TABLE event_log ADD COLUMN provider_id INTEGER"),
|
||||
("provider_name", "ALTER TABLE event_log ADD COLUMN provider_name TEXT DEFAULT ''"),
|
||||
("assets_count", "ALTER TABLE event_log ADD COLUMN assets_count INTEGER DEFAULT 0"),
|
||||
]:
|
||||
if not await _has_column(conn, "event_log", col):
|
||||
await conn.execute(text(sql))
|
||||
logger.info("Added %s column to event_log table", col)
|
||||
|
||||
# Add commands_config to telegram_bot if missing
|
||||
if not await _has_column("telegram_bot", "commands_config"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE telegram_bot ADD COLUMN commands_config TEXT DEFAULT '{}'")
|
||||
)
|
||||
logger.info("Added commands_config column to telegram_bot table")
|
||||
|
||||
# Add webhook_path_id to telegram_bot if missing
|
||||
if not await _has_column("telegram_bot", "webhook_path_id"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE telegram_bot ADD COLUMN webhook_path_id TEXT DEFAULT ''")
|
||||
)
|
||||
logger.info("Added webhook_path_id column to telegram_bot table")
|
||||
# Backfill existing bots with unique IDs
|
||||
import uuid
|
||||
bots = (await conn.execute(text("SELECT id FROM telegram_bot"))).fetchall()
|
||||
for bot in bots:
|
||||
if await _has_table(conn, "telegram_bot"):
|
||||
if not await _has_column(conn, "telegram_bot", "commands_config"):
|
||||
await conn.execute(
|
||||
text("UPDATE telegram_bot SET webhook_path_id = :wid WHERE id = :bid"),
|
||||
{"wid": uuid.uuid4().hex, "bid": bot[0]},
|
||||
text("ALTER TABLE telegram_bot ADD COLUMN commands_config TEXT DEFAULT '{}'")
|
||||
)
|
||||
if bots:
|
||||
logger.info("Backfilled webhook_path_id for %d existing bots", len(bots))
|
||||
logger.info("Added commands_config column to telegram_bot table")
|
||||
|
||||
# Add webhook_path_id to telegram_bot if missing
|
||||
if not await _has_column(conn, "telegram_bot", "webhook_path_id"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE telegram_bot ADD COLUMN webhook_path_id TEXT DEFAULT ''")
|
||||
)
|
||||
logger.info("Added webhook_path_id column to telegram_bot table")
|
||||
# Backfill existing bots with unique IDs
|
||||
import uuid
|
||||
bots = (await conn.execute(text("SELECT id FROM telegram_bot"))).fetchall()
|
||||
for bot in bots:
|
||||
await conn.execute(
|
||||
text("UPDATE telegram_bot SET webhook_path_id = :wid WHERE id = :bid"),
|
||||
{"wid": uuid.uuid4().hex, "bid": bot[0]},
|
||||
)
|
||||
if bots:
|
||||
logger.info("Backfilled webhook_path_id for %d existing bots", len(bots))
|
||||
|
||||
# Add update_mode to telegram_bot if missing
|
||||
if not await _has_column(conn, "telegram_bot", "update_mode"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE telegram_bot ADD COLUMN update_mode TEXT DEFAULT 'polling'")
|
||||
)
|
||||
logger.info("Added update_mode column to telegram_bot table")
|
||||
|
||||
# Add date_only_format to template_config if missing
|
||||
if not await _has_column("template_config", "date_only_format"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE template_config ADD COLUMN date_only_format TEXT DEFAULT '%d.%m.%Y'")
|
||||
)
|
||||
logger.info("Added date_only_format column to template_config table")
|
||||
|
||||
# Add update_mode to telegram_bot if missing
|
||||
if not await _has_column("telegram_bot", "update_mode"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE telegram_bot ADD COLUMN update_mode TEXT DEFAULT 'polling'")
|
||||
)
|
||||
logger.info("Added update_mode column to telegram_bot table")
|
||||
if await _has_table(conn, "template_config"):
|
||||
if not await _has_column(conn, "template_config", "date_only_format"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE template_config ADD COLUMN date_only_format TEXT DEFAULT '%d.%m.%Y'")
|
||||
)
|
||||
logger.info("Added date_only_format column to template_config table")
|
||||
|
||||
# Add memory_source to tracking_config if missing
|
||||
if not await _has_column("tracking_config", "memory_source"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE tracking_config ADD COLUMN memory_source TEXT DEFAULT 'albums'")
|
||||
)
|
||||
logger.info("Added memory_source column to tracking_config table")
|
||||
if await _has_table(conn, "tracking_config"):
|
||||
if not await _has_column(conn, "tracking_config", "memory_source"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE tracking_config ADD COLUMN memory_source TEXT DEFAULT 'albums'")
|
||||
)
|
||||
logger.info("Added memory_source column to tracking_config table")
|
||||
|
||||
# Add collection_name and shared to tracker_state if missing
|
||||
if not await _has_column("tracker_state", "collection_name"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE tracker_state ADD COLUMN collection_name TEXT DEFAULT ''")
|
||||
)
|
||||
logger.info("Added collection_name column to tracker_state table")
|
||||
if not await _has_column("tracker_state", "shared"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE tracker_state ADD COLUMN shared INTEGER DEFAULT 0")
|
||||
)
|
||||
logger.info("Added shared column to tracker_state table")
|
||||
state_table = "notification_tracker_state" if await _has_table(conn, "notification_tracker_state") else "tracker_state"
|
||||
if await _has_table(conn, state_table):
|
||||
if not await _has_column(conn, state_table, "collection_name"):
|
||||
await conn.execute(
|
||||
text(f"ALTER TABLE {state_table} ADD COLUMN collection_name TEXT DEFAULT ''")
|
||||
)
|
||||
logger.info("Added collection_name column to %s table", state_table)
|
||||
if not await _has_column(conn, state_table, "shared"):
|
||||
await conn.execute(
|
||||
text(f"ALTER TABLE {state_table} ADD COLUMN shared INTEGER DEFAULT 0")
|
||||
)
|
||||
logger.info("Added shared column to %s table", state_table)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Legacy tracker_target migration (pre-Phase 1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def migrate_tracker_targets(engine: AsyncEngine) -> None:
|
||||
"""Migrate legacy Tracker.target_ids JSON arrays to TrackerTarget rows.
|
||||
|
||||
@@ -114,36 +152,42 @@ async def migrate_tracker_targets(engine: AsyncEngine) -> None:
|
||||
Idempotent: skips if legacy columns don't exist or data already migrated.
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
# Check if legacy target_ids column exists on tracker table
|
||||
columns = await conn.run_sync(
|
||||
lambda sync_conn: [
|
||||
row[1]
|
||||
for row in sync_conn.execute(
|
||||
text("PRAGMA table_info('tracker')")
|
||||
).fetchall()
|
||||
]
|
||||
)
|
||||
if "target_ids" not in columns:
|
||||
# Determine which table name exists (pre- or post-rename)
|
||||
if await _has_table(conn, "tracker"):
|
||||
tracker_table = "tracker"
|
||||
tt_table = "tracker_target"
|
||||
tracker_id_col = "tracker_id"
|
||||
elif await _has_table(conn, "notification_tracker"):
|
||||
tracker_table = "notification_tracker"
|
||||
tt_table = "notification_tracker_target"
|
||||
tracker_id_col = "notification_tracker_id"
|
||||
else:
|
||||
logger.debug("No tracker table found — skipping migration")
|
||||
return
|
||||
|
||||
# Check if legacy target_ids column exists
|
||||
if not await _has_column(conn, tracker_table, "target_ids"):
|
||||
logger.debug("No legacy target_ids column found — skipping migration")
|
||||
return
|
||||
|
||||
# Check if tracker_target table already has data (previous migration ran)
|
||||
tt_count = (
|
||||
await conn.execute(text("SELECT COUNT(*) FROM tracker_target"))
|
||||
).scalar()
|
||||
if tt_count and tt_count > 0:
|
||||
logger.debug(
|
||||
"tracker_target table already has %d rows — skipping migration",
|
||||
tt_count,
|
||||
)
|
||||
return
|
||||
# Check if junction table already has data
|
||||
if await _has_table(conn, tt_table):
|
||||
tt_count = (
|
||||
await conn.execute(text(f"SELECT COUNT(*) FROM {tt_table}"))
|
||||
).scalar()
|
||||
if tt_count and tt_count > 0:
|
||||
logger.debug(
|
||||
"%s table already has %d rows — skipping migration",
|
||||
tt_table, tt_count,
|
||||
)
|
||||
return
|
||||
|
||||
# Load legacy data
|
||||
trackers = (
|
||||
await conn.execute(
|
||||
text(
|
||||
"SELECT id, target_ids, tracking_config_id, "
|
||||
"quiet_hours_start, quiet_hours_end FROM tracker"
|
||||
f"SELECT id, target_ids, tracking_config_id, "
|
||||
f"quiet_hours_start, quiet_hours_end FROM {tracker_table}"
|
||||
)
|
||||
)
|
||||
).fetchall()
|
||||
@@ -154,20 +198,10 @@ async def migrate_tracker_targets(engine: AsyncEngine) -> None:
|
||||
|
||||
# Load template_config_id from targets (legacy field)
|
||||
target_template_map: dict[int, int | None] = {}
|
||||
target_cols = await conn.run_sync(
|
||||
lambda sync_conn: [
|
||||
row[1]
|
||||
for row in sync_conn.execute(
|
||||
text("PRAGMA table_info('notification_target')")
|
||||
).fetchall()
|
||||
]
|
||||
)
|
||||
if "template_config_id" in target_cols:
|
||||
if await _has_column(conn, "notification_target", "template_config_id"):
|
||||
targets = (
|
||||
await conn.execute(
|
||||
text(
|
||||
"SELECT id, template_config_id FROM notification_target"
|
||||
)
|
||||
text("SELECT id, template_config_id FROM notification_target")
|
||||
)
|
||||
).fetchall()
|
||||
for t in targets:
|
||||
@@ -175,15 +209,7 @@ async def migrate_tracker_targets(engine: AsyncEngine) -> None:
|
||||
|
||||
# Load commands_config from telegram_bots (legacy field)
|
||||
bot_commands_map: dict[int, str | None] = {}
|
||||
bot_cols = await conn.run_sync(
|
||||
lambda sync_conn: [
|
||||
row[1]
|
||||
for row in sync_conn.execute(
|
||||
text("PRAGMA table_info('telegram_bot')")
|
||||
).fetchall()
|
||||
]
|
||||
)
|
||||
if "commands_config" in bot_cols:
|
||||
if await _has_column(conn, "telegram_bot", "commands_config"):
|
||||
bots = (
|
||||
await conn.execute(
|
||||
text("SELECT id, commands_config FROM telegram_bot")
|
||||
@@ -195,8 +221,6 @@ async def migrate_tracker_targets(engine: AsyncEngine) -> None:
|
||||
# Build target → bot mapping for commands_config migration
|
||||
target_bot_map: dict[int, int] = {}
|
||||
if bot_commands_map:
|
||||
import json
|
||||
|
||||
tgt_rows = (
|
||||
await conn.execute(
|
||||
text("SELECT id, config FROM notification_target WHERE type='telegram'")
|
||||
@@ -207,35 +231,21 @@ async def migrate_tracker_targets(engine: AsyncEngine) -> None:
|
||||
cfg = json.loads(tgt[1]) if isinstance(tgt[1], str) else tgt[1]
|
||||
if cfg and "bot_token" in cfg:
|
||||
for bot_id, _ in bot_commands_map.items():
|
||||
bot_row = (
|
||||
bot_token_row = (
|
||||
await conn.execute(
|
||||
text("SELECT id FROM telegram_bot WHERE id=:bid"),
|
||||
text("SELECT token FROM telegram_bot WHERE id=:bid"),
|
||||
{"bid": bot_id},
|
||||
)
|
||||
).fetchone()
|
||||
if bot_row:
|
||||
# Match by checking if this target uses this bot's token
|
||||
bot_token_row = (
|
||||
await conn.execute(
|
||||
text(
|
||||
"SELECT token FROM telegram_bot WHERE id=:bid"
|
||||
),
|
||||
{"bid": bot_id},
|
||||
)
|
||||
).fetchone()
|
||||
if bot_token_row and bot_token_row[0] == cfg.get(
|
||||
"bot_token"
|
||||
):
|
||||
target_bot_map[tgt[0]] = bot_id
|
||||
if bot_token_row and bot_token_row[0] == cfg.get("bot_token"):
|
||||
target_bot_map[tgt[0]] = bot_id
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to match bot token for target %s", tgt[0],
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Create TrackerTarget rows
|
||||
import json
|
||||
|
||||
# Create junction rows
|
||||
migrated = 0
|
||||
for tracker in trackers:
|
||||
tracker_id = tracker[0]
|
||||
@@ -244,7 +254,6 @@ async def migrate_tracker_targets(engine: AsyncEngine) -> None:
|
||||
quiet_hours_start = tracker[3]
|
||||
quiet_hours_end = tracker[4]
|
||||
|
||||
# Parse target_ids JSON
|
||||
if isinstance(raw_target_ids, str):
|
||||
try:
|
||||
target_ids = json.loads(raw_target_ids)
|
||||
@@ -258,25 +267,22 @@ async def migrate_tracker_targets(engine: AsyncEngine) -> None:
|
||||
for target_id in target_ids:
|
||||
template_config_id = target_template_map.get(target_id)
|
||||
|
||||
# Get commands_config if this is a telegram target with a known bot
|
||||
commands_config = None
|
||||
if target_id in target_bot_map:
|
||||
bot_id = target_bot_map[target_id]
|
||||
raw_cmd = bot_commands_map.get(bot_id)
|
||||
if raw_cmd:
|
||||
commands_config = (
|
||||
raw_cmd
|
||||
if isinstance(raw_cmd, str)
|
||||
else json.dumps(raw_cmd)
|
||||
raw_cmd if isinstance(raw_cmd, str) else json.dumps(raw_cmd)
|
||||
)
|
||||
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO tracker_target "
|
||||
"(tracker_id, target_id, tracking_config_id, "
|
||||
"template_config_id, enabled, quiet_hours_start, "
|
||||
"quiet_hours_end, commands_config) "
|
||||
"VALUES (:tid, :tgtid, :tcid, :tmplid, 1, :qhs, :qhe, :cmd)"
|
||||
f"INSERT INTO {tt_table} "
|
||||
f"({tracker_id_col}, target_id, tracking_config_id, "
|
||||
f"template_config_id, enabled, quiet_hours_start, "
|
||||
f"quiet_hours_end, commands_config) "
|
||||
f"VALUES (:tid, :tgtid, :tcid, :tmplid, 1, :qhs, :qhe, :cmd)"
|
||||
),
|
||||
{
|
||||
"tid": tracker_id,
|
||||
@@ -291,3 +297,243 @@ async def migrate_tracker_targets(engine: AsyncEngine) -> None:
|
||||
migrated += 1
|
||||
|
||||
logger.info("Migrated %d tracker-target links", migrated)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 1: Entity refactor migration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def migrate_entity_refactor(engine: AsyncEngine) -> None:
|
||||
"""Phase 1 entity refactor — rename tables, add columns, create new tables.
|
||||
|
||||
Fully idempotent: every operation checks preconditions before acting.
|
||||
"""
|
||||
async with engine.begin() as conn:
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Rename table: tracker → notification_tracker
|
||||
# ------------------------------------------------------------------
|
||||
if await _has_table(conn, "tracker") and not await _has_table(conn, "notification_tracker"):
|
||||
await conn.execute(text("ALTER TABLE tracker RENAME TO notification_tracker"))
|
||||
logger.info("Renamed table tracker → notification_tracker")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Rename table: tracker_target → notification_tracker_target
|
||||
# and rename column tracker_id → notification_tracker_id
|
||||
# ------------------------------------------------------------------
|
||||
if await _has_table(conn, "tracker_target") and not await _has_table(conn, "notification_tracker_target"):
|
||||
# SQLite doesn't support RENAME COLUMN in older versions, so we
|
||||
# recreate the table with the new column name.
|
||||
await conn.execute(text(
|
||||
"CREATE TABLE notification_tracker_target ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" notification_tracker_id INTEGER REFERENCES notification_tracker(id),"
|
||||
" target_id INTEGER REFERENCES notification_target(id),"
|
||||
" tracking_config_id INTEGER REFERENCES tracking_config(id),"
|
||||
" template_config_id INTEGER REFERENCES template_config(id),"
|
||||
" enabled INTEGER DEFAULT 1,"
|
||||
" quiet_hours_start TEXT,"
|
||||
" quiet_hours_end TEXT,"
|
||||
" commands_config TEXT,"
|
||||
" created_at TIMESTAMP"
|
||||
")"
|
||||
))
|
||||
await conn.execute(text(
|
||||
"INSERT INTO notification_tracker_target "
|
||||
"(id, notification_tracker_id, target_id, tracking_config_id, "
|
||||
"template_config_id, enabled, quiet_hours_start, quiet_hours_end, "
|
||||
"commands_config, created_at) "
|
||||
"SELECT id, tracker_id, target_id, tracking_config_id, "
|
||||
"template_config_id, enabled, quiet_hours_start, quiet_hours_end, "
|
||||
"commands_config, created_at "
|
||||
"FROM tracker_target"
|
||||
))
|
||||
await conn.execute(text("DROP TABLE tracker_target"))
|
||||
logger.info("Renamed table tracker_target → notification_tracker_target (with column rename tracker_id → notification_tracker_id)")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Rename table: tracker_state → notification_tracker_state
|
||||
# and rename column tracker_id → notification_tracker_id
|
||||
# ------------------------------------------------------------------
|
||||
if await _has_table(conn, "tracker_state") and not await _has_table(conn, "notification_tracker_state"):
|
||||
await conn.execute(text(
|
||||
"CREATE TABLE notification_tracker_state ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" notification_tracker_id INTEGER REFERENCES notification_tracker(id),"
|
||||
" collection_id TEXT,"
|
||||
" collection_name TEXT DEFAULT '',"
|
||||
" shared INTEGER DEFAULT 0,"
|
||||
" asset_ids TEXT,"
|
||||
" pending_asset_ids TEXT,"
|
||||
" last_updated TIMESTAMP"
|
||||
")"
|
||||
))
|
||||
await conn.execute(text(
|
||||
"INSERT INTO notification_tracker_state "
|
||||
"(id, notification_tracker_id, collection_id, collection_name, "
|
||||
"shared, asset_ids, pending_asset_ids, last_updated) "
|
||||
"SELECT id, tracker_id, collection_id, collection_name, "
|
||||
"shared, asset_ids, pending_asset_ids, last_updated "
|
||||
"FROM tracker_state"
|
||||
))
|
||||
await conn.execute(text("DROP TABLE tracker_state"))
|
||||
logger.info("Renamed table tracker_state → notification_tracker_state (with column rename tracker_id → notification_tracker_id)")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Add chat_action column to notification_target
|
||||
# ------------------------------------------------------------------
|
||||
if await _has_table(conn, "notification_target"):
|
||||
if not await _has_column(conn, "notification_target", "chat_action"):
|
||||
await conn.execute(
|
||||
text("ALTER TABLE notification_target ADD COLUMN chat_action TEXT")
|
||||
)
|
||||
logger.info("Added chat_action column to notification_target table")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Rename tracker_id → notification_tracker_id in event_log
|
||||
# ------------------------------------------------------------------
|
||||
if await _has_table(conn, "event_log"):
|
||||
if await _has_column(conn, "event_log", "tracker_id") and not await _has_column(conn, "event_log", "notification_tracker_id"):
|
||||
# Recreate event_log with renamed column
|
||||
await conn.execute(text(
|
||||
"CREATE TABLE event_log_new ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" notification_tracker_id INTEGER REFERENCES notification_tracker(id),"
|
||||
" tracker_name TEXT DEFAULT '',"
|
||||
" provider_id INTEGER,"
|
||||
" provider_name TEXT DEFAULT '',"
|
||||
" event_type TEXT,"
|
||||
" collection_id TEXT,"
|
||||
" collection_name TEXT,"
|
||||
" assets_count INTEGER DEFAULT 0,"
|
||||
" details TEXT,"
|
||||
" created_at TIMESTAMP"
|
||||
")"
|
||||
))
|
||||
await conn.execute(text(
|
||||
"INSERT INTO event_log_new "
|
||||
"(id, notification_tracker_id, tracker_name, provider_id, "
|
||||
"provider_name, event_type, collection_id, collection_name, "
|
||||
"assets_count, details, created_at) "
|
||||
"SELECT id, tracker_id, tracker_name, provider_id, "
|
||||
"provider_name, event_type, collection_id, collection_name, "
|
||||
"assets_count, details, created_at "
|
||||
"FROM event_log"
|
||||
))
|
||||
await conn.execute(text("DROP TABLE event_log"))
|
||||
await conn.execute(text("ALTER TABLE event_log_new RENAME TO event_log"))
|
||||
logger.info("Renamed column tracker_id → notification_tracker_id in event_log")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. Create command_config table
|
||||
# ------------------------------------------------------------------
|
||||
if not await _has_table(conn, "command_config"):
|
||||
await conn.execute(text(
|
||||
"CREATE TABLE command_config ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" user_id INTEGER NOT NULL REFERENCES user(id),"
|
||||
" provider_type TEXT NOT NULL,"
|
||||
" name TEXT NOT NULL,"
|
||||
" icon TEXT DEFAULT '',"
|
||||
" enabled_commands TEXT DEFAULT '[]',"
|
||||
" locale TEXT DEFAULT 'en',"
|
||||
" response_mode TEXT DEFAULT 'media',"
|
||||
" default_count INTEGER DEFAULT 5,"
|
||||
" rate_limits TEXT DEFAULT '{}',"
|
||||
" created_at TIMESTAMP"
|
||||
")"
|
||||
))
|
||||
logger.info("Created command_config table")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 7. Create command_tracker table
|
||||
# ------------------------------------------------------------------
|
||||
if not await _has_table(conn, "command_tracker"):
|
||||
await conn.execute(text(
|
||||
"CREATE TABLE command_tracker ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" user_id INTEGER NOT NULL REFERENCES user(id),"
|
||||
" provider_id INTEGER NOT NULL REFERENCES service_provider(id),"
|
||||
" command_config_id INTEGER NOT NULL REFERENCES command_config(id),"
|
||||
" name TEXT NOT NULL,"
|
||||
" icon TEXT DEFAULT '',"
|
||||
" enabled INTEGER DEFAULT 1,"
|
||||
" created_at TIMESTAMP"
|
||||
")"
|
||||
))
|
||||
logger.info("Created command_tracker table")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 8. Create command_tracker_listener table
|
||||
# ------------------------------------------------------------------
|
||||
if not await _has_table(conn, "command_tracker_listener"):
|
||||
await conn.execute(text(
|
||||
"CREATE TABLE command_tracker_listener ("
|
||||
" id INTEGER PRIMARY KEY,"
|
||||
" command_tracker_id INTEGER NOT NULL REFERENCES command_tracker(id),"
|
||||
" listener_type TEXT NOT NULL,"
|
||||
" listener_id INTEGER NOT NULL,"
|
||||
" created_at TIMESTAMP,"
|
||||
" UNIQUE(command_tracker_id, listener_type, listener_id)"
|
||||
")"
|
||||
))
|
||||
logger.info("Created command_tracker_listener table")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 9. Migrate TelegramBot.commands_config → CommandConfig rows
|
||||
# ------------------------------------------------------------------
|
||||
if await _has_table(conn, "telegram_bot") and await _has_column(conn, "telegram_bot", "commands_config"):
|
||||
# Only migrate if command_config table is empty (idempotent)
|
||||
cc_count = (await conn.execute(text("SELECT COUNT(*) FROM command_config"))).scalar()
|
||||
if cc_count == 0:
|
||||
bots = (await conn.execute(text(
|
||||
"SELECT id, user_id, commands_config FROM telegram_bot"
|
||||
))).fetchall()
|
||||
migrated = 0
|
||||
for bot in bots:
|
||||
bot_id, user_id, raw_config = bot[0], bot[1], bot[2]
|
||||
if not raw_config:
|
||||
continue
|
||||
try:
|
||||
cfg = json.loads(raw_config) if isinstance(raw_config, str) else raw_config
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
# Skip empty/default configs
|
||||
if not cfg or cfg == {}:
|
||||
continue
|
||||
|
||||
# Extract fields from legacy commands_config
|
||||
enabled_commands = json.dumps(cfg.get("enabled_commands", []))
|
||||
locale = cfg.get("locale", "en")
|
||||
response_mode = cfg.get("response_mode", "media")
|
||||
default_count = cfg.get("default_count", 5)
|
||||
rate_limits = json.dumps(cfg.get("rate_limits", {}))
|
||||
provider_type = cfg.get("provider_type", "immich")
|
||||
|
||||
await conn.execute(
|
||||
text(
|
||||
"INSERT INTO command_config "
|
||||
"(user_id, provider_type, name, enabled_commands, locale, "
|
||||
"response_mode, default_count, rate_limits, created_at) "
|
||||
"VALUES (:uid, :pt, :name, :ec, :locale, :rm, :dc, :rl, CURRENT_TIMESTAMP)"
|
||||
),
|
||||
{
|
||||
"uid": user_id,
|
||||
"pt": provider_type,
|
||||
"name": f"Bot #{bot_id} Commands",
|
||||
"ec": enabled_commands,
|
||||
"locale": locale,
|
||||
"rm": response_mode,
|
||||
"dc": default_count,
|
||||
"rl": rate_limits,
|
||||
},
|
||||
)
|
||||
migrated += 1
|
||||
|
||||
if migrated:
|
||||
logger.info("Migrated %d bot commands_config → command_config rows", migrated)
|
||||
|
||||
# NOTE: We intentionally do NOT drop commands_config from telegram_bot
|
||||
# or notification_tracker_target. SQLite doesn't support DROP COLUMN in
|
||||
# all versions, and SQLModel will simply ignore columns not defined on
|
||||
# the model class. The columns will remain in the DB but are unused.
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlmodel import JSON, Column, Field, SQLModel
|
||||
|
||||
|
||||
@@ -47,7 +48,8 @@ class TelegramBot(SQLModel, table=True):
|
||||
bot_id: int = Field(default=0)
|
||||
webhook_path_id: str = Field(default_factory=lambda: uuid4().hex)
|
||||
update_mode: str = Field(default="polling") # "polling" or "webhook"
|
||||
commands_config: dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON))
|
||||
# NOTE: commands_config column remains in the DB for backward compat,
|
||||
# but is no longer part of the SQLModel class. Data migrated to CommandConfig.
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
@@ -162,13 +164,14 @@ class NotificationTarget(SQLModel, table=True):
|
||||
name: str
|
||||
icon: str = Field(default="")
|
||||
config: dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON))
|
||||
chat_action: str | None = Field(default=None) # e.g. "typing", "upload_photo"
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
class Tracker(SQLModel, table=True):
|
||||
class NotificationTracker(SQLModel, table=True):
|
||||
"""Watches a provider's collections for changes."""
|
||||
|
||||
__tablename__ = "tracker"
|
||||
__tablename__ = "notification_tracker"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
@@ -182,13 +185,18 @@ class Tracker(SQLModel, table=True):
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
class TrackerTarget(SQLModel, table=True):
|
||||
"""Junction between Tracker and NotificationTarget with per-link config."""
|
||||
class NotificationTrackerTarget(SQLModel, table=True):
|
||||
"""Junction between NotificationTracker and NotificationTarget with per-link config."""
|
||||
|
||||
__tablename__ = "tracker_target"
|
||||
__tablename__ = "notification_tracker_target"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
tracker_id: int = Field(foreign_key="tracker.id", index=True)
|
||||
# Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id
|
||||
tracker_id: int = Field(
|
||||
foreign_key="notification_tracker.id",
|
||||
index=True,
|
||||
sa_column_kwargs={"name": "notification_tracker_id"},
|
||||
)
|
||||
target_id: int = Field(foreign_key="notification_target.id", index=True)
|
||||
tracking_config_id: int | None = Field(
|
||||
default=None, foreign_key="tracking_config.id"
|
||||
@@ -199,19 +207,22 @@ class TrackerTarget(SQLModel, table=True):
|
||||
enabled: bool = Field(default=True)
|
||||
quiet_hours_start: str | None = None
|
||||
quiet_hours_end: str | None = None
|
||||
commands_config: dict[str, Any] | None = Field(
|
||||
default=None, sa_column=Column(JSON)
|
||||
)
|
||||
# NOTE: commands_config column remains in the DB for backward compat,
|
||||
# but is no longer part of the SQLModel class. Data migrated to CommandConfig.
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
class TrackerState(SQLModel, table=True):
|
||||
class NotificationTrackerState(SQLModel, table=True):
|
||||
"""Persisted state for change detection."""
|
||||
|
||||
__tablename__ = "tracker_state"
|
||||
__tablename__ = "notification_tracker_state"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
tracker_id: int = Field(foreign_key="tracker.id")
|
||||
# Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id
|
||||
tracker_id: int = Field(
|
||||
foreign_key="notification_tracker.id",
|
||||
sa_column_kwargs={"name": "notification_tracker_id"},
|
||||
)
|
||||
collection_id: str
|
||||
collection_name: str = Field(default="")
|
||||
shared: bool = Field(default=False)
|
||||
@@ -220,13 +231,70 @@ class TrackerState(SQLModel, table=True):
|
||||
last_updated: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
class CommandConfig(SQLModel, table=True):
|
||||
"""Configuration for bot commands (e.g., which commands are enabled, rate limits)."""
|
||||
|
||||
__tablename__ = "command_config"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
provider_type: str
|
||||
name: str
|
||||
icon: str = Field(default="")
|
||||
enabled_commands: list[str] = Field(default_factory=list, sa_column=Column(JSON))
|
||||
locale: str = Field(default="en")
|
||||
response_mode: str = Field(default="media") # "media" or "text"
|
||||
default_count: int = Field(default=5)
|
||||
rate_limits: dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON))
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
class CommandTracker(SQLModel, table=True):
|
||||
"""Links a provider to a command config for interactive bot commands."""
|
||||
|
||||
__tablename__ = "command_tracker"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
provider_id: int = Field(foreign_key="service_provider.id")
|
||||
command_config_id: int = Field(foreign_key="command_config.id")
|
||||
name: str
|
||||
icon: str = Field(default="")
|
||||
enabled: bool = Field(default=True)
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
class CommandTrackerListener(SQLModel, table=True):
|
||||
"""Links a CommandTracker to a listener (e.g., a telegram bot chat)."""
|
||||
|
||||
__tablename__ = "command_tracker_listener"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"command_tracker_id", "listener_type", "listener_id",
|
||||
name="uq_command_tracker_listener",
|
||||
),
|
||||
)
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
command_tracker_id: int = Field(foreign_key="command_tracker.id")
|
||||
listener_type: str # e.g. "telegram_bot"
|
||||
listener_id: int
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
class EventLog(SQLModel, table=True):
|
||||
"""Log of detected events."""
|
||||
|
||||
__tablename__ = "event_log"
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
tracker_id: int | None = Field(default=None, foreign_key="tracker.id", index=True)
|
||||
# Python attr stays as tracker_id for backward compat; DB column is notification_tracker_id
|
||||
tracker_id: int | None = Field(
|
||||
default=None,
|
||||
foreign_key="notification_tracker.id",
|
||||
index=True,
|
||||
sa_column_kwargs={"name": "notification_tracker_id"},
|
||||
)
|
||||
tracker_name: str = Field(default="")
|
||||
provider_id: int | None = Field(default=None, index=True)
|
||||
provider_name: str = Field(default="")
|
||||
|
||||
@@ -15,8 +15,8 @@ from .database.models import * # noqa: F401,F403 — ensure all models register
|
||||
|
||||
from .auth.routes import router as auth_router
|
||||
from .api.providers import router as providers_router
|
||||
from .api.trackers import router as trackers_router
|
||||
from .api.tracker_targets import router as tracker_targets_router
|
||||
from .api.notification_trackers import router as notification_trackers_router
|
||||
from .api.notification_tracker_targets import router as notification_tracker_targets_router
|
||||
from .api.tracking_configs import router as tracking_configs_router
|
||||
from .api.template_configs import router as template_configs_router
|
||||
from .api.targets import router as targets_router
|
||||
@@ -25,6 +25,8 @@ from .api.users import router as users_router
|
||||
from .api.status import router as status_router
|
||||
from .api.template_vars import router as template_vars_router
|
||||
from .api.app_settings import router as app_settings_router
|
||||
from .api.command_configs import router as command_configs_router
|
||||
from .api.command_trackers import router as command_trackers_router
|
||||
from .commands.webhook import router as webhook_router, set_webhook_secret
|
||||
|
||||
|
||||
@@ -33,10 +35,11 @@ async def lifespan(app: FastAPI):
|
||||
await init_db()
|
||||
# Run data migrations (idempotent)
|
||||
from .database.engine import get_engine
|
||||
from .database.migrations import migrate_schema, migrate_tracker_targets
|
||||
from .database.migrations import migrate_schema, migrate_tracker_targets, migrate_entity_refactor
|
||||
engine = get_engine()
|
||||
await migrate_schema(engine)
|
||||
await migrate_tracker_targets(engine)
|
||||
await migrate_entity_refactor(engine)
|
||||
await _seed_default_templates()
|
||||
# Configure webhook secret from DB setting (falls back to env var)
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession as _AS
|
||||
@@ -55,8 +58,8 @@ app = FastAPI(title="Notify Bridge", version="0.1.0", lifespan=lifespan)
|
||||
app.include_router(auth_router)
|
||||
app.include_router(template_vars_router)
|
||||
app.include_router(providers_router)
|
||||
app.include_router(trackers_router)
|
||||
app.include_router(tracker_targets_router)
|
||||
app.include_router(notification_trackers_router)
|
||||
app.include_router(notification_tracker_targets_router)
|
||||
app.include_router(tracking_configs_router)
|
||||
app.include_router(template_configs_router)
|
||||
app.include_router(targets_router)
|
||||
@@ -64,6 +67,8 @@ app.include_router(telegram_bots_router)
|
||||
app.include_router(users_router)
|
||||
app.include_router(status_router)
|
||||
app.include_router(app_settings_router)
|
||||
app.include_router(command_configs_router)
|
||||
app.include_router(command_trackers_router)
|
||||
app.include_router(webhook_router)
|
||||
|
||||
|
||||
|
||||
@@ -26,9 +26,9 @@ async def start_scheduler() -> None:
|
||||
|
||||
await _load_tracker_jobs()
|
||||
|
||||
# Start Telegram bot polling for bots in polling mode
|
||||
from .telegram_poller import start_bot_polling
|
||||
await start_bot_polling()
|
||||
# Start Telegram bot polling for bots with active command listeners
|
||||
from .telegram_poller import start_command_listener_polling
|
||||
await start_command_listener_polling()
|
||||
|
||||
|
||||
async def _load_tracker_jobs() -> None:
|
||||
@@ -36,13 +36,13 @@ async def _load_tracker_jobs() -> None:
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import Tracker
|
||||
from ..database.models import NotificationTracker
|
||||
|
||||
engine = get_engine()
|
||||
scheduler = get_scheduler()
|
||||
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(select(Tracker).where(Tracker.enabled == True))
|
||||
result = await session.exec(select(NotificationTracker).where(NotificationTracker.enabled == True))
|
||||
trackers = result.all()
|
||||
|
||||
for tracker in trackers:
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
Uses APScheduler to run getUpdates periodically for each bot
|
||||
with update_mode == "polling". Processes updates identically
|
||||
to the webhook handler (auto-save chat, dispatch commands).
|
||||
|
||||
Ref-counted: only starts/stops polling for bots that have active
|
||||
CommandTrackerListeners with enabled CommandTrackers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -17,7 +20,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from notify_bridge_core.notifications.telegram.media import TELEGRAM_API_BASE_URL
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import TelegramBot
|
||||
from ..database.models import CommandTracker, CommandTrackerListener, TelegramBot
|
||||
from ..services.telegram import save_chat_from_webhook
|
||||
from .scheduler import get_scheduler
|
||||
|
||||
@@ -27,18 +30,82 @@ _LOGGER = logging.getLogger(__name__)
|
||||
_last_update_id: dict[int, int] = {}
|
||||
|
||||
|
||||
async def start_bot_polling() -> None:
|
||||
"""Schedule polling jobs for all bots with update_mode == 'polling'."""
|
||||
async def _get_bot_ids_with_active_listeners() -> set[int]:
|
||||
"""Return bot IDs that have at least one active command tracker listener.
|
||||
|
||||
A bot is "active" if there is a CommandTrackerListener with
|
||||
listener_type="telegram_bot" pointing to it, AND the associated
|
||||
CommandTracker is enabled.
|
||||
"""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(TelegramBot).where(TelegramBot.update_mode == "polling")
|
||||
select(CommandTrackerListener).where(
|
||||
CommandTrackerListener.listener_type == "telegram_bot"
|
||||
)
|
||||
)
|
||||
listeners = result.all()
|
||||
|
||||
active_bot_ids: set[int] = set()
|
||||
for listener in listeners:
|
||||
tracker = await session.get(CommandTracker, listener.command_tracker_id)
|
||||
if tracker and tracker.enabled:
|
||||
active_bot_ids.add(listener.listener_id)
|
||||
|
||||
return active_bot_ids
|
||||
|
||||
|
||||
async def start_command_listener_polling() -> None:
|
||||
"""Schedule polling jobs only for bots with active command tracker listeners."""
|
||||
active_bot_ids = await _get_bot_ids_with_active_listeners()
|
||||
if not active_bot_ids:
|
||||
_LOGGER.info("No bots with active command listeners to poll")
|
||||
return
|
||||
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(TelegramBot).where(
|
||||
TelegramBot.update_mode == "polling",
|
||||
TelegramBot.id.in_(active_bot_ids),
|
||||
)
|
||||
)
|
||||
bots = result.all()
|
||||
|
||||
for bot in bots:
|
||||
schedule_bot_polling(bot.id)
|
||||
|
||||
_LOGGER.info("Started command listener polling for %d bot(s)", len(bots))
|
||||
|
||||
|
||||
async def start_bot_polling() -> None:
|
||||
"""Schedule polling jobs for all bots with update_mode == 'polling'.
|
||||
|
||||
Deprecated: prefer start_command_listener_polling() which only starts
|
||||
bots with active command tracker listeners.
|
||||
"""
|
||||
await start_command_listener_polling()
|
||||
|
||||
|
||||
async def start_bot_if_needed(bot_id: int) -> None:
|
||||
"""Start polling for a bot if it has active listeners and is not already running."""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
bot = await session.get(TelegramBot, bot_id)
|
||||
if not bot or bot.update_mode != "polling":
|
||||
return
|
||||
|
||||
active_bot_ids = await _get_bot_ids_with_active_listeners()
|
||||
if bot_id in active_bot_ids:
|
||||
schedule_bot_polling(bot_id)
|
||||
|
||||
|
||||
async def stop_bot_if_unused(bot_id: int) -> None:
|
||||
"""Stop polling for a bot if it has no enabled command tracker listeners."""
|
||||
active_bot_ids = await _get_bot_ids_with_active_listeners()
|
||||
if bot_id not in active_bot_ids:
|
||||
unschedule_bot_polling(bot_id)
|
||||
|
||||
|
||||
def schedule_bot_polling(bot_id: int) -> None:
|
||||
"""Add a polling job for a bot (idempotent)."""
|
||||
@@ -70,76 +137,82 @@ def unschedule_bot_polling(bot_id: int) -> None:
|
||||
async def _poll_bot(bot_id: int) -> None:
|
||||
"""Fetch updates from Telegram and process them."""
|
||||
engine = get_engine()
|
||||
|
||||
# Eagerly load bot data and close session before aiohttp work
|
||||
# (cannot nest aiohttp inside active SQLAlchemy async session)
|
||||
async with AsyncSession(engine) as session:
|
||||
bot = await session.get(TelegramBot, bot_id)
|
||||
if not bot or bot.update_mode != "polling":
|
||||
unschedule_bot_polling(bot_id)
|
||||
return
|
||||
# Extract what we need before closing session
|
||||
bot_token = bot.token
|
||||
bot_obj = bot
|
||||
|
||||
offset = _last_update_id.get(bot_id, 0)
|
||||
params: dict[str, Any] = {
|
||||
"timeout": 0,
|
||||
"limit": 50,
|
||||
"allowed_updates": '["message"]',
|
||||
}
|
||||
if offset:
|
||||
params["offset"] = offset + 1
|
||||
offset = _last_update_id.get(bot_id, 0)
|
||||
params: dict[str, Any] = {
|
||||
"timeout": 0,
|
||||
"limit": 50,
|
||||
"allowed_updates": '["message"]',
|
||||
}
|
||||
if offset:
|
||||
params["offset"] = offset + 1
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http:
|
||||
async with http.get(
|
||||
f"{TELEGRAM_API_BASE_URL}{bot_token}/getUpdates",
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
if not data.get("ok"):
|
||||
return
|
||||
updates = data.get("result", [])
|
||||
except Exception as e:
|
||||
_LOGGER.debug("Polling error for bot %d: %s", bot_id, e)
|
||||
return
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
# Update offset to latest
|
||||
_last_update_id[bot_id] = updates[-1]["update_id"]
|
||||
|
||||
# Process each update
|
||||
from ..commands.handler import handle_command, send_media_group
|
||||
|
||||
for update in updates:
|
||||
message = update.get("message")
|
||||
if not message:
|
||||
continue
|
||||
|
||||
chat_info = message.get("chat", {})
|
||||
chat_id = str(chat_info.get("id", ""))
|
||||
text = message.get("text", "")
|
||||
|
||||
if not chat_id:
|
||||
continue
|
||||
|
||||
# Auto-persist chat (fresh session per save)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as http:
|
||||
async with http.get(
|
||||
f"{TELEGRAM_API_BASE_URL}{bot.token}/getUpdates",
|
||||
params=params,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
if not data.get("ok"):
|
||||
return
|
||||
updates = data.get("result", [])
|
||||
except Exception as e:
|
||||
_LOGGER.debug("Polling error for bot %d: %s", bot_id, e)
|
||||
return
|
||||
async with AsyncSession(engine) as save_session:
|
||||
await save_chat_from_webhook(save_session, bot_obj.id, chat_info)
|
||||
await save_session.commit()
|
||||
except Exception:
|
||||
_LOGGER.debug("Failed to auto-save chat %s", chat_id, exc_info=True)
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
# Update offset to latest
|
||||
_last_update_id[bot_id] = updates[-1]["update_id"]
|
||||
|
||||
# Process each update
|
||||
from ..commands.handler import handle_command, send_media_group
|
||||
|
||||
for update in updates:
|
||||
message = update.get("message")
|
||||
if not message:
|
||||
continue
|
||||
|
||||
chat_info = message.get("chat", {})
|
||||
chat_id = str(chat_info.get("id", ""))
|
||||
text = message.get("text", "")
|
||||
|
||||
if not chat_id:
|
||||
continue
|
||||
|
||||
# Auto-persist chat
|
||||
# Dispatch commands
|
||||
if text and text.startswith("/"):
|
||||
try:
|
||||
async with AsyncSession(engine) as save_session:
|
||||
await save_chat_from_webhook(save_session, bot.id, chat_info)
|
||||
await save_session.commit()
|
||||
cmd_response = await handle_command(bot_obj, chat_id, text)
|
||||
if cmd_response is not None:
|
||||
if isinstance(cmd_response, list):
|
||||
await send_media_group(bot_token, chat_id, cmd_response)
|
||||
else:
|
||||
await _send_reply(bot_token, chat_id, cmd_response)
|
||||
except Exception:
|
||||
_LOGGER.debug("Failed to auto-save chat %s", chat_id, exc_info=True)
|
||||
|
||||
# Dispatch commands
|
||||
if text and text.startswith("/"):
|
||||
try:
|
||||
cmd_response = await handle_command(bot, chat_id, text)
|
||||
if cmd_response is not None:
|
||||
if isinstance(cmd_response, list):
|
||||
await send_media_group(bot.token, chat_id, cmd_response)
|
||||
else:
|
||||
await _send_reply(bot.token, chat_id, cmd_response)
|
||||
except Exception:
|
||||
_LOGGER.error("Error handling command from bot %d", bot_id, exc_info=True)
|
||||
_LOGGER.error("Error handling command from bot %d", bot_id, exc_info=True)
|
||||
|
||||
|
||||
async def _send_reply(bot_token: str, chat_id: str, text: str) -> None:
|
||||
|
||||
@@ -19,11 +19,11 @@ from ..database.engine import get_engine
|
||||
from ..database.models import (
|
||||
EventLog,
|
||||
NotificationTarget,
|
||||
NotificationTracker,
|
||||
NotificationTrackerState,
|
||||
NotificationTrackerTarget,
|
||||
ServiceProvider,
|
||||
TemplateConfig,
|
||||
Tracker,
|
||||
TrackerState,
|
||||
TrackerTarget,
|
||||
TrackingConfig,
|
||||
)
|
||||
|
||||
@@ -89,7 +89,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
|
||||
# Load all DB data eagerly before entering aiohttp context
|
||||
async with AsyncSession(engine) as session:
|
||||
tracker = await session.get(Tracker, tracker_id)
|
||||
tracker = await session.get(NotificationTracker, tracker_id)
|
||||
if not tracker or not tracker.enabled:
|
||||
return {"status": "skipped", "reason": "disabled or not found"}
|
||||
|
||||
@@ -99,7 +99,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
|
||||
# Load tracker state
|
||||
result = await session.exec(
|
||||
select(TrackerState).where(TrackerState.tracker_id == tracker_id)
|
||||
select(NotificationTrackerState).where(NotificationTrackerState.tracker_id == tracker_id)
|
||||
)
|
||||
states = result.all()
|
||||
state_dict: dict[str, Any] = {}
|
||||
@@ -113,7 +113,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
|
||||
# Load tracker-target links (replaces old target_ids JSON array)
|
||||
tt_result = await session.exec(
|
||||
select(TrackerTarget).where(TrackerTarget.tracker_id == tracker_id)
|
||||
select(NotificationTrackerTarget).where(NotificationTrackerTarget.tracker_id == tracker_id)
|
||||
)
|
||||
tracker_targets = tt_result.all()
|
||||
|
||||
@@ -188,7 +188,7 @@ async def check_tracker(tracker_id: int) -> dict[str, Any]:
|
||||
existing.shared = cstate.get("shared", False)
|
||||
session.add(existing)
|
||||
else:
|
||||
new_ts = TrackerState(
|
||||
new_ts = NotificationTrackerState(
|
||||
tracker_id=tracker_id,
|
||||
collection_id=cid,
|
||||
collection_name=cstate.get("name", ""),
|
||||
|
||||
Reference in New Issue
Block a user