feat: Home Assistant provider — WebSocket subscription + bot commands
Adds Home Assistant as a service provider with two coordinated surfaces: Notifications (subscription): - Long-lived WebSocket client (aiohttp ws_connect) with auth handshake, exponential-backoff reconnect, bounded event queue, and area-registry enrichment cached per (re)connect - ServiceProvider ABC gains an optional `subscribe()` method for push-style providers; HomeAssistantServiceProvider uses it via a per-provider supervisor task started in the FastAPI lifespan - 4 event types (state_changed, automation_triggered, call_service, event_fired), 4 default Jinja templates (en + ru), HA-specific tracker filters (entity_glob, domain_allowlist, exact entity ids) - Extracted shared dispatch pipeline (api/webhooks.py → services/ event_dispatch.py) so subscription and webhook ingest share the same event_log + deferred-dispatch + quiet-hours code path Bot commands: - /status, /entities [glob], /state <entity_id>, /areas - Multi-command WS session so /status and /areas cost one handshake - Sensitive-attribute blocklist (camera access_token, entity_picture, etc.) and 30-attribute cap to keep /state output safe and within Telegram's message size - Error-message redaction strips URL userinfo before surfacing to chat Frontend: - HA descriptor with toggle ConfigField type (new) and tag-input filter mode for free-text glob/domain lists (new TagInput component) - 15 command slots + 4 notification slots wired into the existing template-config UI
This commit is contained in:
@@ -565,6 +565,63 @@ async def preview_raw(
|
||||
"count": 2,
|
||||
# /rate_limited
|
||||
"wait": 15,
|
||||
# --- Home Assistant: /status, /entities, /state, /areas ---
|
||||
"ok": True,
|
||||
"message": "OK",
|
||||
"provider_name": "Home Assistant",
|
||||
"url": "http://homeassistant.local:8123",
|
||||
"entity_count": 142,
|
||||
"area_count": 8,
|
||||
"entities": [
|
||||
{
|
||||
"entity_id": "binary_sensor.front_door",
|
||||
"friendly_name": "Front Door",
|
||||
"domain": "binary_sensor",
|
||||
"state": "off",
|
||||
"attributes": {"device_class": "door", "friendly_name": "Front Door"},
|
||||
"device_class": "door",
|
||||
"unit_of_measurement": None,
|
||||
"last_changed": "2026-05-13T12:34:56.789+00:00",
|
||||
"last_updated": "2026-05-13T12:34:56.789+00:00",
|
||||
},
|
||||
{
|
||||
"entity_id": "sensor.kitchen_temperature",
|
||||
"friendly_name": "Kitchen Temperature",
|
||||
"domain": "sensor",
|
||||
"state": "21.4",
|
||||
"attributes": {"unit_of_measurement": "°C", "friendly_name": "Kitchen Temperature"},
|
||||
"device_class": "temperature",
|
||||
"unit_of_measurement": "°C",
|
||||
"last_changed": "2026-05-13T12:30:00+00:00",
|
||||
"last_updated": "2026-05-13T12:30:00+00:00",
|
||||
},
|
||||
],
|
||||
"glob": "binary_sensor.*",
|
||||
"total": 12,
|
||||
"shown": 2,
|
||||
# /state — single entity drill-down. ``found`` controls which branch
|
||||
# of the template renders.
|
||||
"found": True,
|
||||
"entity_id": "light.kitchen",
|
||||
"friendly_name": "Kitchen Light",
|
||||
"domain": "light",
|
||||
"state": "on",
|
||||
"attributes": {
|
||||
"brightness": 200,
|
||||
"color_mode": "brightness",
|
||||
},
|
||||
"hidden_attr_count": 0,
|
||||
"device_class": None,
|
||||
"unit_of_measurement": None,
|
||||
"last_changed": "2026-05-13T12:34:56.789+00:00",
|
||||
"last_updated": "2026-05-13T12:34:56.789+00:00",
|
||||
"reason": "",
|
||||
"error": "",
|
||||
# /areas
|
||||
"areas": [
|
||||
{"area_id": "kitchen", "name": "Kitchen", "entity_count": 14},
|
||||
{"area_id": "entrance", "name": "Entrance", "entity_count": 4},
|
||||
],
|
||||
}
|
||||
|
||||
return render_template_preview(body.template, sample_ctx)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import AnyHttpUrl, BaseModel, ValidationError, field_validator
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from typing import Any
|
||||
@@ -103,6 +103,54 @@ class WebhookProviderConfig(BaseModel):
|
||||
max_stored_payloads: int = 20 # 1-100
|
||||
|
||||
|
||||
class HomeAssistantProviderConfig(BaseModel):
|
||||
url: str
|
||||
access_token: str
|
||||
verify_tls: bool = True
|
||||
event_types: list[str] | None = None
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def _validate_url(cls, raw: str) -> str:
|
||||
"""Reject malformed URLs early so the user sees a clear error.
|
||||
|
||||
``AnyHttpUrl`` accepts the homelab-friendly forms
|
||||
(``http://homeassistant.local:8123``) while rejecting garbage like
|
||||
``not-a-url`` or ``ftp://...``. Validation is best-effort; we still
|
||||
re-derive the WebSocket URL at runtime.
|
||||
"""
|
||||
try:
|
||||
AnyHttpUrl(raw)
|
||||
except ValueError as err:
|
||||
raise ValueError(f"url must be a valid http(s) URL: {err}") from err
|
||||
return raw
|
||||
|
||||
@field_validator("event_types")
|
||||
@classmethod
|
||||
def _validate_event_types(cls, raw: list[str] | None) -> list[str] | None:
|
||||
"""Cap list size and per-entry length; reject obvious junk.
|
||||
|
||||
We don't whitelist event names — HA has unbounded custom event types
|
||||
from third-party integrations. Length and count caps are enough to
|
||||
keep a misconfiguration from blowing up the subscription handshake.
|
||||
"""
|
||||
if raw is None:
|
||||
return None
|
||||
if len(raw) > 50:
|
||||
raise ValueError("event_types accepts at most 50 entries")
|
||||
cleaned: list[str] = []
|
||||
for entry in raw:
|
||||
if not isinstance(entry, str):
|
||||
raise ValueError("event_types entries must be strings")
|
||||
entry = entry.strip()
|
||||
if not entry:
|
||||
continue
|
||||
if len(entry) > 100:
|
||||
raise ValueError("event_types entries must be <=100 chars")
|
||||
cleaned.append(entry)
|
||||
return cleaned or None
|
||||
|
||||
|
||||
_PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
|
||||
"immich": ImmichProviderConfig,
|
||||
"gitea": GiteaProviderConfig,
|
||||
@@ -111,6 +159,7 @@ _PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
|
||||
"nut": NutProviderConfig,
|
||||
"google_photos": GooglePhotosProviderConfig,
|
||||
"webhook": WebhookProviderConfig,
|
||||
"home_assistant": HomeAssistantProviderConfig,
|
||||
}
|
||||
|
||||
|
||||
@@ -160,6 +209,18 @@ async def _test_provider_connection(provider: ServiceProvider) -> dict[str, Any]
|
||||
gp = make_google_photos_provider(http_session, provider)
|
||||
return await gp.test_connection()
|
||||
|
||||
if provider.type == "home_assistant":
|
||||
from notify_bridge_core.providers.home_assistant import HomeAssistantServiceProvider
|
||||
ha = HomeAssistantServiceProvider(
|
||||
session=http_session,
|
||||
url=provider.config.get("url", ""),
|
||||
access_token=provider.config.get("access_token", ""),
|
||||
verify_tls=bool(provider.config.get("verify_tls", True)),
|
||||
event_types=provider.config.get("event_types") or None,
|
||||
name=provider.name,
|
||||
)
|
||||
return await ha.test_connection()
|
||||
|
||||
if provider.type in ("scheduler", "webhook"):
|
||||
return {"ok": True, "message": "Virtual provider — always available"}
|
||||
|
||||
|
||||
@@ -285,6 +285,8 @@ async def get_template_variables(
|
||||
**_planka_variables(),
|
||||
# --- NUT (UPS) slots ---
|
||||
**_nut_variables(),
|
||||
# --- Home Assistant slots ---
|
||||
**_home_assistant_variables(),
|
||||
# --- Scheduler slots ---
|
||||
"message_scheduled_message": {
|
||||
"description": "Notification for scheduled message events",
|
||||
@@ -433,6 +435,58 @@ def _nut_variables() -> dict:
|
||||
}
|
||||
|
||||
|
||||
def _home_assistant_variables() -> dict:
|
||||
common = {
|
||||
"entity_id": "HA entity id (e.g. light.kitchen)",
|
||||
"friendly_name": "Human-readable entity name from attributes.friendly_name",
|
||||
"domain": "HA domain prefix (light, sensor, binary_sensor, ...)",
|
||||
"attributes": "Full attributes dict of the new state",
|
||||
"device_class": "Device class (motion, door, temperature, ...)",
|
||||
"unit_of_measurement": "Unit suffix for numeric sensors",
|
||||
"area": "Area name from the HA area registry (empty when not assigned)",
|
||||
"ha_event_type": "Raw HA event_type (state_changed, automation_triggered, ...)",
|
||||
"last_changed": "ISO timestamp of last state change",
|
||||
"last_updated": "ISO timestamp of last attribute or state update",
|
||||
}
|
||||
return {
|
||||
"message_ha_state_changed": {
|
||||
"description": "Entity state changed",
|
||||
"variables": {
|
||||
**common,
|
||||
"old_state": "Previous state string",
|
||||
"new_state": "New state string ('removed' if entity deleted)",
|
||||
},
|
||||
},
|
||||
"message_ha_automation_triggered": {
|
||||
"description": "Automation triggered",
|
||||
"variables": {
|
||||
"entity_id": common["entity_id"],
|
||||
"automation_name": "Automation name",
|
||||
"trigger_source": "Why the automation fired",
|
||||
"ha_event_type": common["ha_event_type"],
|
||||
},
|
||||
},
|
||||
"message_ha_service_called": {
|
||||
"description": "HA service called",
|
||||
"variables": {
|
||||
"service_called": "Qualified service name (e.g. light.turn_on)",
|
||||
"service_domain": "Service domain",
|
||||
"service_name": "Service name within domain",
|
||||
"service_data": "Service payload dict",
|
||||
"target_entity": "entity_id targeted by the call (comma-joined for multi-target)",
|
||||
"ha_event_type": common["ha_event_type"],
|
||||
},
|
||||
},
|
||||
"message_ha_event_fired": {
|
||||
"description": "Other HA event fired (catch-all)",
|
||||
"variables": {
|
||||
"ha_event_type": common["ha_event_type"],
|
||||
"event_data": "Raw event data dict (use {{ event_data | tojson }} to render)",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
async def create_config(
|
||||
body: TemplateConfigCreate,
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.models.events import ServiceEvent
|
||||
from notify_bridge_core.notifications.dispatcher import NotificationDispatcher, TargetConfig
|
||||
from notify_bridge_core.providers.gitea.event_parser import parse_webhook as parse_gitea_webhook
|
||||
from notify_bridge_core.providers.planka.event_parser import parse_webhook as parse_planka_webhook
|
||||
from notify_bridge_core.providers.webhook.event_parser import parse_webhook as parse_generic_webhook
|
||||
@@ -27,13 +26,7 @@ from ..database.models import (
|
||||
ServiceProvider,
|
||||
WebhookPayloadLog,
|
||||
)
|
||||
from ..services.dispatch_helpers import (
|
||||
GateReason,
|
||||
apply_tracking_display_filters,
|
||||
evaluate_event_gate,
|
||||
get_app_timezone,
|
||||
load_link_data,
|
||||
)
|
||||
from ..services.event_dispatch import dispatch_provider_event
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -131,7 +124,7 @@ def _passes_filters(
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared dispatch helper
|
||||
# Shared dispatch helper (legacy wrapper — body moved to services/event_dispatch.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _dispatch_webhook_event(
|
||||
@@ -142,185 +135,16 @@ async def _dispatch_webhook_event(
|
||||
event: ServiceEvent,
|
||||
detail_keys: tuple[str, ...],
|
||||
) -> int:
|
||||
"""Load trackers, filter, create EventLogs, dispatch notifications, and commit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
engine:
|
||||
SQLAlchemy async engine.
|
||||
provider_id:
|
||||
ID of the ServiceProvider that received the webhook.
|
||||
provider_name:
|
||||
Human-readable name of the provider (for logging).
|
||||
provider_config:
|
||||
The provider's ``config`` dict (passed through to target config builder).
|
||||
event:
|
||||
Parsed :class:`ServiceEvent` to dispatch.
|
||||
detail_keys:
|
||||
Keys from ``event.extra`` to include in the EventLog ``details`` dict.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of successfully dispatched notifications.
|
||||
"""
|
||||
dispatched = 0
|
||||
# ``defers_to_schedule`` is collected during the loop and flushed AFTER the
|
||||
# main session commits — the only side-effect of failing to schedule is a
|
||||
# delayed delivery (the startup loader / catch-up scan will reschedule),
|
||||
# so this is best-effort and must not roll back the DB writes.
|
||||
defers_to_schedule: set[Any] = set()
|
||||
async with AsyncSession(engine) as session:
|
||||
# App timezone is identical across trackers within one webhook request;
|
||||
# pull it once.
|
||||
app_tz = await get_app_timezone(session)
|
||||
|
||||
tracker_result = await session.exec(
|
||||
select(NotificationTracker).where(
|
||||
NotificationTracker.provider_id == provider_id,
|
||||
NotificationTracker.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
trackers = tracker_result.all()
|
||||
|
||||
from ..services.deferred_dispatch import defer_event, is_deferrable
|
||||
|
||||
for tracker in trackers:
|
||||
filters = tracker.filters or {}
|
||||
if not _passes_filters(event, filters):
|
||||
_LOGGER.debug(
|
||||
"Event filtered out for tracker %d (%s)", tracker.id, tracker.name
|
||||
)
|
||||
continue
|
||||
|
||||
link_data = await load_link_data(session, tracker.id)
|
||||
if not link_data:
|
||||
continue
|
||||
|
||||
# Log event
|
||||
extra_details = {k: v for k, v in event.extra.items() if k in detail_keys}
|
||||
event_log_row = EventLog(
|
||||
user_id=tracker.user_id,
|
||||
tracker_id=tracker.id,
|
||||
tracker_name=tracker.name,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
event_type=event.event_type.value,
|
||||
collection_id=event.collection_id,
|
||||
collection_name=event.collection_name,
|
||||
assets_count=0,
|
||||
details={
|
||||
"provider_type": event.provider_type.value,
|
||||
**extra_details,
|
||||
},
|
||||
)
|
||||
session.add(event_log_row)
|
||||
await session.flush()
|
||||
event_log_id = event_log_row.id
|
||||
|
||||
# Dedupe defers by parent ``link_id``: broadcast links emit one
|
||||
# ``link_data`` entry per child, all sharing the same parent id —
|
||||
# the deferred row is one-per-link, so we only call ``defer_event``
|
||||
# once per distinct id (earliest fire_at wins on ties).
|
||||
groups: dict[int, tuple[Any, list[TargetConfig]]] = {}
|
||||
defers_for_event: dict[int, Any] = {}
|
||||
for ld in link_data:
|
||||
tc = ld["tracking_config"]
|
||||
if tc is not None:
|
||||
outcome = evaluate_event_gate(event, tc, app_tz)
|
||||
if outcome.reason is GateReason.QUIET_HOURS:
|
||||
if is_deferrable(event.event_type.value) and outcome.quiet_hours_end_at is not None:
|
||||
link_id = ld.get("link_id")
|
||||
if link_id is not None:
|
||||
prior = defers_for_event.get(link_id)
|
||||
if prior is None or outcome.quiet_hours_end_at < prior:
|
||||
defers_for_event[link_id] = outcome.quiet_hours_end_at
|
||||
continue
|
||||
if outcome.reason is GateReason.EVENT_TYPE_DISABLED:
|
||||
continue
|
||||
|
||||
tmpl = ld["template_config"]
|
||||
target_cfg = TargetConfig(
|
||||
type=ld["target_type"],
|
||||
config=ld["target_config"],
|
||||
template_slots=ld["template_slots"],
|
||||
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
|
||||
date_only_format=tmpl.date_only_format if tmpl and tmpl.date_only_format else "%d.%m.%Y",
|
||||
provider_api_key=provider_config.get("api_token"),
|
||||
provider_internal_url=provider_config.get("url", ""),
|
||||
provider_external_url=provider_config.get("url", ""),
|
||||
receivers=ld["receivers"],
|
||||
)
|
||||
key = id(tc) if tc is not None else 0
|
||||
if key not in groups:
|
||||
groups[key] = (tc, [])
|
||||
groups[key][1].append(target_cfg)
|
||||
|
||||
# Persist defers + stamp event_log dispatch_status in the same
|
||||
# session that holds the EventLog row, so the "deferred" badge
|
||||
# only appears if the underlying queue rows actually exist.
|
||||
if defers_for_event:
|
||||
earliest = min(defers_for_event.values())
|
||||
for link_id, fire_at in defers_for_event.items():
|
||||
await defer_event(
|
||||
session,
|
||||
event=event,
|
||||
user_id=tracker.user_id,
|
||||
tracker_id=tracker.id,
|
||||
link_id=link_id,
|
||||
event_log_id=event_log_id,
|
||||
fire_at=fire_at,
|
||||
)
|
||||
details = dict(event_log_row.details or {})
|
||||
if not details.get("dispatch_status"):
|
||||
details["dispatch_status"] = "deferred"
|
||||
details["deferred_until"] = earliest.isoformat()
|
||||
event_log_row.details = details
|
||||
session.add(event_log_row)
|
||||
defers_to_schedule.update(defers_for_event.values())
|
||||
|
||||
# Dispatch to targets. Isolate dispatcher exceptions per group so
|
||||
# a failed remote call doesn't bubble out, abort the surrounding
|
||||
# transaction, and roll back the just-written defers/event_log.
|
||||
from ..services.http_session import get_http_session
|
||||
dispatcher = NotificationDispatcher(session=await get_http_session())
|
||||
for tc, target_configs in groups.values():
|
||||
if not target_configs:
|
||||
continue
|
||||
shaped_event = apply_tracking_display_filters(event, tc)
|
||||
if shaped_event is None:
|
||||
continue
|
||||
try:
|
||||
results = await dispatcher.dispatch(shaped_event, target_configs)
|
||||
except Exception as err: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"Dispatcher raised for tracker %d: %s", tracker.id, err,
|
||||
)
|
||||
continue
|
||||
for r in results:
|
||||
if r.get("success"):
|
||||
dispatched += 1
|
||||
else:
|
||||
_LOGGER.error(
|
||||
"Notification failed for tracker %d: %s",
|
||||
tracker.id, r.get("error", "unknown"),
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Schedule drain jobs OUTSIDE the DB session so an APScheduler hiccup
|
||||
# can't roll back the persisted defer rows.
|
||||
if defers_to_schedule:
|
||||
from ..services.scheduler import schedule_deferred_drain
|
||||
for fire_at in defers_to_schedule:
|
||||
try:
|
||||
schedule_deferred_drain(fire_at)
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"Failed to schedule deferred drain for %s", fire_at,
|
||||
)
|
||||
|
||||
return dispatched
|
||||
"""Webhook-flavoured dispatch — thin wrapper over ``dispatch_provider_event``."""
|
||||
return await dispatch_provider_event(
|
||||
engine=engine,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
provider_config=provider_config,
|
||||
event=event,
|
||||
detail_keys=detail_keys,
|
||||
filter_fn=_passes_filters,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -34,12 +34,14 @@ def _auto_register() -> None:
|
||||
from .planka_handler import PlankaCommandHandler
|
||||
from .nut_handler import NutCommandHandler
|
||||
from .webhook_handler import WebhookCommandHandler
|
||||
from .home_assistant_handler import HomeAssistantCommandHandler
|
||||
|
||||
register_handler(ImmichCommandHandler())
|
||||
register_handler(GiteaCommandHandler())
|
||||
register_handler(PlankaCommandHandler())
|
||||
register_handler(NutCommandHandler())
|
||||
register_handler(WebhookCommandHandler())
|
||||
register_handler(HomeAssistantCommandHandler())
|
||||
|
||||
|
||||
# Auto-register on import
|
||||
|
||||
@@ -0,0 +1,375 @@
|
||||
"""Home Assistant bot command handler.
|
||||
|
||||
Phase 2 of the HA integration. Each command opens a fresh WebSocket
|
||||
connection to HA — same approach used by ``HomeAssistantServiceProvider.
|
||||
list_collections`` — so the handler does not need to coordinate with the
|
||||
long-lived subscription supervisor.
|
||||
|
||||
Commands:
|
||||
|
||||
* ``/status`` — connection health, subscribed area / entity counts.
|
||||
* ``/entities [glob]`` — list matching entities with their current state.
|
||||
* ``/state <entity_id>`` — full state + attributes for one entity.
|
||||
* ``/areas`` — area registry summary with entity counts per area.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from fnmatch import fnmatchcase
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from notify_bridge_core.providers.home_assistant import (
|
||||
HomeAssistantApiError,
|
||||
HomeAssistantAuthError,
|
||||
HomeAssistantWSClient,
|
||||
redact_ha_message,
|
||||
)
|
||||
|
||||
from ..database.models import (
|
||||
CommandConfig,
|
||||
CommandTracker,
|
||||
CommandTrackerListener,
|
||||
ServiceProvider,
|
||||
TelegramBot,
|
||||
)
|
||||
from ..services.http_session import get_http_session
|
||||
from .base import CommandResponse, ProviderCommandHandler
|
||||
from .handler import _render_cmd_template
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_HA_COMMANDS = {"status", "entities", "state", "areas"}
|
||||
|
||||
|
||||
# HA exposes credentials and tokens through state attributes for some
|
||||
# integrations — most notably ``camera.*`` entities surface a working
|
||||
# ``access_token`` for the camera proxy URL, and ``entity_picture`` can
|
||||
# carry signed URLs. Filtering keys by substring blocklist before rendering
|
||||
# protects the chat user from seeing those values in /state output.
|
||||
#
|
||||
# Match is case-insensitive substring; tokens are intentionally generic so
|
||||
# custom integrations that follow the obvious naming conventions are also
|
||||
# covered. Anything not matched still renders.
|
||||
_SENSITIVE_ATTR_TOKENS: tuple[str, ...] = (
|
||||
"access_token",
|
||||
"token",
|
||||
"secret",
|
||||
"password",
|
||||
"passwd",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"private_key",
|
||||
"session_id",
|
||||
"authorization",
|
||||
"bearer",
|
||||
"cookie",
|
||||
# ``entity_picture`` is a URL that often embeds a signed token in its
|
||||
# query string (HA generates these for camera and media_player entities).
|
||||
# The key itself doesn't match the credential token blocklist, so it
|
||||
# gets its own explicit entry.
|
||||
"entity_picture",
|
||||
)
|
||||
|
||||
# Attributes already rendered as top-level fields by the state template; no
|
||||
# point repeating them in the "Attributes" iteration.
|
||||
_TOP_LEVEL_ATTRS: frozenset[str] = frozenset({
|
||||
"friendly_name", "unit_of_measurement", "device_class",
|
||||
})
|
||||
|
||||
# Hard cap on the number of attributes shown in /state to prevent message
|
||||
# truncation when an entity has dozens (e.g. weather hourly forecasts,
|
||||
# light supported features). After the cap, an "and N more" line is added
|
||||
# by the template logic.
|
||||
_MAX_ATTRIBUTES_RENDERED = 30
|
||||
|
||||
|
||||
def _is_sensitive_attr(key: str) -> bool:
|
||||
lowered = str(key).lower()
|
||||
return any(tok in lowered for tok in _SENSITIVE_ATTR_TOKENS)
|
||||
|
||||
|
||||
def _filter_attributes(attrs: dict[str, Any]) -> tuple[dict[str, Any], int]:
|
||||
"""Drop sensitive keys, cap count, return ``(visible_attrs, hidden_count)``.
|
||||
|
||||
Hidden count covers both the security filter (blocklisted keys) and the
|
||||
size cap (entries beyond ``_MAX_ATTRIBUTES_RENDERED``). The template can
|
||||
surface "and N more hidden" so users know the view is incomplete.
|
||||
"""
|
||||
if not isinstance(attrs, dict):
|
||||
return {}, 0
|
||||
safe: dict[str, Any] = {}
|
||||
redacted = 0
|
||||
for key, value in attrs.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
if key in _TOP_LEVEL_ATTRS:
|
||||
continue
|
||||
if _is_sensitive_attr(key):
|
||||
redacted += 1
|
||||
continue
|
||||
safe[key] = value
|
||||
overflow = max(0, len(safe) - _MAX_ATTRIBUTES_RENDERED)
|
||||
if overflow > 0:
|
||||
# Stable order — sort by key so the truncation point is deterministic.
|
||||
capped = dict(sorted(safe.items())[:_MAX_ATTRIBUTES_RENDERED])
|
||||
return capped, redacted + overflow
|
||||
return safe, redacted
|
||||
|
||||
|
||||
def _make_ws_client(provider: ServiceProvider, session: aiohttp.ClientSession) -> HomeAssistantWSClient:
|
||||
"""Build a one-shot WS client from the provider row."""
|
||||
config = provider.config or {}
|
||||
return HomeAssistantWSClient(
|
||||
session=session,
|
||||
base_url=config.get("url", ""),
|
||||
access_token=config.get("access_token", ""),
|
||||
verify_tls=bool(config.get("verify_tls", True)),
|
||||
)
|
||||
|
||||
|
||||
def _domain_of(entity_id: str) -> str:
|
||||
return entity_id.split(".", 1)[0] if "." in entity_id else ""
|
||||
|
||||
|
||||
def _normalize_state(state_row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Flatten an HA state dict into the shape templates consume.
|
||||
|
||||
``attributes`` is filtered through ``_filter_attributes`` to drop
|
||||
credential-like keys (e.g. ``camera.access_token``) and cap the rendered
|
||||
count. ``hidden_attr_count`` is exposed so the template can surface
|
||||
"and N more hidden" if the user wants to see everything they need to
|
||||
use a different tool (or the HA UI itself).
|
||||
"""
|
||||
entity_id = state_row.get("entity_id") or ""
|
||||
raw_attrs = state_row.get("attributes") or {}
|
||||
visible_attrs, hidden_count = _filter_attributes(raw_attrs)
|
||||
return {
|
||||
"entity_id": entity_id,
|
||||
"friendly_name": raw_attrs.get("friendly_name") or entity_id,
|
||||
"domain": _domain_of(entity_id),
|
||||
"state": state_row.get("state"),
|
||||
"attributes": visible_attrs,
|
||||
"hidden_attr_count": hidden_count,
|
||||
"device_class": raw_attrs.get("device_class"),
|
||||
"unit_of_measurement": raw_attrs.get("unit_of_measurement"),
|
||||
"last_changed": state_row.get("last_changed"),
|
||||
"last_updated": state_row.get("last_updated"),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Command implementations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _cmd_status(provider: ServiceProvider) -> dict[str, Any]:
|
||||
"""``/status`` — connection health + counts.
|
||||
|
||||
Health is derived from a live connection rather than the supervisor's
|
||||
in-memory state so the user sees what's happening *right now* if they
|
||||
just edited the token / URL. Connection + entity-count + area-count run
|
||||
on a single WS session so a healthy /status costs one TCP + TLS + WS +
|
||||
auth handshake instead of three.
|
||||
"""
|
||||
session = await get_http_session()
|
||||
client = _make_ws_client(provider, session)
|
||||
ok = True
|
||||
message = "OK"
|
||||
entity_count = 0
|
||||
area_count = 0
|
||||
|
||||
try:
|
||||
async with client.session() as sess:
|
||||
# Reaching here proves connect + auth succeeded.
|
||||
try:
|
||||
entity_count = len(await sess.get_states())
|
||||
except HomeAssistantApiError as err:
|
||||
_LOGGER.debug("HA /status get_states failed: %s", err)
|
||||
try:
|
||||
area_count = len(await sess.get_area_registry())
|
||||
except HomeAssistantApiError as err:
|
||||
_LOGGER.debug("HA /status get_area_registry failed: %s", err)
|
||||
except HomeAssistantAuthError as err:
|
||||
ok = False
|
||||
message = f"Auth failed: {redact_ha_message(str(err))}"
|
||||
except (aiohttp.ClientError, HomeAssistantApiError) as err:
|
||||
ok = False
|
||||
message = redact_ha_message(str(err)) or "Connection failed"
|
||||
|
||||
return {
|
||||
"ok": ok,
|
||||
"message": message,
|
||||
"provider_name": provider.name or "",
|
||||
"url": (provider.config or {}).get("url", ""),
|
||||
"entity_count": entity_count,
|
||||
"area_count": area_count,
|
||||
}
|
||||
|
||||
|
||||
async def _cmd_entities(provider: ServiceProvider, args: str, count: int) -> dict[str, Any]:
|
||||
"""``/entities [glob]`` — entities filtered by glob, capped at ``count``.
|
||||
|
||||
Empty args returns the first ``count`` entities in entity_id order. A
|
||||
glob pattern is matched against the entity_id (case-insensitive). The
|
||||
normalization step (which walks the attribute dict to redact secrets)
|
||||
runs **only on the survivors** — sorting and slicing happen on raw
|
||||
state rows first, so an HA install with 1000+ entities doesn't
|
||||
materialize 1000 normalized dicts just to discard most of them.
|
||||
"""
|
||||
session = await get_http_session()
|
||||
client = _make_ws_client(provider, session)
|
||||
try:
|
||||
async with client.session() as sess:
|
||||
states = await sess.get_states()
|
||||
except (HomeAssistantApiError, aiohttp.ClientError, HomeAssistantAuthError) as err:
|
||||
redacted = redact_ha_message(str(err))
|
||||
_LOGGER.warning("HA /entities failed: %s", redacted)
|
||||
return {"entities": [], "glob": args.strip(), "total": 0, "error": redacted}
|
||||
|
||||
glob = args.strip()
|
||||
if glob:
|
||||
lower_glob = glob.lower()
|
||||
raw_matches = [
|
||||
s for s in states
|
||||
if isinstance(s.get("entity_id"), str)
|
||||
and fnmatchcase(s["entity_id"].lower(), lower_glob)
|
||||
]
|
||||
else:
|
||||
raw_matches = [s for s in states if isinstance(s.get("entity_id"), str)]
|
||||
|
||||
total = len(raw_matches)
|
||||
raw_matches.sort(key=lambda s: s.get("entity_id", ""))
|
||||
return {
|
||||
"entities": [_normalize_state(s) for s in raw_matches[:count]],
|
||||
"glob": glob,
|
||||
"total": total,
|
||||
"shown": min(count, total),
|
||||
}
|
||||
|
||||
|
||||
async def _cmd_state(provider: ServiceProvider, args: str) -> dict[str, Any]:
|
||||
"""``/state <entity_id>`` — single-entity drill-down.
|
||||
|
||||
Returns ``found=False`` when the entity_id is missing or not present.
|
||||
Templates render the no-results fallback in that case. Uses the session
|
||||
context manager for consistency with the other commands even though
|
||||
there's only one underlying WS call today — leaves the door open for
|
||||
Phase 3 (service calls) to chain a follow-up on the same socket.
|
||||
"""
|
||||
target = args.strip()
|
||||
if not target:
|
||||
return {"found": False, "entity_id": "", "reason": "missing_arg"}
|
||||
|
||||
session = await get_http_session()
|
||||
client = _make_ws_client(provider, session)
|
||||
try:
|
||||
async with client.session() as sess:
|
||||
states = await sess.get_states()
|
||||
except (HomeAssistantApiError, aiohttp.ClientError, HomeAssistantAuthError) as err:
|
||||
redacted = redact_ha_message(str(err))
|
||||
_LOGGER.warning("HA /state failed: %s", redacted)
|
||||
return {"found": False, "entity_id": target, "reason": "api_error", "error": redacted}
|
||||
|
||||
for s in states:
|
||||
if s.get("entity_id") == target:
|
||||
normalized = _normalize_state(s)
|
||||
return {"found": True, **normalized}
|
||||
|
||||
return {"found": False, "entity_id": target, "reason": "not_found"}
|
||||
|
||||
|
||||
async def _cmd_areas(provider: ServiceProvider) -> dict[str, Any]:
|
||||
"""``/areas`` — area registry with per-area entity counts.
|
||||
|
||||
Areas without entities are still listed so users can see which areas
|
||||
exist in HA but haven't been assigned anything. The entity counts come
|
||||
from the entity registry, not the state list — the registry includes
|
||||
disabled entities, which matches what users see in the HA UI. Both
|
||||
registry calls share a single WS session so /areas costs one handshake.
|
||||
"""
|
||||
session = await get_http_session()
|
||||
client = _make_ws_client(provider, session)
|
||||
try:
|
||||
async with client.session() as sess:
|
||||
areas = await sess.get_area_registry()
|
||||
# Entity registry failure is non-fatal — areas can still be
|
||||
# listed without per-area counts.
|
||||
try:
|
||||
entities = await sess.get_entity_registry()
|
||||
except HomeAssistantApiError:
|
||||
entities = []
|
||||
except (HomeAssistantApiError, aiohttp.ClientError, HomeAssistantAuthError) as err:
|
||||
redacted = redact_ha_message(str(err))
|
||||
_LOGGER.warning("HA /areas failed: %s", redacted)
|
||||
return {"areas": [], "total": 0, "error": redacted}
|
||||
|
||||
counts: dict[str, int] = {}
|
||||
for ent in entities:
|
||||
area_id = ent.get("area_id")
|
||||
if isinstance(area_id, str):
|
||||
counts[area_id] = counts.get(area_id, 0) + 1
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for a in areas:
|
||||
area_id = a.get("area_id")
|
||||
if not isinstance(area_id, str):
|
||||
continue
|
||||
rows.append({
|
||||
"area_id": area_id,
|
||||
"name": a.get("name") or area_id,
|
||||
"entity_count": counts.get(area_id, 0),
|
||||
})
|
||||
rows.sort(key=lambda r: r.get("name", "").lower())
|
||||
return {"areas": rows, "total": len(rows)}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Handler class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HomeAssistantCommandHandler(ProviderCommandHandler):
|
||||
"""Routes ``/status``, ``/entities``, ``/state``, ``/areas`` to the WS client."""
|
||||
|
||||
provider_type = "home_assistant"
|
||||
|
||||
def get_provider_commands(self) -> set[str]:
|
||||
return _HA_COMMANDS
|
||||
|
||||
def get_rate_categories(self) -> dict[str, str]:
|
||||
# All HA commands hit the WS API and share an "api" rate-limit bucket.
|
||||
return {cmd: "api" for cmd in _HA_COMMANDS}
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
cmd: str,
|
||||
args: str,
|
||||
count: int,
|
||||
locale: str,
|
||||
response_mode: str, # noqa: ARG002 — HA has no media commands; always text
|
||||
provider: ServiceProvider,
|
||||
cmd_templates: dict[str, dict[str, str]],
|
||||
bot: TelegramBot, # noqa: ARG002
|
||||
tracker: CommandTracker, # noqa: ARG002
|
||||
config: CommandConfig, # noqa: ARG002
|
||||
*,
|
||||
listener: CommandTrackerListener | None = None, # noqa: ARG002
|
||||
allowed_album_ids: set[str] | None = None, # noqa: ARG002 — HA has no album scope
|
||||
page: int = 1, # noqa: ARG002 — no pagination in v1
|
||||
) -> CommandResponse | None:
|
||||
if cmd == "status":
|
||||
ctx = await _cmd_status(provider)
|
||||
elif cmd == "entities":
|
||||
ctx = await _cmd_entities(provider, args, count)
|
||||
elif cmd == "state":
|
||||
ctx = await _cmd_state(provider, args)
|
||||
elif cmd == "areas":
|
||||
ctx = await _cmd_areas(provider)
|
||||
else:
|
||||
return None
|
||||
|
||||
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
|
||||
@@ -11,6 +11,8 @@ _RATE_CATEGORY: dict[str, str] = {
|
||||
"repos": "api", "issues": "api", "prs": "api", "commits": "api",
|
||||
# Planka (API calls share a category)
|
||||
"boards": "api", "cards": "api", "lists": "api",
|
||||
# Home Assistant (WebSocket queries share a category)
|
||||
"entities": "api", "state": "api", "areas": "api",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -292,6 +292,23 @@ async def migrate_schema(engine: AsyncEngine) -> None:
|
||||
)
|
||||
logger.info("Added track_webhook_received column to tracking_config table")
|
||||
|
||||
# Add Home Assistant tracking flags to tracking_config if missing.
|
||||
# state_changed defaults ON to match the canonical "watch the state bus"
|
||||
# use case; the other three are loud and opt-in (defaults 0).
|
||||
if await _has_table(conn, "tracking_config"):
|
||||
ha_flags = [
|
||||
("track_ha_state_changed", "INTEGER DEFAULT 1"),
|
||||
("track_ha_automation_triggered", "INTEGER DEFAULT 0"),
|
||||
("track_ha_service_called", "INTEGER DEFAULT 0"),
|
||||
("track_ha_event_fired", "INTEGER DEFAULT 0"),
|
||||
]
|
||||
for col_name, col_type in ha_flags:
|
||||
if not await _has_column(conn, "tracking_config", col_name):
|
||||
await conn.execute(
|
||||
text(f"ALTER TABLE tracking_config ADD COLUMN {col_name} {col_type}")
|
||||
)
|
||||
logger.info("Added %s column to tracking_config table", col_name)
|
||||
|
||||
# Add quiet hours to tracking_config if missing.
|
||||
# Start/end are nullable HH:MM strings; quiet_hours_enabled gates them.
|
||||
if await _has_table(conn, "tracking_config"):
|
||||
|
||||
@@ -165,6 +165,12 @@ class TrackingConfig(SQLModel, table=True):
|
||||
# Generic Webhook event tracking
|
||||
track_webhook_received: bool = Field(default=True)
|
||||
|
||||
# Home Assistant event tracking
|
||||
track_ha_state_changed: bool = Field(default=True)
|
||||
track_ha_automation_triggered: bool = Field(default=False)
|
||||
track_ha_service_called: bool = Field(default=False)
|
||||
track_ha_event_fired: bool = Field(default=False)
|
||||
|
||||
# Immich asset display
|
||||
track_images: bool = Field(default=True)
|
||||
track_videos: bool = Field(default=True)
|
||||
|
||||
@@ -158,6 +158,7 @@ async def _seed_default_templates() -> None:
|
||||
await _seed_provider_template(session, "nut", "NUT")
|
||||
await _seed_provider_template(session, "google_photos", "Google Photos")
|
||||
await _seed_provider_template(session, "webhook", "Generic Webhook")
|
||||
await _seed_provider_template(session, "home_assistant", "Home Assistant")
|
||||
await session.commit()
|
||||
|
||||
|
||||
@@ -187,6 +188,10 @@ async def _seed_default_command_templates() -> None:
|
||||
await _seed_provider_command_template(
|
||||
session, "webhook", "Default Webhook Commands", "Default Generic Webhook command templates",
|
||||
)
|
||||
await _seed_provider_command_template(
|
||||
session, "home_assistant", "Default Home Assistant Commands",
|
||||
"Default Home Assistant command templates",
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@@ -272,6 +277,14 @@ async def _seed_default_tracking_configs() -> None:
|
||||
"track_ups_replace_battery": True,
|
||||
"track_ups_overload": True,
|
||||
},
|
||||
{
|
||||
"provider_type": "home_assistant",
|
||||
"name": "Default Home Assistant",
|
||||
"track_ha_state_changed": True,
|
||||
"track_ha_automation_triggered": False,
|
||||
"track_ha_service_called": False,
|
||||
"track_ha_event_fired": False,
|
||||
},
|
||||
]
|
||||
|
||||
for cfg in defaults:
|
||||
|
||||
@@ -139,12 +139,20 @@ async def lifespan(app: FastAPI):
|
||||
set_webhook_secret(_secret or None)
|
||||
from .services.scheduler import start_scheduler, get_scheduler
|
||||
await start_scheduler()
|
||||
# Phase 1 of the Home Assistant provider: subscription-based ingest runs
|
||||
# outside the polling scheduler. ``start_all`` spawns one supervisor task
|
||||
# per enabled HA provider row. No-op when no HA providers are configured.
|
||||
from .services.ha_subscription import start_all as start_ha_subscriptions
|
||||
await start_ha_subscriptions()
|
||||
_READY = True
|
||||
yield
|
||||
# Graceful shutdown — stop the scheduler FIRST so in-flight jobs finish
|
||||
# before we close their HTTP session. Then close the shared session and
|
||||
# dispose the DB engine.
|
||||
# Graceful shutdown — cancel HA supervisors FIRST so they release their
|
||||
# WS connections before the shared HTTP session is closed. Then stop the
|
||||
# polling scheduler. Order matters: scheduler.shutdown(wait=True) drains
|
||||
# in-flight jobs that may also use the shared session.
|
||||
_READY = False
|
||||
from .services.ha_subscription import stop_all as stop_ha_subscriptions
|
||||
await stop_ha_subscriptions()
|
||||
scheduler = get_scheduler()
|
||||
if scheduler.running:
|
||||
scheduler.shutdown(wait=True)
|
||||
|
||||
@@ -115,6 +115,16 @@ def _make_collection_provider(
|
||||
return make_planka_provider(http_session, provider)
|
||||
if ptype == "google_photos":
|
||||
return make_google_photos_provider(http_session, provider)
|
||||
if ptype == "home_assistant":
|
||||
from notify_bridge_core.providers.home_assistant import HomeAssistantServiceProvider
|
||||
return HomeAssistantServiceProvider(
|
||||
session=http_session,
|
||||
url=config.get("url", ""),
|
||||
access_token=config.get("access_token", ""),
|
||||
verify_tls=bool(config.get("verify_tls", True)),
|
||||
event_types=config.get("event_types") or None,
|
||||
name=provider.name,
|
||||
)
|
||||
# NUT provider needs no http_session
|
||||
if ptype == "nut":
|
||||
return make_nut_provider(provider) # type: ignore[return-value]
|
||||
@@ -122,7 +132,7 @@ def _make_collection_provider(
|
||||
|
||||
|
||||
# Set of provider types that need an aiohttp session for collection listing.
|
||||
_HTTP_COLLECTION_PROVIDERS = {"immich", "gitea", "planka", "google_photos"}
|
||||
_HTTP_COLLECTION_PROVIDERS = {"immich", "gitea", "planka", "google_photos", "home_assistant"}
|
||||
|
||||
|
||||
async def list_provider_collections(provider: ServiceProvider) -> list[dict[str, Any]]:
|
||||
|
||||
@@ -204,6 +204,12 @@ def _event_type_enabled(event: ServiceEvent, tc: TrackingConfig) -> bool:
|
||||
"ups_comms_restored": tc.track_ups_comms_restored,
|
||||
"ups_replace_battery": tc.track_ups_replace_battery,
|
||||
"ups_overload": tc.track_ups_overload,
|
||||
# Home Assistant events — use getattr so legacy DB rows / test mocks
|
||||
# that pre-date the columns still pass the gate (default to tracked).
|
||||
"ha_state_changed": getattr(tc, "track_ha_state_changed", True),
|
||||
"ha_automation_triggered": getattr(tc, "track_ha_automation_triggered", False),
|
||||
"ha_service_called": getattr(tc, "track_ha_service_called", False),
|
||||
"ha_event_fired": getattr(tc, "track_ha_event_fired", False),
|
||||
}
|
||||
return flag_map.get(event_type, True)
|
||||
|
||||
|
||||
@@ -0,0 +1,239 @@
|
||||
"""Shared dispatch helper for push-style providers.
|
||||
|
||||
Push-style providers (webhook receivers in ``api/webhooks.py`` and the
|
||||
Home Assistant subscription manager in ``services/ha_subscription.py``)
|
||||
share the same downstream pipeline: write an :class:`EventLog`, evaluate
|
||||
quiet hours / event-type gates, defer if needed, otherwise hand off to the
|
||||
:class:`NotificationDispatcher`.
|
||||
|
||||
This module extracts that pipeline so both callers can reuse it without
|
||||
either side importing from the other (which would create a server/api ->
|
||||
services -> api cycle).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.models.events import ServiceEvent
|
||||
from notify_bridge_core.notifications.dispatcher import (
|
||||
NotificationDispatcher,
|
||||
TargetConfig,
|
||||
)
|
||||
|
||||
from ..database.models import EventLog, NotificationTracker
|
||||
from .deferred_dispatch import defer_event, is_deferrable
|
||||
from .dispatch_helpers import (
|
||||
GateReason,
|
||||
apply_tracking_display_filters,
|
||||
evaluate_event_gate,
|
||||
get_app_timezone,
|
||||
load_link_data,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Filter signature: ``(event, tracker.filters dict) -> bool``. Returning False
|
||||
# drops the event for that tracker before any DB writes happen. Callers pass
|
||||
# provider-specific logic (Gitea sender allowlist, HA entity glob, etc.).
|
||||
FilterFn = Callable[[ServiceEvent, dict[str, Any]], bool]
|
||||
|
||||
|
||||
async def dispatch_provider_event(
|
||||
engine: Any,
|
||||
provider_id: int,
|
||||
provider_name: str,
|
||||
provider_config: dict[str, Any],
|
||||
event: ServiceEvent,
|
||||
detail_keys: tuple[str, ...],
|
||||
filter_fn: FilterFn,
|
||||
) -> int:
|
||||
"""Load matching trackers, log, gate, defer, and dispatch one event.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
engine:
|
||||
SQLAlchemy async engine.
|
||||
provider_id:
|
||||
ID of the :class:`ServiceProvider` the event came from.
|
||||
provider_name:
|
||||
Human-readable name (for logging only).
|
||||
provider_config:
|
||||
``ServiceProvider.config`` dict; flowed into :class:`TargetConfig`.
|
||||
event:
|
||||
Parsed :class:`ServiceEvent` to dispatch.
|
||||
detail_keys:
|
||||
Keys from ``event.extra`` to copy into ``EventLog.details``.
|
||||
filter_fn:
|
||||
Per-event tracker-level filter. Returning False drops the event for
|
||||
that tracker before any DB writes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of successfully dispatched notifications across all trackers.
|
||||
"""
|
||||
dispatched = 0
|
||||
# Drain-scheduling is best-effort: a scheduling failure must not roll
|
||||
# back the persisted defer rows (startup catch-up re-establishes them).
|
||||
defers_to_schedule: set[Any] = set()
|
||||
async with AsyncSession(engine) as session:
|
||||
# App timezone is identical across trackers in one inbound event;
|
||||
# pull it once.
|
||||
app_tz = await get_app_timezone(session)
|
||||
|
||||
tracker_result = await session.exec(
|
||||
select(NotificationTracker).where(
|
||||
NotificationTracker.provider_id == provider_id,
|
||||
NotificationTracker.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
trackers = tracker_result.all()
|
||||
|
||||
for tracker in trackers:
|
||||
filters = tracker.filters or {}
|
||||
if not filter_fn(event, filters):
|
||||
_LOGGER.debug(
|
||||
"Event filtered out for tracker %d (%s)", tracker.id, tracker.name
|
||||
)
|
||||
continue
|
||||
|
||||
link_data = await load_link_data(session, tracker.id)
|
||||
if not link_data:
|
||||
continue
|
||||
|
||||
extra_details = {k: v for k, v in event.extra.items() if k in detail_keys}
|
||||
event_log_row = EventLog(
|
||||
user_id=tracker.user_id,
|
||||
tracker_id=tracker.id,
|
||||
tracker_name=tracker.name,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
event_type=event.event_type.value,
|
||||
collection_id=event.collection_id,
|
||||
collection_name=event.collection_name,
|
||||
assets_count=0,
|
||||
details={
|
||||
"provider_type": event.provider_type.value,
|
||||
**extra_details,
|
||||
},
|
||||
)
|
||||
session.add(event_log_row)
|
||||
await session.flush()
|
||||
event_log_id = event_log_row.id
|
||||
|
||||
# Dedupe defers by parent link_id: broadcast links emit one
|
||||
# link_data entry per child, sharing the same parent id — the
|
||||
# deferred row is one-per-link, so we call defer_event only
|
||||
# once per distinct id (earliest fire_at wins on ties).
|
||||
groups: dict[int, tuple[Any, list[TargetConfig]]] = {}
|
||||
defers_for_event: dict[int, Any] = {}
|
||||
for ld in link_data:
|
||||
tc = ld["tracking_config"]
|
||||
if tc is not None:
|
||||
outcome = evaluate_event_gate(event, tc, app_tz)
|
||||
if outcome.reason is GateReason.QUIET_HOURS:
|
||||
if (
|
||||
is_deferrable(event.event_type.value)
|
||||
and outcome.quiet_hours_end_at is not None
|
||||
):
|
||||
link_id = ld.get("link_id")
|
||||
if link_id is not None:
|
||||
prior = defers_for_event.get(link_id)
|
||||
if prior is None or outcome.quiet_hours_end_at < prior:
|
||||
defers_for_event[link_id] = outcome.quiet_hours_end_at
|
||||
continue
|
||||
if outcome.reason is GateReason.EVENT_TYPE_DISABLED:
|
||||
continue
|
||||
|
||||
tmpl = ld["template_config"]
|
||||
target_cfg = TargetConfig(
|
||||
type=ld["target_type"],
|
||||
config=ld["target_config"],
|
||||
template_slots=ld["template_slots"],
|
||||
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
|
||||
date_only_format=(
|
||||
tmpl.date_only_format
|
||||
if tmpl and tmpl.date_only_format
|
||||
else "%d.%m.%Y"
|
||||
),
|
||||
provider_api_key=provider_config.get("api_token"),
|
||||
provider_internal_url=provider_config.get("url", ""),
|
||||
provider_external_url=provider_config.get("url", ""),
|
||||
receivers=ld["receivers"],
|
||||
)
|
||||
key = id(tc) if tc is not None else 0
|
||||
if key not in groups:
|
||||
groups[key] = (tc, [])
|
||||
groups[key][1].append(target_cfg)
|
||||
|
||||
# Persist defers + stamp event_log dispatch_status in the same
|
||||
# session that holds the EventLog row, so the "deferred" badge
|
||||
# only appears if the underlying queue rows actually exist.
|
||||
if defers_for_event:
|
||||
earliest = min(defers_for_event.values())
|
||||
for link_id, fire_at in defers_for_event.items():
|
||||
await defer_event(
|
||||
session,
|
||||
event=event,
|
||||
user_id=tracker.user_id,
|
||||
tracker_id=tracker.id,
|
||||
link_id=link_id,
|
||||
event_log_id=event_log_id,
|
||||
fire_at=fire_at,
|
||||
)
|
||||
details = dict(event_log_row.details or {})
|
||||
if not details.get("dispatch_status"):
|
||||
details["dispatch_status"] = "deferred"
|
||||
details["deferred_until"] = earliest.isoformat()
|
||||
event_log_row.details = details
|
||||
session.add(event_log_row)
|
||||
defers_to_schedule.update(defers_for_event.values())
|
||||
|
||||
# Dispatch to targets. Isolate dispatcher exceptions per group so
|
||||
# a failed remote call doesn't bubble out, abort the surrounding
|
||||
# transaction, and roll back the just-written defers / event_log.
|
||||
from .http_session import get_http_session
|
||||
dispatcher = NotificationDispatcher(session=await get_http_session())
|
||||
for tc, target_configs in groups.values():
|
||||
if not target_configs:
|
||||
continue
|
||||
shaped_event = apply_tracking_display_filters(event, tc)
|
||||
if shaped_event is None:
|
||||
continue
|
||||
try:
|
||||
results = await dispatcher.dispatch(shaped_event, target_configs)
|
||||
except Exception as err: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"Dispatcher raised for tracker %d: %s", tracker.id, err,
|
||||
)
|
||||
continue
|
||||
for r in results:
|
||||
if r.get("success"):
|
||||
dispatched += 1
|
||||
else:
|
||||
_LOGGER.error(
|
||||
"Notification failed for tracker %d: %s",
|
||||
tracker.id, r.get("error", "unknown"),
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Schedule drain jobs OUTSIDE the DB session so an APScheduler hiccup
|
||||
# can't roll back the persisted defer rows.
|
||||
if defers_to_schedule:
|
||||
from .scheduler import schedule_deferred_drain
|
||||
for fire_at in defers_to_schedule:
|
||||
try:
|
||||
schedule_deferred_drain(fire_at)
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"Failed to schedule deferred drain for %s", fire_at,
|
||||
)
|
||||
|
||||
return dispatched
|
||||
@@ -0,0 +1,293 @@
|
||||
"""Home Assistant subscription manager.
|
||||
|
||||
Phase 1 of the HA provider lives here. For every enabled ``home_assistant``
|
||||
:class:`ServiceProvider` row in the DB, this module spawns one long-running
|
||||
asyncio task that:
|
||||
|
||||
1. Builds an :class:`HomeAssistantServiceProvider` from the provider row.
|
||||
2. Calls ``provider.subscribe(emit)`` which loops forever — connect,
|
||||
authenticate, subscribe, drain events through ``emit`` — and reconnects
|
||||
with exponential backoff on any drop.
|
||||
3. Each ``emit`` call hands the parsed :class:`ServiceEvent` to
|
||||
:func:`dispatch_provider_event` (the shared dispatch helper that webhook
|
||||
providers also use), so quiet hours, deferred dispatch, and event-log
|
||||
writes all behave identically to the rest of the system.
|
||||
|
||||
Lifecycle is owned by ``main.py`` via :func:`start_all` and :func:`stop_all`.
|
||||
Phase 1 does not reconcile against DB changes after boot — adding,
|
||||
modifying, or removing a HA provider requires a server restart. Phase 1.5
|
||||
will add a CRUD-triggered :func:`reload_provider` hook.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlmodel import select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from notify_bridge_core.models.events import ServiceEvent
|
||||
from notify_bridge_core.providers.home_assistant import (
|
||||
HomeAssistantAuthError,
|
||||
HomeAssistantServiceProvider,
|
||||
)
|
||||
|
||||
from ..database.engine import get_engine
|
||||
from ..database.models import ServiceProvider
|
||||
from .event_dispatch import dispatch_provider_event
|
||||
from .http_session import get_http_session
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Per-provider running task. Keyed by provider_id so reload_provider() can
|
||||
# find and replace a single task without disturbing the rest.
|
||||
_running_tasks: dict[int, asyncio.Task[None]] = {}
|
||||
|
||||
|
||||
# Keys from ``event.extra`` to copy into ``EventLog.details``. Anything not
|
||||
# in this list is still available to templates via the merged extras, but
|
||||
# the event-log row stays slim.
|
||||
_HA_DETAIL_KEYS: tuple[str, ...] = (
|
||||
"entity_id",
|
||||
"friendly_name",
|
||||
"domain",
|
||||
"old_state",
|
||||
"new_state",
|
||||
"device_class",
|
||||
"unit_of_measurement",
|
||||
"area",
|
||||
"ha_event_type",
|
||||
"automation_name",
|
||||
"service_called",
|
||||
"target_entity",
|
||||
)
|
||||
|
||||
|
||||
def _ha_passes_filters(event: ServiceEvent, filters: dict[str, Any]) -> bool:
|
||||
"""HA-specific tracker filter.
|
||||
|
||||
Three filter keys, all optional, evaluated as a union: if the entity
|
||||
matches any one of them, the event passes. Empty filters mean "accept
|
||||
everything" — different from the Gitea filter which is an intersection.
|
||||
|
||||
Filter shape:
|
||||
|
||||
* ``collections`` — list of exact ``entity_id`` matches.
|
||||
* ``entity_glob`` — list of glob patterns (``light.*``, ``*_motion``).
|
||||
* ``domain_allowlist`` — list of HA domain prefixes (``light``).
|
||||
"""
|
||||
collections = filters.get("collections") or []
|
||||
entity_globs = filters.get("entity_glob") or []
|
||||
domain_allowlist = filters.get("domain_allowlist") or []
|
||||
|
||||
# No filters configured = accept everything.
|
||||
if not collections and not entity_globs and not domain_allowlist:
|
||||
return True
|
||||
|
||||
entity_id = event.collection_id
|
||||
domain = event.extra.get("domain") or (
|
||||
entity_id.split(".", 1)[0] if "." in entity_id else ""
|
||||
)
|
||||
|
||||
if collections and entity_id in collections:
|
||||
return True
|
||||
if domain_allowlist and domain in domain_allowlist:
|
||||
return True
|
||||
if entity_globs:
|
||||
from fnmatch import fnmatchcase
|
||||
for pattern in entity_globs:
|
||||
if isinstance(pattern, str) and fnmatchcase(entity_id, pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _run_provider(provider_id: int) -> None:
|
||||
"""One per-provider supervisor loop.
|
||||
|
||||
Reloads the provider row each iteration so config changes (URL, token,
|
||||
event types) take effect on the next reconnect cycle — no need for a
|
||||
full restart in the simple case where only credentials changed.
|
||||
|
||||
The ``_emit`` closure is rebuilt every iteration. Its lifetime equals
|
||||
one ``subscribe()`` call: the callback only runs while the HA client's
|
||||
drain task is alive. ``provider_name`` is snapshotted at the start of
|
||||
each (re)connect cycle, so renames take effect on the next reconnect —
|
||||
chatty enough for v1; revisit if longer-lived WS sessions need fresher
|
||||
names mid-stream.
|
||||
"""
|
||||
assert provider_id is not None, "_run_provider requires a real provider id"
|
||||
engine = get_engine()
|
||||
while True:
|
||||
try:
|
||||
async with AsyncSession(engine) as session:
|
||||
row = await session.get(ServiceProvider, provider_id)
|
||||
if row is None or row.type != "home_assistant":
|
||||
_LOGGER.info(
|
||||
"HA provider %s removed or retyped, stopping supervisor",
|
||||
provider_id,
|
||||
)
|
||||
return
|
||||
config = dict(row.config or {})
|
||||
provider_name = row.name
|
||||
|
||||
url = config.get("url", "")
|
||||
access_token = config.get("access_token", "")
|
||||
verify_tls = bool(config.get("verify_tls", True))
|
||||
event_types = config.get("event_types") or None
|
||||
|
||||
if not url or not access_token:
|
||||
_LOGGER.warning(
|
||||
"HA provider %s missing url or access_token; retrying in 60s",
|
||||
provider_id,
|
||||
)
|
||||
await asyncio.sleep(60)
|
||||
continue
|
||||
|
||||
session_http = await get_http_session()
|
||||
ha_provider = HomeAssistantServiceProvider(
|
||||
session=session_http,
|
||||
url=url,
|
||||
access_token=access_token,
|
||||
verify_tls=verify_tls,
|
||||
event_types=event_types,
|
||||
name=provider_name,
|
||||
)
|
||||
|
||||
async def _emit(event: ServiceEvent) -> None:
|
||||
# Shield the DB-writing dispatch from external cancellation
|
||||
# (shutdown, supervisor restart). The shield ensures that
|
||||
# once a transaction is mid-flight, it commits or rolls back
|
||||
# cleanly instead of being torn down with the asyncio task
|
||||
# at a write boundary. Worst case: shutdown waits up to one
|
||||
# dispatch latency longer.
|
||||
#
|
||||
# Perf note (Phase 2 follow-up): dispatch_provider_event
|
||||
# opens a fresh AsyncSession per call. For HA's chatty
|
||||
# state_changed bus this hammers the pool; batch in a
|
||||
# follow-up.
|
||||
try:
|
||||
await asyncio.shield(dispatch_provider_event(
|
||||
engine=engine,
|
||||
provider_id=provider_id,
|
||||
provider_name=provider_name,
|
||||
provider_config=config,
|
||||
event=event,
|
||||
detail_keys=_HA_DETAIL_KEYS,
|
||||
filter_fn=_ha_passes_filters,
|
||||
))
|
||||
except asyncio.CancelledError:
|
||||
# Shield re-raises CancelledError to the caller; let it
|
||||
# propagate so the drain task exits cleanly.
|
||||
raise
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"Failed to dispatch HA event for provider %s",
|
||||
provider_id,
|
||||
)
|
||||
|
||||
_LOGGER.info(
|
||||
"Starting HA subscription for provider %s (%s)",
|
||||
provider_id, provider_name,
|
||||
)
|
||||
await ha_provider.subscribe(_emit)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except HomeAssistantAuthError as err:
|
||||
# Fatal at the provider level — bad token. Sleep long and retry
|
||||
# so the user has time to fix the token without us hammering HA.
|
||||
# Error string already redacted by the client before re-raise.
|
||||
_LOGGER.error(
|
||||
"HA provider %s auth failed: %s — retrying in 5 minutes",
|
||||
provider_id, err,
|
||||
)
|
||||
await asyncio.sleep(300)
|
||||
except Exception: # noqa: BLE001
|
||||
_LOGGER.exception(
|
||||
"HA supervisor for provider %s crashed; restarting in 30s",
|
||||
provider_id,
|
||||
)
|
||||
await asyncio.sleep(30)
|
||||
|
||||
|
||||
def _make_done_callback(provider_id: int):
|
||||
"""Return a done-callback that prunes the task from ``_running_tasks``.
|
||||
|
||||
Without this, finished supervisors (provider deleted, fatal auth error
|
||||
after long sleep) leave stale entries in the dict — across many
|
||||
reload cycles the dict would grow unboundedly. The callback is
|
||||
registered on every task spawned via ``start_all`` / ``reload_provider``.
|
||||
"""
|
||||
def _cb(task: asyncio.Task[None]) -> None:
|
||||
current = _running_tasks.get(provider_id)
|
||||
if current is task:
|
||||
_running_tasks.pop(provider_id, None)
|
||||
return _cb
|
||||
|
||||
|
||||
async def start_all() -> None:
|
||||
"""Start a supervisor task for every enabled HA provider."""
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
result = await session.exec(
|
||||
select(ServiceProvider).where(
|
||||
ServiceProvider.type == "home_assistant",
|
||||
)
|
||||
)
|
||||
providers = result.all()
|
||||
|
||||
for prov in providers:
|
||||
if prov.id in _running_tasks and not _running_tasks[prov.id].done():
|
||||
continue
|
||||
task = asyncio.create_task(
|
||||
_run_provider(prov.id),
|
||||
name=f"ha-subscription-{prov.id}",
|
||||
)
|
||||
task.add_done_callback(_make_done_callback(prov.id))
|
||||
_running_tasks[prov.id] = task
|
||||
if providers:
|
||||
_LOGGER.info(
|
||||
"Started HA subscription manager: %d provider(s)", len(providers),
|
||||
)
|
||||
|
||||
|
||||
async def stop_all() -> None:
|
||||
"""Cancel every HA supervisor task and wait for clean shutdown."""
|
||||
if not _running_tasks:
|
||||
return
|
||||
for task in _running_tasks.values():
|
||||
task.cancel()
|
||||
# Wait for all to drain; swallow cancellation errors.
|
||||
await asyncio.gather(*_running_tasks.values(), return_exceptions=True)
|
||||
_running_tasks.clear()
|
||||
_LOGGER.info("Stopped all HA subscription supervisors")
|
||||
|
||||
|
||||
async def reload_provider(provider_id: int) -> None:
|
||||
"""Stop and restart the supervisor for a single provider id.
|
||||
|
||||
Hook for the provider CRUD routes — Phase 1.5 will wire it in. For Phase
|
||||
1, configure-then-restart-backend is the supported flow.
|
||||
"""
|
||||
task = _running_tasks.pop(provider_id, None)
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception): # noqa: BLE001
|
||||
pass
|
||||
|
||||
engine = get_engine()
|
||||
async with AsyncSession(engine) as session:
|
||||
prov = await session.get(ServiceProvider, provider_id)
|
||||
if prov is None or prov.type != "home_assistant":
|
||||
return
|
||||
|
||||
new_task = asyncio.create_task(
|
||||
_run_provider(provider_id),
|
||||
name=f"ha-subscription-{provider_id}",
|
||||
)
|
||||
new_task.add_done_callback(_make_done_callback(provider_id))
|
||||
_running_tasks[provider_id] = new_task
|
||||
@@ -213,4 +213,25 @@ _SAMPLE_CONTEXT = {
|
||||
"raw_payload": {"action": "opened", "issue": {"title": "Bug report", "number": 1}, "sender": {"login": "user1"}},
|
||||
"event_type_raw": "webhook_received",
|
||||
"source_ip": "192.168.1.100",
|
||||
# Home Assistant variables (for home_assistant provider templates)
|
||||
"friendly_name": "Front Door",
|
||||
"entity_id": "binary_sensor.front_door",
|
||||
"domain": "binary_sensor",
|
||||
"old_state": "off",
|
||||
"new_state": "on",
|
||||
"attributes": {"friendly_name": "Front Door", "device_class": "door"},
|
||||
"device_class": "door",
|
||||
"unit_of_measurement": "",
|
||||
"area": "Entrance",
|
||||
"last_changed": "2026-05-13T12:34:56.789+00:00",
|
||||
"last_updated": "2026-05-13T12:34:56.789+00:00",
|
||||
"automation_name": "Front door notification",
|
||||
"trigger_source": "state of binary_sensor.front_door",
|
||||
"service_called": "light.turn_on",
|
||||
"service_domain": "light",
|
||||
"service_name": "turn_on",
|
||||
"service_data": {"entity_id": "light.kitchen"},
|
||||
"target_entity": "light.kitchen",
|
||||
"ha_event_type": "state_changed",
|
||||
"event_data": {"foo": "bar"},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
"""Unit tests for HA bot command helpers — Phase 2.
|
||||
|
||||
Focus on the security-sensitive bits the reviewer flagged: attribute
|
||||
filtering, error-message redaction, and the sample-context shape that
|
||||
flows through Jinja preview rendering.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from notify_bridge_server.commands.home_assistant_handler import (
|
||||
_filter_attributes,
|
||||
_is_sensitive_attr,
|
||||
_normalize_state,
|
||||
)
|
||||
|
||||
|
||||
def test_filter_attributes_drops_credential_keys() -> None:
|
||||
"""HA camera entities expose an ``access_token`` attribute. The handler
|
||||
MUST NOT surface it to the chat user via /state."""
|
||||
raw = {
|
||||
"friendly_name": "Front Camera",
|
||||
"access_token": "real-camera-proxy-token",
|
||||
"entity_picture": "/api/camera_proxy/...?token=abc",
|
||||
"brightness": 200,
|
||||
}
|
||||
safe, hidden = _filter_attributes(raw)
|
||||
assert "access_token" not in safe
|
||||
# entity_picture contains 'token' substring → blocked.
|
||||
assert "entity_picture" not in safe
|
||||
# friendly_name is rendered as a top-level field, not iterated.
|
||||
assert "friendly_name" not in safe
|
||||
# brightness is a normal attribute, passes through.
|
||||
assert safe["brightness"] == 200
|
||||
assert hidden == 2
|
||||
|
||||
|
||||
def test_filter_attributes_caps_count() -> None:
|
||||
"""When an entity has dozens of attributes the renderer would overflow
|
||||
Telegram's 4096-char message limit. Cap at 30 with overflow surfaced."""
|
||||
raw = {f"attr_{i:03d}": i for i in range(50)}
|
||||
safe, hidden = _filter_attributes(raw)
|
||||
assert len(safe) == 30
|
||||
assert hidden == 20
|
||||
|
||||
|
||||
def test_is_sensitive_attr_case_insensitive() -> None:
|
||||
"""Match should not depend on key casing — custom integrations are
|
||||
inconsistent about capitalization."""
|
||||
assert _is_sensitive_attr("Access_Token") is True
|
||||
assert _is_sensitive_attr("API_KEY") is True
|
||||
assert _is_sensitive_attr("password") is True
|
||||
assert _is_sensitive_attr("brightness") is False
|
||||
assert _is_sensitive_attr("color_mode") is False
|
||||
|
||||
|
||||
def test_normalize_state_filters_attrs() -> None:
|
||||
"""End-to-end: feed _normalize_state a malicious state row, verify the
|
||||
output has redacted attributes + hidden_attr_count surfaced."""
|
||||
state_row = {
|
||||
"entity_id": "camera.front_door",
|
||||
"state": "idle",
|
||||
"attributes": {
|
||||
"friendly_name": "Front Door Camera",
|
||||
"access_token": "leaked",
|
||||
"brand": "Reolink",
|
||||
},
|
||||
"last_changed": "2026-05-13T12:00:00+00:00",
|
||||
"last_updated": "2026-05-13T12:00:00+00:00",
|
||||
}
|
||||
out = _normalize_state(state_row)
|
||||
assert out["entity_id"] == "camera.front_door"
|
||||
assert out["friendly_name"] == "Front Door Camera"
|
||||
assert out["domain"] == "camera"
|
||||
# Top-level fields preserved.
|
||||
assert out["state"] == "idle"
|
||||
# Attributes dict is filtered.
|
||||
assert "access_token" not in out["attributes"]
|
||||
assert out["attributes"].get("brand") == "Reolink"
|
||||
# Hidden count reflects access_token (friendly_name is top-level, not redacted).
|
||||
assert out["hidden_attr_count"] == 1
|
||||
|
||||
|
||||
def test_normalize_state_handles_missing_attributes() -> None:
|
||||
"""A state row with no attributes dict should not crash."""
|
||||
out = _normalize_state({"entity_id": "sensor.x", "state": "1"})
|
||||
assert out["attributes"] == {}
|
||||
assert out["hidden_attr_count"] == 0
|
||||
|
||||
|
||||
def test_redact_ha_message_strips_userinfo() -> None:
|
||||
"""The Phase 1 redact helper is re-exported via the HA package and used
|
||||
by /entities, /state, /areas before surfacing errors. Make sure the
|
||||
re-export still works and the contract is what we expect."""
|
||||
from notify_bridge_core.providers.home_assistant import redact_ha_message
|
||||
msg = "Cannot connect to https://leak-token@homeassistant.local:8123/api/websocket"
|
||||
out = redact_ha_message(msg)
|
||||
assert "leak-token@" not in out
|
||||
assert "homeassistant.local:8123" in out
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Tests for the HA-specific tracker filter (entity_glob, domain_allowlist).
|
||||
|
||||
The Gitea filter is an intersection of senders/collections. The HA filter
|
||||
is intentionally a *union* across the three keys — any match passes — so a
|
||||
user can mix exact entity ids with glob patterns and domain allowlists
|
||||
without each one narrowing the others.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from notify_bridge_core.models.events import EventType, ServiceEvent
|
||||
from notify_bridge_core.providers.base import ServiceProviderType
|
||||
from notify_bridge_server.services.ha_subscription import _ha_passes_filters
|
||||
|
||||
|
||||
def _ha_event(entity_id: str, domain: str | None = None) -> ServiceEvent:
|
||||
return ServiceEvent(
|
||||
event_type=EventType.HA_STATE_CHANGED,
|
||||
provider_type=ServiceProviderType.HOME_ASSISTANT,
|
||||
provider_name="HA",
|
||||
collection_id=entity_id,
|
||||
collection_name=entity_id,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
extra={"domain": domain or (entity_id.split(".", 1)[0] if "." in entity_id else "")},
|
||||
)
|
||||
|
||||
|
||||
def test_empty_filters_accept_everything() -> None:
|
||||
assert _ha_passes_filters(_ha_event("light.kitchen"), {}) is True
|
||||
|
||||
|
||||
def test_exact_entity_match() -> None:
|
||||
filters = {"collections": ["light.kitchen", "switch.lamp"]}
|
||||
assert _ha_passes_filters(_ha_event("light.kitchen"), filters) is True
|
||||
assert _ha_passes_filters(_ha_event("light.bedroom"), filters) is False
|
||||
|
||||
|
||||
def test_entity_glob_match() -> None:
|
||||
filters = {"entity_glob": ["binary_sensor.*_motion", "light.kitchen*"]}
|
||||
assert _ha_passes_filters(_ha_event("binary_sensor.hallway_motion"), filters) is True
|
||||
assert _ha_passes_filters(_ha_event("light.kitchen_main"), filters) is True
|
||||
assert _ha_passes_filters(_ha_event("light.bedroom"), filters) is False
|
||||
|
||||
|
||||
def test_domain_allowlist() -> None:
|
||||
filters = {"domain_allowlist": ["light", "switch"]}
|
||||
assert _ha_passes_filters(_ha_event("light.kitchen"), filters) is True
|
||||
assert _ha_passes_filters(_ha_event("switch.lamp"), filters) is True
|
||||
assert _ha_passes_filters(_ha_event("sensor.temp"), filters) is False
|
||||
|
||||
|
||||
def test_union_across_keys() -> None:
|
||||
"""If collections names a specific sensor.* but domain_allowlist names
|
||||
'light', BOTH should be acceptable — that's the difference from the
|
||||
Gitea-style intersection filter."""
|
||||
filters = {
|
||||
"collections": ["sensor.outdoor_temp"],
|
||||
"domain_allowlist": ["light"],
|
||||
}
|
||||
assert _ha_passes_filters(_ha_event("sensor.outdoor_temp"), filters) is True
|
||||
assert _ha_passes_filters(_ha_event("light.kitchen"), filters) is True
|
||||
# Neither matches:
|
||||
assert _ha_passes_filters(_ha_event("binary_sensor.door"), filters) is False
|
||||
|
||||
|
||||
def test_domain_derived_when_extra_missing() -> None:
|
||||
"""If the parser didn't populate extra.domain (e.g. malformed event),
|
||||
the filter must still infer it from the entity_id prefix."""
|
||||
evt = ServiceEvent(
|
||||
event_type=EventType.HA_STATE_CHANGED,
|
||||
provider_type=ServiceProviderType.HOME_ASSISTANT,
|
||||
provider_name="HA",
|
||||
collection_id="light.kitchen",
|
||||
collection_name="light.kitchen",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
extra={}, # No 'domain' key.
|
||||
)
|
||||
assert _ha_passes_filters(evt, {"domain_allowlist": ["light"]}) is True
|
||||
@@ -0,0 +1,187 @@
|
||||
"""Unit tests for the Home Assistant event parser.
|
||||
|
||||
These tests don't need a database or HA server — the parser is a pure
|
||||
function from ``ha_event_dict`` to :class:`ServiceEvent`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from notify_bridge_core.models.events import EventType
|
||||
from notify_bridge_core.providers.base import ServiceProviderType
|
||||
from notify_bridge_core.providers.home_assistant.event_parser import parse_event
|
||||
|
||||
|
||||
def _ha_event_envelope(event_type: str, data: dict) -> dict:
|
||||
return {
|
||||
"event_type": event_type,
|
||||
"data": data,
|
||||
"time_fired": "2026-05-13T12:34:56.789Z",
|
||||
}
|
||||
|
||||
|
||||
def test_state_changed_basic() -> None:
|
||||
payload = _ha_event_envelope(
|
||||
"state_changed",
|
||||
{
|
||||
"entity_id": "binary_sensor.front_door",
|
||||
"old_state": {"state": "off", "attributes": {}},
|
||||
"new_state": {
|
||||
"state": "on",
|
||||
"attributes": {
|
||||
"friendly_name": "Front Door",
|
||||
"device_class": "door",
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
evt = parse_event(payload, provider_name="HA")
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.HA_STATE_CHANGED
|
||||
assert evt.provider_type is ServiceProviderType.HOME_ASSISTANT
|
||||
assert evt.collection_id == "binary_sensor.front_door"
|
||||
assert evt.collection_name == "Front Door"
|
||||
assert evt.extra["old_state"] == "off"
|
||||
assert evt.extra["new_state"] == "on"
|
||||
assert evt.extra["domain"] == "binary_sensor"
|
||||
assert evt.extra["device_class"] == "door"
|
||||
# Area was not provided in lookup -> None.
|
||||
assert evt.extra["area"] is None
|
||||
|
||||
|
||||
def test_state_changed_with_area_lookup() -> None:
|
||||
payload = _ha_event_envelope(
|
||||
"state_changed",
|
||||
{
|
||||
"entity_id": "light.kitchen",
|
||||
"old_state": {"state": "off", "attributes": {}},
|
||||
"new_state": {
|
||||
"state": "on",
|
||||
"attributes": {"friendly_name": "Kitchen Light"},
|
||||
},
|
||||
},
|
||||
)
|
||||
evt = parse_event(
|
||||
payload,
|
||||
provider_name="HA",
|
||||
area_lookup={"light.kitchen": "Kitchen"},
|
||||
)
|
||||
assert evt is not None
|
||||
assert evt.extra["area"] == "Kitchen"
|
||||
|
||||
|
||||
def test_state_changed_entity_removed() -> None:
|
||||
"""new_state=None means HA removed the entity. Surface as 'removed' so
|
||||
templates can branch on it; collection_name falls back to old_state."""
|
||||
payload = _ha_event_envelope(
|
||||
"state_changed",
|
||||
{
|
||||
"entity_id": "sensor.dropped",
|
||||
"old_state": {
|
||||
"state": "online",
|
||||
"attributes": {"friendly_name": "Dropped Sensor"},
|
||||
},
|
||||
"new_state": None,
|
||||
},
|
||||
)
|
||||
evt = parse_event(payload, provider_name="HA")
|
||||
assert evt is not None
|
||||
assert evt.extra["new_state"] == "removed"
|
||||
assert evt.collection_name == "Dropped Sensor"
|
||||
|
||||
|
||||
def test_automation_triggered() -> None:
|
||||
payload = _ha_event_envelope(
|
||||
"automation_triggered",
|
||||
{
|
||||
"name": "Front door notification",
|
||||
"entity_id": "automation.front_door_notify",
|
||||
"source": "state of binary_sensor.front_door",
|
||||
},
|
||||
)
|
||||
evt = parse_event(payload, provider_name="HA")
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.HA_AUTOMATION_TRIGGERED
|
||||
assert evt.collection_name == "Front door notification"
|
||||
assert evt.collection_id == "automation.front_door_notify"
|
||||
assert evt.extra["automation_name"] == "Front door notification"
|
||||
assert evt.extra["trigger_source"] == "state of binary_sensor.front_door"
|
||||
|
||||
|
||||
def test_call_service_with_target() -> None:
|
||||
payload = _ha_event_envelope(
|
||||
"call_service",
|
||||
{
|
||||
"domain": "light",
|
||||
"service": "turn_on",
|
||||
"service_data": {"entity_id": "light.kitchen"},
|
||||
},
|
||||
)
|
||||
evt = parse_event(payload, provider_name="HA")
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.HA_SERVICE_CALLED
|
||||
assert evt.collection_id == "light.turn_on"
|
||||
assert evt.extra["target_entity"] == "light.kitchen"
|
||||
assert evt.extra["service_domain"] == "light"
|
||||
assert evt.extra["service_name"] == "turn_on"
|
||||
|
||||
|
||||
def test_call_service_with_multi_target() -> None:
|
||||
"""When the call hits multiple entities, the parser comma-joins them
|
||||
so templates can render ``{{ target_entity }}`` without iterating."""
|
||||
payload = _ha_event_envelope(
|
||||
"call_service",
|
||||
{
|
||||
"domain": "light",
|
||||
"service": "turn_off",
|
||||
"service_data": {
|
||||
"entity_id": ["light.kitchen", "light.living_room"],
|
||||
},
|
||||
},
|
||||
)
|
||||
evt = parse_event(payload, provider_name="HA")
|
||||
assert evt is not None
|
||||
assert evt.extra["target_entity"] == "light.kitchen, light.living_room"
|
||||
|
||||
|
||||
def test_generic_event_fallback() -> None:
|
||||
"""Any event_type not in the known set becomes ha_event_fired with the
|
||||
raw event_type stashed in extras so loud catch-all subscriptions work."""
|
||||
payload = _ha_event_envelope(
|
||||
"custom_event_xyz",
|
||||
{"foo": "bar"},
|
||||
)
|
||||
evt = parse_event(payload, provider_name="HA")
|
||||
assert evt is not None
|
||||
assert evt.event_type is EventType.HA_EVENT_FIRED
|
||||
assert evt.extra["ha_event_type"] == "custom_event_xyz"
|
||||
assert evt.extra["event_data"] == {"foo": "bar"}
|
||||
|
||||
|
||||
def test_malformed_payload_returns_none() -> None:
|
||||
assert parse_event({}, provider_name="HA") is None
|
||||
assert parse_event("not a dict", provider_name="HA") is None # type: ignore[arg-type]
|
||||
# state_changed without entity_id is unrecoverable
|
||||
bad = _ha_event_envelope("state_changed", {"new_state": None})
|
||||
assert parse_event(bad, provider_name="HA") is None
|
||||
# call_service without domain/service is unrecoverable
|
||||
bad2 = _ha_event_envelope("call_service", {"service": "turn_on"})
|
||||
assert parse_event(bad2, provider_name="HA") is None
|
||||
|
||||
|
||||
def test_time_fired_iso_with_z_suffix_parses() -> None:
|
||||
"""HA uses ``Z`` suffix; older Python ``fromisoformat`` rejects it.
|
||||
The parser must handle both forms or we'd lose the timestamp."""
|
||||
from datetime import timezone
|
||||
payload = _ha_event_envelope(
|
||||
"state_changed",
|
||||
{
|
||||
"entity_id": "sensor.temp",
|
||||
"old_state": {"state": "20", "attributes": {}},
|
||||
"new_state": {"state": "21", "attributes": {}},
|
||||
},
|
||||
)
|
||||
payload["time_fired"] = "2026-05-13T12:34:56.789Z"
|
||||
evt = parse_event(payload, provider_name="HA")
|
||||
assert evt is not None
|
||||
assert evt.timestamp.tzinfo is not None
|
||||
assert evt.timestamp.utcoffset() == timezone.utc.utcoffset(None)
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Tests for the HA WS session helper and slice-before-normalize path.
|
||||
|
||||
The reviewer flagged two perf-shaped concerns that we've now addressed:
|
||||
|
||||
1. ``/status`` and ``/areas`` previously opened 3 and 2 separate WS
|
||||
connections respectively. With ``HomeAssistantSession`` they share one
|
||||
socket — these tests pin the contract.
|
||||
2. ``/entities`` used to normalize every matching entity before slicing to
|
||||
``count``. For HA installs with 1000+ entities this materialized 1000+
|
||||
normalized dicts to throw most away. The optimization moves the slice
|
||||
*before* normalize; this test exercises a 200-entity fixture and
|
||||
verifies only the ``count`` survivors get normalized.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from notify_bridge_core.providers.home_assistant.client import HomeAssistantSession
|
||||
from notify_bridge_server.commands import home_assistant_handler as handler
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session class — surface contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_class_has_expected_methods() -> None:
|
||||
"""Anyone consuming ``HomeAssistantSession`` can rely on this surface."""
|
||||
expected = {"send", "get_states", "get_area_registry", "get_entity_registry"}
|
||||
actual = {name for name in dir(HomeAssistantSession) if not name.startswith("_")}
|
||||
assert expected <= actual, f"missing: {expected - actual}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_get_states_routes_through_send() -> None:
|
||||
"""``get_states`` is a thin wrapper around ``send`` with the canonical payload."""
|
||||
sent: list[dict[str, Any]] = []
|
||||
|
||||
class _FakeClient:
|
||||
async def _send_command(self, ws: Any, payload: dict[str, Any]) -> int:
|
||||
sent.append(payload)
|
||||
return 1
|
||||
|
||||
async def _await_result(self, ws: Any, msg_id: int, timeout: float = 15.0) -> Any:
|
||||
return [{"entity_id": "light.kitchen", "state": "on", "attributes": {}}]
|
||||
|
||||
sess = HomeAssistantSession(_FakeClient(), ws=object()) # type: ignore[arg-type]
|
||||
result = await sess.get_states()
|
||||
assert sent == [{"type": "get_states"}]
|
||||
assert result == [{"entity_id": "light.kitchen", "state": "on", "attributes": {}}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_methods_use_distinct_payloads() -> None:
|
||||
"""Each session-scoped method sends the right HA command name."""
|
||||
sent: list[dict[str, Any]] = []
|
||||
|
||||
class _FakeClient:
|
||||
async def _send_command(self, ws: Any, payload: dict[str, Any]) -> int:
|
||||
sent.append(payload)
|
||||
return len(sent)
|
||||
|
||||
async def _await_result(self, ws: Any, msg_id: int, timeout: float = 15.0) -> Any:
|
||||
return []
|
||||
|
||||
sess = HomeAssistantSession(_FakeClient(), ws=object()) # type: ignore[arg-type]
|
||||
await sess.get_states()
|
||||
await sess.get_area_registry()
|
||||
await sess.get_entity_registry()
|
||||
assert [p["type"] for p in sent] == [
|
||||
"get_states",
|
||||
"config/area_registry/list",
|
||||
"config/entity_registry/list",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# slice-before-normalize — perf contract for /entities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeAsyncSession:
|
||||
"""A fake HA session that returns a canned state list."""
|
||||
|
||||
def __init__(self, states: list[dict[str, Any]]) -> None:
|
||||
self._states = states
|
||||
|
||||
async def get_states(self) -> list[dict[str, Any]]:
|
||||
return self._states
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
"""A fake client whose ``session()`` yields a ``_FakeAsyncSession``."""
|
||||
|
||||
def __init__(self, states: list[dict[str, Any]]) -> None:
|
||||
self._states = states
|
||||
|
||||
def session(self): # noqa: D401 — mimics real client signature
|
||||
states = self._states
|
||||
class _CM:
|
||||
async def __aenter__(self_inner):
|
||||
return _FakeAsyncSession(states)
|
||||
async def __aexit__(self_inner, *_exc):
|
||||
return False
|
||||
return _CM()
|
||||
|
||||
|
||||
def _state_row(entity_id: str, n_attrs: int = 2) -> dict[str, Any]:
|
||||
return {
|
||||
"entity_id": entity_id,
|
||||
"state": "on",
|
||||
"attributes": {f"attr_{i}": i for i in range(n_attrs)},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cmd_entities_slices_before_normalizing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""200 raw entities, count=10. Normalize must run only 10 times.
|
||||
|
||||
We instrument ``_normalize_state`` with a counter to prove the slice
|
||||
happens before the per-row transform. The total field still reports
|
||||
all 200 so the user knows the result is truncated.
|
||||
"""
|
||||
states = [_state_row(f"light.bulb_{i:03d}") for i in range(200)]
|
||||
fake_client = _FakeClient(states)
|
||||
monkeypatch.setattr(handler, "_make_ws_client", lambda provider, session: fake_client)
|
||||
|
||||
calls = {"count": 0}
|
||||
real_normalize = handler._normalize_state
|
||||
|
||||
def _counting_normalize(row: dict[str, Any]) -> dict[str, Any]:
|
||||
calls["count"] += 1
|
||||
return real_normalize(row)
|
||||
|
||||
monkeypatch.setattr(handler, "_normalize_state", _counting_normalize)
|
||||
|
||||
# ``get_http_session`` opens a real aiohttp session in the bg; bypass
|
||||
# it since our fake client never uses the session arg.
|
||||
async def _fake_http_session() -> Any:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(handler, "get_http_session", _fake_http_session)
|
||||
|
||||
provider = type("FakeProvider", (), {"config": {}, "name": "HA"})()
|
||||
result = await handler._cmd_entities(provider, args="", count=10)
|
||||
assert result["total"] == 200
|
||||
assert result["shown"] == 10
|
||||
assert len(result["entities"]) == 10
|
||||
assert calls["count"] == 10, (
|
||||
f"normalize should run once per survivor; ran {calls['count']} times"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cmd_entities_glob_filter_still_normalizes_only_survivors(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""200 raw entities mixed across 2 domains; glob narrows to one.
|
||||
|
||||
Normalize count = min(count, matching_total). Demonstrates the
|
||||
optimization composes with the filter step.
|
||||
"""
|
||||
states = [
|
||||
_state_row(f"light.bulb_{i:03d}") for i in range(100)
|
||||
] + [
|
||||
_state_row(f"sensor.temp_{i:03d}") for i in range(100)
|
||||
]
|
||||
fake_client = _FakeClient(states)
|
||||
monkeypatch.setattr(handler, "_make_ws_client", lambda provider, session: fake_client)
|
||||
|
||||
calls = {"count": 0}
|
||||
real_normalize = handler._normalize_state
|
||||
|
||||
def _counting_normalize(row: dict[str, Any]) -> dict[str, Any]:
|
||||
calls["count"] += 1
|
||||
return real_normalize(row)
|
||||
|
||||
monkeypatch.setattr(handler, "_normalize_state", _counting_normalize)
|
||||
|
||||
async def _fake_http_session() -> Any:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(handler, "get_http_session", _fake_http_session)
|
||||
|
||||
provider = type("FakeProvider", (), {"config": {}, "name": "HA"})()
|
||||
result = await handler._cmd_entities(provider, args="light.*", count=5)
|
||||
assert result["total"] == 100 # all light.* entities counted
|
||||
assert result["shown"] == 5 # but only 5 normalized
|
||||
assert calls["count"] == 5
|
||||
assert all(e["entity_id"].startswith("light.") for e in result["entities"])
|
||||
Reference in New Issue
Block a user