feat: Home Assistant provider — WebSocket subscription + bot commands

Adds Home Assistant as a service provider with two coordinated surfaces:

Notifications (subscription):
- Long-lived WebSocket client (aiohttp ws_connect) with auth handshake,
  exponential-backoff reconnect, bounded event queue, and area-registry
  enrichment cached per (re)connect
- ServiceProvider ABC gains an optional `subscribe()` method for push-style
  providers; HomeAssistantServiceProvider uses it via a per-provider
  supervisor task started in the FastAPI lifespan
- 4 event types (state_changed, automation_triggered, call_service,
  event_fired), 4 default Jinja templates (en + ru), HA-specific
  tracker filters (entity_glob, domain_allowlist, exact entity ids)
- Extracted shared dispatch pipeline (api/webhooks.py → services/
  event_dispatch.py) so subscription and webhook ingest share the same
  event_log + deferred-dispatch + quiet-hours code path

Bot commands:
- /status, /entities [glob], /state <entity_id>, /areas
- Multi-command WS session so /status and /areas cost one handshake
- Sensitive-attribute blocklist (camera access_token, entity_picture, etc.)
  and 30-attribute cap to keep /state output safe and within Telegram's
  message size
- Error-message redaction strips URL userinfo before surfacing to chat

Frontend:
- HA descriptor with toggle ConfigField type (new) and tag-input filter
  mode for free-text glob/domain lists (new TagInput component)
- 15 command slots + 4 notification slots wired into the existing
  template-config UI
This commit is contained in:
2026-05-13 14:31:56 +03:00
parent 90f958bdc6
commit 22127e2a59
79 changed files with 4042 additions and 210 deletions
@@ -565,6 +565,63 @@ async def preview_raw(
"count": 2,
# /rate_limited
"wait": 15,
# --- Home Assistant: /status, /entities, /state, /areas ---
"ok": True,
"message": "OK",
"provider_name": "Home Assistant",
"url": "http://homeassistant.local:8123",
"entity_count": 142,
"area_count": 8,
"entities": [
{
"entity_id": "binary_sensor.front_door",
"friendly_name": "Front Door",
"domain": "binary_sensor",
"state": "off",
"attributes": {"device_class": "door", "friendly_name": "Front Door"},
"device_class": "door",
"unit_of_measurement": None,
"last_changed": "2026-05-13T12:34:56.789+00:00",
"last_updated": "2026-05-13T12:34:56.789+00:00",
},
{
"entity_id": "sensor.kitchen_temperature",
"friendly_name": "Kitchen Temperature",
"domain": "sensor",
"state": "21.4",
"attributes": {"unit_of_measurement": "°C", "friendly_name": "Kitchen Temperature"},
"device_class": "temperature",
"unit_of_measurement": "°C",
"last_changed": "2026-05-13T12:30:00+00:00",
"last_updated": "2026-05-13T12:30:00+00:00",
},
],
"glob": "binary_sensor.*",
"total": 12,
"shown": 2,
# /state — single entity drill-down. ``found`` controls which branch
# of the template renders.
"found": True,
"entity_id": "light.kitchen",
"friendly_name": "Kitchen Light",
"domain": "light",
"state": "on",
"attributes": {
"brightness": 200,
"color_mode": "brightness",
},
"hidden_attr_count": 0,
"device_class": None,
"unit_of_measurement": None,
"last_changed": "2026-05-13T12:34:56.789+00:00",
"last_updated": "2026-05-13T12:34:56.789+00:00",
"reason": "",
"error": "",
# /areas
"areas": [
{"area_id": "kitchen", "name": "Kitchen", "entity_count": 14},
{"area_id": "entrance", "name": "Entrance", "entity_count": 4},
],
}
return render_template_preview(body.template, sample_ctx)
@@ -3,7 +3,7 @@
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, ValidationError
from pydantic import AnyHttpUrl, BaseModel, ValidationError, field_validator
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from typing import Any
@@ -103,6 +103,54 @@ class WebhookProviderConfig(BaseModel):
max_stored_payloads: int = 20 # 1-100
class HomeAssistantProviderConfig(BaseModel):
url: str
access_token: str
verify_tls: bool = True
event_types: list[str] | None = None
@field_validator("url")
@classmethod
def _validate_url(cls, raw: str) -> str:
"""Reject malformed URLs early so the user sees a clear error.
``AnyHttpUrl`` accepts the homelab-friendly forms
(``http://homeassistant.local:8123``) while rejecting garbage like
``not-a-url`` or ``ftp://...``. Validation is best-effort; we still
re-derive the WebSocket URL at runtime.
"""
try:
AnyHttpUrl(raw)
except ValueError as err:
raise ValueError(f"url must be a valid http(s) URL: {err}") from err
return raw
@field_validator("event_types")
@classmethod
def _validate_event_types(cls, raw: list[str] | None) -> list[str] | None:
"""Cap list size and per-entry length; reject obvious junk.
We don't whitelist event names — HA has unbounded custom event types
from third-party integrations. Length and count caps are enough to
keep a misconfiguration from blowing up the subscription handshake.
"""
if raw is None:
return None
if len(raw) > 50:
raise ValueError("event_types accepts at most 50 entries")
cleaned: list[str] = []
for entry in raw:
if not isinstance(entry, str):
raise ValueError("event_types entries must be strings")
entry = entry.strip()
if not entry:
continue
if len(entry) > 100:
raise ValueError("event_types entries must be <=100 chars")
cleaned.append(entry)
return cleaned or None
_PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
"immich": ImmichProviderConfig,
"gitea": GiteaProviderConfig,
@@ -111,6 +159,7 @@ _PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
"nut": NutProviderConfig,
"google_photos": GooglePhotosProviderConfig,
"webhook": WebhookProviderConfig,
"home_assistant": HomeAssistantProviderConfig,
}
@@ -160,6 +209,18 @@ async def _test_provider_connection(provider: ServiceProvider) -> dict[str, Any]
gp = make_google_photos_provider(http_session, provider)
return await gp.test_connection()
if provider.type == "home_assistant":
from notify_bridge_core.providers.home_assistant import HomeAssistantServiceProvider
ha = HomeAssistantServiceProvider(
session=http_session,
url=provider.config.get("url", ""),
access_token=provider.config.get("access_token", ""),
verify_tls=bool(provider.config.get("verify_tls", True)),
event_types=provider.config.get("event_types") or None,
name=provider.name,
)
return await ha.test_connection()
if provider.type in ("scheduler", "webhook"):
return {"ok": True, "message": "Virtual provider — always available"}
@@ -285,6 +285,8 @@ async def get_template_variables(
**_planka_variables(),
# --- NUT (UPS) slots ---
**_nut_variables(),
# --- Home Assistant slots ---
**_home_assistant_variables(),
# --- Scheduler slots ---
"message_scheduled_message": {
"description": "Notification for scheduled message events",
@@ -433,6 +435,58 @@ def _nut_variables() -> dict:
}
def _home_assistant_variables() -> dict:
common = {
"entity_id": "HA entity id (e.g. light.kitchen)",
"friendly_name": "Human-readable entity name from attributes.friendly_name",
"domain": "HA domain prefix (light, sensor, binary_sensor, ...)",
"attributes": "Full attributes dict of the new state",
"device_class": "Device class (motion, door, temperature, ...)",
"unit_of_measurement": "Unit suffix for numeric sensors",
"area": "Area name from the HA area registry (empty when not assigned)",
"ha_event_type": "Raw HA event_type (state_changed, automation_triggered, ...)",
"last_changed": "ISO timestamp of last state change",
"last_updated": "ISO timestamp of last attribute or state update",
}
return {
"message_ha_state_changed": {
"description": "Entity state changed",
"variables": {
**common,
"old_state": "Previous state string",
"new_state": "New state string ('removed' if entity deleted)",
},
},
"message_ha_automation_triggered": {
"description": "Automation triggered",
"variables": {
"entity_id": common["entity_id"],
"automation_name": "Automation name",
"trigger_source": "Why the automation fired",
"ha_event_type": common["ha_event_type"],
},
},
"message_ha_service_called": {
"description": "HA service called",
"variables": {
"service_called": "Qualified service name (e.g. light.turn_on)",
"service_domain": "Service domain",
"service_name": "Service name within domain",
"service_data": "Service payload dict",
"target_entity": "entity_id targeted by the call (comma-joined for multi-target)",
"ha_event_type": common["ha_event_type"],
},
},
"message_ha_event_fired": {
"description": "Other HA event fired (catch-all)",
"variables": {
"ha_event_type": common["ha_event_type"],
"event_data": "Raw event data dict (use {{ event_data | tojson }} to render)",
},
},
}
@router.post("", status_code=status.HTTP_201_CREATED)
async def create_config(
body: TemplateConfigCreate,
@@ -13,7 +13,6 @@ from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.notifications.dispatcher import NotificationDispatcher, TargetConfig
from notify_bridge_core.providers.gitea.event_parser import parse_webhook as parse_gitea_webhook
from notify_bridge_core.providers.planka.event_parser import parse_webhook as parse_planka_webhook
from notify_bridge_core.providers.webhook.event_parser import parse_webhook as parse_generic_webhook
@@ -27,13 +26,7 @@ from ..database.models import (
ServiceProvider,
WebhookPayloadLog,
)
from ..services.dispatch_helpers import (
GateReason,
apply_tracking_display_filters,
evaluate_event_gate,
get_app_timezone,
load_link_data,
)
from ..services.event_dispatch import dispatch_provider_event
_LOGGER = logging.getLogger(__name__)
@@ -131,7 +124,7 @@ def _passes_filters(
# ---------------------------------------------------------------------------
# Shared dispatch helper
# Shared dispatch helper (legacy wrapper — body moved to services/event_dispatch.py)
# ---------------------------------------------------------------------------
async def _dispatch_webhook_event(
@@ -142,185 +135,16 @@ async def _dispatch_webhook_event(
event: ServiceEvent,
detail_keys: tuple[str, ...],
) -> int:
"""Load trackers, filter, create EventLogs, dispatch notifications, and commit.
Parameters
----------
engine:
SQLAlchemy async engine.
provider_id:
ID of the ServiceProvider that received the webhook.
provider_name:
Human-readable name of the provider (for logging).
provider_config:
The provider's ``config`` dict (passed through to target config builder).
event:
Parsed :class:`ServiceEvent` to dispatch.
detail_keys:
Keys from ``event.extra`` to include in the EventLog ``details`` dict.
Returns
-------
int
Number of successfully dispatched notifications.
"""
dispatched = 0
# ``defers_to_schedule`` is collected during the loop and flushed AFTER the
# main session commits — the only side-effect of failing to schedule is a
# delayed delivery (the startup loader / catch-up scan will reschedule),
# so this is best-effort and must not roll back the DB writes.
defers_to_schedule: set[Any] = set()
async with AsyncSession(engine) as session:
# App timezone is identical across trackers within one webhook request;
# pull it once.
app_tz = await get_app_timezone(session)
tracker_result = await session.exec(
select(NotificationTracker).where(
NotificationTracker.provider_id == provider_id,
NotificationTracker.enabled == True, # noqa: E712
)
)
trackers = tracker_result.all()
from ..services.deferred_dispatch import defer_event, is_deferrable
for tracker in trackers:
filters = tracker.filters or {}
if not _passes_filters(event, filters):
_LOGGER.debug(
"Event filtered out for tracker %d (%s)", tracker.id, tracker.name
)
continue
link_data = await load_link_data(session, tracker.id)
if not link_data:
continue
# Log event
extra_details = {k: v for k, v in event.extra.items() if k in detail_keys}
event_log_row = EventLog(
user_id=tracker.user_id,
tracker_id=tracker.id,
tracker_name=tracker.name,
provider_id=provider_id,
provider_name=provider_name,
event_type=event.event_type.value,
collection_id=event.collection_id,
collection_name=event.collection_name,
assets_count=0,
details={
"provider_type": event.provider_type.value,
**extra_details,
},
)
session.add(event_log_row)
await session.flush()
event_log_id = event_log_row.id
# Dedupe defers by parent ``link_id``: broadcast links emit one
# ``link_data`` entry per child, all sharing the same parent id —
# the deferred row is one-per-link, so we only call ``defer_event``
# once per distinct id (earliest fire_at wins on ties).
groups: dict[int, tuple[Any, list[TargetConfig]]] = {}
defers_for_event: dict[int, Any] = {}
for ld in link_data:
tc = ld["tracking_config"]
if tc is not None:
outcome = evaluate_event_gate(event, tc, app_tz)
if outcome.reason is GateReason.QUIET_HOURS:
if is_deferrable(event.event_type.value) and outcome.quiet_hours_end_at is not None:
link_id = ld.get("link_id")
if link_id is not None:
prior = defers_for_event.get(link_id)
if prior is None or outcome.quiet_hours_end_at < prior:
defers_for_event[link_id] = outcome.quiet_hours_end_at
continue
if outcome.reason is GateReason.EVENT_TYPE_DISABLED:
continue
tmpl = ld["template_config"]
target_cfg = TargetConfig(
type=ld["target_type"],
config=ld["target_config"],
template_slots=ld["template_slots"],
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
date_only_format=tmpl.date_only_format if tmpl and tmpl.date_only_format else "%d.%m.%Y",
provider_api_key=provider_config.get("api_token"),
provider_internal_url=provider_config.get("url", ""),
provider_external_url=provider_config.get("url", ""),
receivers=ld["receivers"],
)
key = id(tc) if tc is not None else 0
if key not in groups:
groups[key] = (tc, [])
groups[key][1].append(target_cfg)
# Persist defers + stamp event_log dispatch_status in the same
# session that holds the EventLog row, so the "deferred" badge
# only appears if the underlying queue rows actually exist.
if defers_for_event:
earliest = min(defers_for_event.values())
for link_id, fire_at in defers_for_event.items():
await defer_event(
session,
event=event,
user_id=tracker.user_id,
tracker_id=tracker.id,
link_id=link_id,
event_log_id=event_log_id,
fire_at=fire_at,
)
details = dict(event_log_row.details or {})
if not details.get("dispatch_status"):
details["dispatch_status"] = "deferred"
details["deferred_until"] = earliest.isoformat()
event_log_row.details = details
session.add(event_log_row)
defers_to_schedule.update(defers_for_event.values())
# Dispatch to targets. Isolate dispatcher exceptions per group so
# a failed remote call doesn't bubble out, abort the surrounding
# transaction, and roll back the just-written defers/event_log.
from ..services.http_session import get_http_session
dispatcher = NotificationDispatcher(session=await get_http_session())
for tc, target_configs in groups.values():
if not target_configs:
continue
shaped_event = apply_tracking_display_filters(event, tc)
if shaped_event is None:
continue
try:
results = await dispatcher.dispatch(shaped_event, target_configs)
except Exception as err: # noqa: BLE001
_LOGGER.exception(
"Dispatcher raised for tracker %d: %s", tracker.id, err,
)
continue
for r in results:
if r.get("success"):
dispatched += 1
else:
_LOGGER.error(
"Notification failed for tracker %d: %s",
tracker.id, r.get("error", "unknown"),
)
await session.commit()
# Schedule drain jobs OUTSIDE the DB session so an APScheduler hiccup
# can't roll back the persisted defer rows.
if defers_to_schedule:
from ..services.scheduler import schedule_deferred_drain
for fire_at in defers_to_schedule:
try:
schedule_deferred_drain(fire_at)
except Exception: # noqa: BLE001
_LOGGER.exception(
"Failed to schedule deferred drain for %s", fire_at,
)
return dispatched
"""Webhook-flavoured dispatch — thin wrapper over ``dispatch_provider_event``."""
return await dispatch_provider_event(
engine=engine,
provider_id=provider_id,
provider_name=provider_name,
provider_config=provider_config,
event=event,
detail_keys=detail_keys,
filter_fn=_passes_filters,
)
# ---------------------------------------------------------------------------
@@ -34,12 +34,14 @@ def _auto_register() -> None:
from .planka_handler import PlankaCommandHandler
from .nut_handler import NutCommandHandler
from .webhook_handler import WebhookCommandHandler
from .home_assistant_handler import HomeAssistantCommandHandler
register_handler(ImmichCommandHandler())
register_handler(GiteaCommandHandler())
register_handler(PlankaCommandHandler())
register_handler(NutCommandHandler())
register_handler(WebhookCommandHandler())
register_handler(HomeAssistantCommandHandler())
# Auto-register on import
@@ -0,0 +1,375 @@
"""Home Assistant bot command handler.
Phase 2 of the HA integration. Each command opens a fresh WebSocket
connection to HA — same approach used by ``HomeAssistantServiceProvider.
list_collections`` — so the handler does not need to coordinate with the
long-lived subscription supervisor.
Commands:
* ``/status`` — connection health, subscribed area / entity counts.
* ``/entities [glob]`` — list matching entities with their current state.
* ``/state <entity_id>`` — full state + attributes for one entity.
* ``/areas`` — area registry summary with entity counts per area.
"""
from __future__ import annotations
import logging
from fnmatch import fnmatchcase
from typing import Any
import aiohttp
from notify_bridge_core.providers.home_assistant import (
HomeAssistantApiError,
HomeAssistantAuthError,
HomeAssistantWSClient,
redact_ha_message,
)
from ..database.models import (
CommandConfig,
CommandTracker,
CommandTrackerListener,
ServiceProvider,
TelegramBot,
)
from ..services.http_session import get_http_session
from .base import CommandResponse, ProviderCommandHandler
from .handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__)
_HA_COMMANDS = {"status", "entities", "state", "areas"}
# HA exposes credentials and tokens through state attributes for some
# integrations — most notably ``camera.*`` entities surface a working
# ``access_token`` for the camera proxy URL, and ``entity_picture`` can
# carry signed URLs. Filtering keys by substring blocklist before rendering
# protects the chat user from seeing those values in /state output.
#
# Match is case-insensitive substring; tokens are intentionally generic so
# custom integrations that follow the obvious naming conventions are also
# covered. Anything not matched still renders.
_SENSITIVE_ATTR_TOKENS: tuple[str, ...] = (
"access_token",
"token",
"secret",
"password",
"passwd",
"api_key",
"apikey",
"private_key",
"session_id",
"authorization",
"bearer",
"cookie",
# ``entity_picture`` is a URL that often embeds a signed token in its
# query string (HA generates these for camera and media_player entities).
# The key itself doesn't match the credential token blocklist, so it
# gets its own explicit entry.
"entity_picture",
)
# Attributes already rendered as top-level fields by the state template; no
# point repeating them in the "Attributes" iteration.
_TOP_LEVEL_ATTRS: frozenset[str] = frozenset({
"friendly_name", "unit_of_measurement", "device_class",
})
# Hard cap on the number of attributes shown in /state to prevent message
# truncation when an entity has dozens (e.g. weather hourly forecasts,
# light supported features). After the cap, an "and N more" line is added
# by the template logic.
_MAX_ATTRIBUTES_RENDERED = 30
def _is_sensitive_attr(key: str) -> bool:
lowered = str(key).lower()
return any(tok in lowered for tok in _SENSITIVE_ATTR_TOKENS)
def _filter_attributes(attrs: dict[str, Any]) -> tuple[dict[str, Any], int]:
"""Drop sensitive keys, cap count, return ``(visible_attrs, hidden_count)``.
Hidden count covers both the security filter (blocklisted keys) and the
size cap (entries beyond ``_MAX_ATTRIBUTES_RENDERED``). The template can
surface "and N more hidden" so users know the view is incomplete.
"""
if not isinstance(attrs, dict):
return {}, 0
safe: dict[str, Any] = {}
redacted = 0
for key, value in attrs.items():
if not isinstance(key, str):
continue
if key in _TOP_LEVEL_ATTRS:
continue
if _is_sensitive_attr(key):
redacted += 1
continue
safe[key] = value
overflow = max(0, len(safe) - _MAX_ATTRIBUTES_RENDERED)
if overflow > 0:
# Stable order — sort by key so the truncation point is deterministic.
capped = dict(sorted(safe.items())[:_MAX_ATTRIBUTES_RENDERED])
return capped, redacted + overflow
return safe, redacted
def _make_ws_client(provider: ServiceProvider, session: aiohttp.ClientSession) -> HomeAssistantWSClient:
"""Build a one-shot WS client from the provider row."""
config = provider.config or {}
return HomeAssistantWSClient(
session=session,
base_url=config.get("url", ""),
access_token=config.get("access_token", ""),
verify_tls=bool(config.get("verify_tls", True)),
)
def _domain_of(entity_id: str) -> str:
return entity_id.split(".", 1)[0] if "." in entity_id else ""
def _normalize_state(state_row: dict[str, Any]) -> dict[str, Any]:
"""Flatten an HA state dict into the shape templates consume.
``attributes`` is filtered through ``_filter_attributes`` to drop
credential-like keys (e.g. ``camera.access_token``) and cap the rendered
count. ``hidden_attr_count`` is exposed so the template can surface
"and N more hidden" if the user wants to see everything they need to
use a different tool (or the HA UI itself).
"""
entity_id = state_row.get("entity_id") or ""
raw_attrs = state_row.get("attributes") or {}
visible_attrs, hidden_count = _filter_attributes(raw_attrs)
return {
"entity_id": entity_id,
"friendly_name": raw_attrs.get("friendly_name") or entity_id,
"domain": _domain_of(entity_id),
"state": state_row.get("state"),
"attributes": visible_attrs,
"hidden_attr_count": hidden_count,
"device_class": raw_attrs.get("device_class"),
"unit_of_measurement": raw_attrs.get("unit_of_measurement"),
"last_changed": state_row.get("last_changed"),
"last_updated": state_row.get("last_updated"),
}
# ---------------------------------------------------------------------------
# Command implementations
# ---------------------------------------------------------------------------
async def _cmd_status(provider: ServiceProvider) -> dict[str, Any]:
"""``/status`` — connection health + counts.
Health is derived from a live connection rather than the supervisor's
in-memory state so the user sees what's happening *right now* if they
just edited the token / URL. Connection + entity-count + area-count run
on a single WS session so a healthy /status costs one TCP + TLS + WS +
auth handshake instead of three.
"""
session = await get_http_session()
client = _make_ws_client(provider, session)
ok = True
message = "OK"
entity_count = 0
area_count = 0
try:
async with client.session() as sess:
# Reaching here proves connect + auth succeeded.
try:
entity_count = len(await sess.get_states())
except HomeAssistantApiError as err:
_LOGGER.debug("HA /status get_states failed: %s", err)
try:
area_count = len(await sess.get_area_registry())
except HomeAssistantApiError as err:
_LOGGER.debug("HA /status get_area_registry failed: %s", err)
except HomeAssistantAuthError as err:
ok = False
message = f"Auth failed: {redact_ha_message(str(err))}"
except (aiohttp.ClientError, HomeAssistantApiError) as err:
ok = False
message = redact_ha_message(str(err)) or "Connection failed"
return {
"ok": ok,
"message": message,
"provider_name": provider.name or "",
"url": (provider.config or {}).get("url", ""),
"entity_count": entity_count,
"area_count": area_count,
}
async def _cmd_entities(provider: ServiceProvider, args: str, count: int) -> dict[str, Any]:
"""``/entities [glob]`` — entities filtered by glob, capped at ``count``.
Empty args returns the first ``count`` entities in entity_id order. A
glob pattern is matched against the entity_id (case-insensitive). The
normalization step (which walks the attribute dict to redact secrets)
runs **only on the survivors** — sorting and slicing happen on raw
state rows first, so an HA install with 1000+ entities doesn't
materialize 1000 normalized dicts just to discard most of them.
"""
session = await get_http_session()
client = _make_ws_client(provider, session)
try:
async with client.session() as sess:
states = await sess.get_states()
except (HomeAssistantApiError, aiohttp.ClientError, HomeAssistantAuthError) as err:
redacted = redact_ha_message(str(err))
_LOGGER.warning("HA /entities failed: %s", redacted)
return {"entities": [], "glob": args.strip(), "total": 0, "error": redacted}
glob = args.strip()
if glob:
lower_glob = glob.lower()
raw_matches = [
s for s in states
if isinstance(s.get("entity_id"), str)
and fnmatchcase(s["entity_id"].lower(), lower_glob)
]
else:
raw_matches = [s for s in states if isinstance(s.get("entity_id"), str)]
total = len(raw_matches)
raw_matches.sort(key=lambda s: s.get("entity_id", ""))
return {
"entities": [_normalize_state(s) for s in raw_matches[:count]],
"glob": glob,
"total": total,
"shown": min(count, total),
}
async def _cmd_state(provider: ServiceProvider, args: str) -> dict[str, Any]:
"""``/state <entity_id>`` — single-entity drill-down.
Returns ``found=False`` when the entity_id is missing or not present.
Templates render the no-results fallback in that case. Uses the session
context manager for consistency with the other commands even though
there's only one underlying WS call today — leaves the door open for
Phase 3 (service calls) to chain a follow-up on the same socket.
"""
target = args.strip()
if not target:
return {"found": False, "entity_id": "", "reason": "missing_arg"}
session = await get_http_session()
client = _make_ws_client(provider, session)
try:
async with client.session() as sess:
states = await sess.get_states()
except (HomeAssistantApiError, aiohttp.ClientError, HomeAssistantAuthError) as err:
redacted = redact_ha_message(str(err))
_LOGGER.warning("HA /state failed: %s", redacted)
return {"found": False, "entity_id": target, "reason": "api_error", "error": redacted}
for s in states:
if s.get("entity_id") == target:
normalized = _normalize_state(s)
return {"found": True, **normalized}
return {"found": False, "entity_id": target, "reason": "not_found"}
async def _cmd_areas(provider: ServiceProvider) -> dict[str, Any]:
"""``/areas`` — area registry with per-area entity counts.
Areas without entities are still listed so users can see which areas
exist in HA but haven't been assigned anything. The entity counts come
from the entity registry, not the state list — the registry includes
disabled entities, which matches what users see in the HA UI. Both
registry calls share a single WS session so /areas costs one handshake.
"""
session = await get_http_session()
client = _make_ws_client(provider, session)
try:
async with client.session() as sess:
areas = await sess.get_area_registry()
# Entity registry failure is non-fatal — areas can still be
# listed without per-area counts.
try:
entities = await sess.get_entity_registry()
except HomeAssistantApiError:
entities = []
except (HomeAssistantApiError, aiohttp.ClientError, HomeAssistantAuthError) as err:
redacted = redact_ha_message(str(err))
_LOGGER.warning("HA /areas failed: %s", redacted)
return {"areas": [], "total": 0, "error": redacted}
counts: dict[str, int] = {}
for ent in entities:
area_id = ent.get("area_id")
if isinstance(area_id, str):
counts[area_id] = counts.get(area_id, 0) + 1
rows: list[dict[str, Any]] = []
for a in areas:
area_id = a.get("area_id")
if not isinstance(area_id, str):
continue
rows.append({
"area_id": area_id,
"name": a.get("name") or area_id,
"entity_count": counts.get(area_id, 0),
})
rows.sort(key=lambda r: r.get("name", "").lower())
return {"areas": rows, "total": len(rows)}
# ---------------------------------------------------------------------------
# Handler class
# ---------------------------------------------------------------------------
class HomeAssistantCommandHandler(ProviderCommandHandler):
"""Routes ``/status``, ``/entities``, ``/state``, ``/areas`` to the WS client."""
provider_type = "home_assistant"
def get_provider_commands(self) -> set[str]:
return _HA_COMMANDS
def get_rate_categories(self) -> dict[str, str]:
# All HA commands hit the WS API and share an "api" rate-limit bucket.
return {cmd: "api" for cmd in _HA_COMMANDS}
async def handle(
self,
cmd: str,
args: str,
count: int,
locale: str,
response_mode: str, # noqa: ARG002 — HA has no media commands; always text
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot, # noqa: ARG002
tracker: CommandTracker, # noqa: ARG002
config: CommandConfig, # noqa: ARG002
*,
listener: CommandTrackerListener | None = None, # noqa: ARG002
allowed_album_ids: set[str] | None = None, # noqa: ARG002 — HA has no album scope
page: int = 1, # noqa: ARG002 — no pagination in v1
) -> CommandResponse | None:
if cmd == "status":
ctx = await _cmd_status(provider)
elif cmd == "entities":
ctx = await _cmd_entities(provider, args, count)
elif cmd == "state":
ctx = await _cmd_state(provider, args)
elif cmd == "areas":
ctx = await _cmd_areas(provider)
else:
return None
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
@@ -11,6 +11,8 @@ _RATE_CATEGORY: dict[str, str] = {
"repos": "api", "issues": "api", "prs": "api", "commits": "api",
# Planka (API calls share a category)
"boards": "api", "cards": "api", "lists": "api",
# Home Assistant (WebSocket queries share a category)
"entities": "api", "state": "api", "areas": "api",
}
@@ -292,6 +292,23 @@ async def migrate_schema(engine: AsyncEngine) -> None:
)
logger.info("Added track_webhook_received column to tracking_config table")
# Add Home Assistant tracking flags to tracking_config if missing.
# state_changed defaults ON to match the canonical "watch the state bus"
# use case; the other three are loud and opt-in (defaults 0).
if await _has_table(conn, "tracking_config"):
ha_flags = [
("track_ha_state_changed", "INTEGER DEFAULT 1"),
("track_ha_automation_triggered", "INTEGER DEFAULT 0"),
("track_ha_service_called", "INTEGER DEFAULT 0"),
("track_ha_event_fired", "INTEGER DEFAULT 0"),
]
for col_name, col_type in ha_flags:
if not await _has_column(conn, "tracking_config", col_name):
await conn.execute(
text(f"ALTER TABLE tracking_config ADD COLUMN {col_name} {col_type}")
)
logger.info("Added %s column to tracking_config table", col_name)
# Add quiet hours to tracking_config if missing.
# Start/end are nullable HH:MM strings; quiet_hours_enabled gates them.
if await _has_table(conn, "tracking_config"):
@@ -165,6 +165,12 @@ class TrackingConfig(SQLModel, table=True):
# Generic Webhook event tracking
track_webhook_received: bool = Field(default=True)
# Home Assistant event tracking
track_ha_state_changed: bool = Field(default=True)
track_ha_automation_triggered: bool = Field(default=False)
track_ha_service_called: bool = Field(default=False)
track_ha_event_fired: bool = Field(default=False)
# Immich asset display
track_images: bool = Field(default=True)
track_videos: bool = Field(default=True)
@@ -158,6 +158,7 @@ async def _seed_default_templates() -> None:
await _seed_provider_template(session, "nut", "NUT")
await _seed_provider_template(session, "google_photos", "Google Photos")
await _seed_provider_template(session, "webhook", "Generic Webhook")
await _seed_provider_template(session, "home_assistant", "Home Assistant")
await session.commit()
@@ -187,6 +188,10 @@ async def _seed_default_command_templates() -> None:
await _seed_provider_command_template(
session, "webhook", "Default Webhook Commands", "Default Generic Webhook command templates",
)
await _seed_provider_command_template(
session, "home_assistant", "Default Home Assistant Commands",
"Default Home Assistant command templates",
)
await session.commit()
@@ -272,6 +277,14 @@ async def _seed_default_tracking_configs() -> None:
"track_ups_replace_battery": True,
"track_ups_overload": True,
},
{
"provider_type": "home_assistant",
"name": "Default Home Assistant",
"track_ha_state_changed": True,
"track_ha_automation_triggered": False,
"track_ha_service_called": False,
"track_ha_event_fired": False,
},
]
for cfg in defaults:
@@ -139,12 +139,20 @@ async def lifespan(app: FastAPI):
set_webhook_secret(_secret or None)
from .services.scheduler import start_scheduler, get_scheduler
await start_scheduler()
# Phase 1 of the Home Assistant provider: subscription-based ingest runs
# outside the polling scheduler. ``start_all`` spawns one supervisor task
# per enabled HA provider row. No-op when no HA providers are configured.
from .services.ha_subscription import start_all as start_ha_subscriptions
await start_ha_subscriptions()
_READY = True
yield
# Graceful shutdown — stop the scheduler FIRST so in-flight jobs finish
# before we close their HTTP session. Then close the shared session and
# dispose the DB engine.
# Graceful shutdown — cancel HA supervisors FIRST so they release their
# WS connections before the shared HTTP session is closed. Then stop the
# polling scheduler. Order matters: scheduler.shutdown(wait=True) drains
# in-flight jobs that may also use the shared session.
_READY = False
from .services.ha_subscription import stop_all as stop_ha_subscriptions
await stop_ha_subscriptions()
scheduler = get_scheduler()
if scheduler.running:
scheduler.shutdown(wait=True)
@@ -115,6 +115,16 @@ def _make_collection_provider(
return make_planka_provider(http_session, provider)
if ptype == "google_photos":
return make_google_photos_provider(http_session, provider)
if ptype == "home_assistant":
from notify_bridge_core.providers.home_assistant import HomeAssistantServiceProvider
return HomeAssistantServiceProvider(
session=http_session,
url=config.get("url", ""),
access_token=config.get("access_token", ""),
verify_tls=bool(config.get("verify_tls", True)),
event_types=config.get("event_types") or None,
name=provider.name,
)
# NUT provider needs no http_session
if ptype == "nut":
return make_nut_provider(provider) # type: ignore[return-value]
@@ -122,7 +132,7 @@ def _make_collection_provider(
# Set of provider types that need an aiohttp session for collection listing.
_HTTP_COLLECTION_PROVIDERS = {"immich", "gitea", "planka", "google_photos"}
_HTTP_COLLECTION_PROVIDERS = {"immich", "gitea", "planka", "google_photos", "home_assistant"}
async def list_provider_collections(provider: ServiceProvider) -> list[dict[str, Any]]:
@@ -204,6 +204,12 @@ def _event_type_enabled(event: ServiceEvent, tc: TrackingConfig) -> bool:
"ups_comms_restored": tc.track_ups_comms_restored,
"ups_replace_battery": tc.track_ups_replace_battery,
"ups_overload": tc.track_ups_overload,
# Home Assistant events — use getattr so legacy DB rows / test mocks
# that pre-date the columns still pass the gate (default to tracked).
"ha_state_changed": getattr(tc, "track_ha_state_changed", True),
"ha_automation_triggered": getattr(tc, "track_ha_automation_triggered", False),
"ha_service_called": getattr(tc, "track_ha_service_called", False),
"ha_event_fired": getattr(tc, "track_ha_event_fired", False),
}
return flag_map.get(event_type, True)
@@ -0,0 +1,239 @@
"""Shared dispatch helper for push-style providers.
Push-style providers (webhook receivers in ``api/webhooks.py`` and the
Home Assistant subscription manager in ``services/ha_subscription.py``)
share the same downstream pipeline: write an :class:`EventLog`, evaluate
quiet hours / event-type gates, defer if needed, otherwise hand off to the
:class:`NotificationDispatcher`.
This module extracts that pipeline so both callers can reuse it without
either side importing from the other (which would create a server/api ->
services -> api cycle).
"""
from __future__ import annotations
import logging
from typing import Any, Awaitable, Callable
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.notifications.dispatcher import (
NotificationDispatcher,
TargetConfig,
)
from ..database.models import EventLog, NotificationTracker
from .deferred_dispatch import defer_event, is_deferrable
from .dispatch_helpers import (
GateReason,
apply_tracking_display_filters,
evaluate_event_gate,
get_app_timezone,
load_link_data,
)
_LOGGER = logging.getLogger(__name__)
# Filter signature: ``(event, tracker.filters dict) -> bool``. Returning False
# drops the event for that tracker before any DB writes happen. Callers pass
# provider-specific logic (Gitea sender allowlist, HA entity glob, etc.).
FilterFn = Callable[[ServiceEvent, dict[str, Any]], bool]
async def dispatch_provider_event(
engine: Any,
provider_id: int,
provider_name: str,
provider_config: dict[str, Any],
event: ServiceEvent,
detail_keys: tuple[str, ...],
filter_fn: FilterFn,
) -> int:
"""Load matching trackers, log, gate, defer, and dispatch one event.
Parameters
----------
engine:
SQLAlchemy async engine.
provider_id:
ID of the :class:`ServiceProvider` the event came from.
provider_name:
Human-readable name (for logging only).
provider_config:
``ServiceProvider.config`` dict; flowed into :class:`TargetConfig`.
event:
Parsed :class:`ServiceEvent` to dispatch.
detail_keys:
Keys from ``event.extra`` to copy into ``EventLog.details``.
filter_fn:
Per-event tracker-level filter. Returning False drops the event for
that tracker before any DB writes.
Returns
-------
int
Number of successfully dispatched notifications across all trackers.
"""
dispatched = 0
# Drain-scheduling is best-effort: a scheduling failure must not roll
# back the persisted defer rows (startup catch-up re-establishes them).
defers_to_schedule: set[Any] = set()
async with AsyncSession(engine) as session:
# App timezone is identical across trackers in one inbound event;
# pull it once.
app_tz = await get_app_timezone(session)
tracker_result = await session.exec(
select(NotificationTracker).where(
NotificationTracker.provider_id == provider_id,
NotificationTracker.enabled == True, # noqa: E712
)
)
trackers = tracker_result.all()
for tracker in trackers:
filters = tracker.filters or {}
if not filter_fn(event, filters):
_LOGGER.debug(
"Event filtered out for tracker %d (%s)", tracker.id, tracker.name
)
continue
link_data = await load_link_data(session, tracker.id)
if not link_data:
continue
extra_details = {k: v for k, v in event.extra.items() if k in detail_keys}
event_log_row = EventLog(
user_id=tracker.user_id,
tracker_id=tracker.id,
tracker_name=tracker.name,
provider_id=provider_id,
provider_name=provider_name,
event_type=event.event_type.value,
collection_id=event.collection_id,
collection_name=event.collection_name,
assets_count=0,
details={
"provider_type": event.provider_type.value,
**extra_details,
},
)
session.add(event_log_row)
await session.flush()
event_log_id = event_log_row.id
# Dedupe defers by parent link_id: broadcast links emit one
# link_data entry per child, sharing the same parent id — the
# deferred row is one-per-link, so we call defer_event only
# once per distinct id (earliest fire_at wins on ties).
groups: dict[int, tuple[Any, list[TargetConfig]]] = {}
defers_for_event: dict[int, Any] = {}
for ld in link_data:
tc = ld["tracking_config"]
if tc is not None:
outcome = evaluate_event_gate(event, tc, app_tz)
if outcome.reason is GateReason.QUIET_HOURS:
if (
is_deferrable(event.event_type.value)
and outcome.quiet_hours_end_at is not None
):
link_id = ld.get("link_id")
if link_id is not None:
prior = defers_for_event.get(link_id)
if prior is None or outcome.quiet_hours_end_at < prior:
defers_for_event[link_id] = outcome.quiet_hours_end_at
continue
if outcome.reason is GateReason.EVENT_TYPE_DISABLED:
continue
tmpl = ld["template_config"]
target_cfg = TargetConfig(
type=ld["target_type"],
config=ld["target_config"],
template_slots=ld["template_slots"],
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
date_only_format=(
tmpl.date_only_format
if tmpl and tmpl.date_only_format
else "%d.%m.%Y"
),
provider_api_key=provider_config.get("api_token"),
provider_internal_url=provider_config.get("url", ""),
provider_external_url=provider_config.get("url", ""),
receivers=ld["receivers"],
)
key = id(tc) if tc is not None else 0
if key not in groups:
groups[key] = (tc, [])
groups[key][1].append(target_cfg)
# Persist defers + stamp event_log dispatch_status in the same
# session that holds the EventLog row, so the "deferred" badge
# only appears if the underlying queue rows actually exist.
if defers_for_event:
earliest = min(defers_for_event.values())
for link_id, fire_at in defers_for_event.items():
await defer_event(
session,
event=event,
user_id=tracker.user_id,
tracker_id=tracker.id,
link_id=link_id,
event_log_id=event_log_id,
fire_at=fire_at,
)
details = dict(event_log_row.details or {})
if not details.get("dispatch_status"):
details["dispatch_status"] = "deferred"
details["deferred_until"] = earliest.isoformat()
event_log_row.details = details
session.add(event_log_row)
defers_to_schedule.update(defers_for_event.values())
# Dispatch to targets. Isolate dispatcher exceptions per group so
# a failed remote call doesn't bubble out, abort the surrounding
# transaction, and roll back the just-written defers / event_log.
from .http_session import get_http_session
dispatcher = NotificationDispatcher(session=await get_http_session())
for tc, target_configs in groups.values():
if not target_configs:
continue
shaped_event = apply_tracking_display_filters(event, tc)
if shaped_event is None:
continue
try:
results = await dispatcher.dispatch(shaped_event, target_configs)
except Exception as err: # noqa: BLE001
_LOGGER.exception(
"Dispatcher raised for tracker %d: %s", tracker.id, err,
)
continue
for r in results:
if r.get("success"):
dispatched += 1
else:
_LOGGER.error(
"Notification failed for tracker %d: %s",
tracker.id, r.get("error", "unknown"),
)
await session.commit()
# Schedule drain jobs OUTSIDE the DB session so an APScheduler hiccup
# can't roll back the persisted defer rows.
if defers_to_schedule:
from .scheduler import schedule_deferred_drain
for fire_at in defers_to_schedule:
try:
schedule_deferred_drain(fire_at)
except Exception: # noqa: BLE001
_LOGGER.exception(
"Failed to schedule deferred drain for %s", fire_at,
)
return dispatched
@@ -0,0 +1,293 @@
"""Home Assistant subscription manager.
Phase 1 of the HA provider lives here. For every enabled ``home_assistant``
:class:`ServiceProvider` row in the DB, this module spawns one long-running
asyncio task that:
1. Builds an :class:`HomeAssistantServiceProvider` from the provider row.
2. Calls ``provider.subscribe(emit)`` which loops forever — connect,
authenticate, subscribe, drain events through ``emit`` — and reconnects
with exponential backoff on any drop.
3. Each ``emit`` call hands the parsed :class:`ServiceEvent` to
:func:`dispatch_provider_event` (the shared dispatch helper that webhook
providers also use), so quiet hours, deferred dispatch, and event-log
writes all behave identically to the rest of the system.
Lifecycle is owned by ``main.py`` via :func:`start_all` and :func:`stop_all`.
Phase 1 does not reconcile against DB changes after boot — adding,
modifying, or removing a HA provider requires a server restart. Phase 1.5
will add a CRUD-triggered :func:`reload_provider` hook.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.providers.home_assistant import (
HomeAssistantAuthError,
HomeAssistantServiceProvider,
)
from ..database.engine import get_engine
from ..database.models import ServiceProvider
from .event_dispatch import dispatch_provider_event
from .http_session import get_http_session
_LOGGER = logging.getLogger(__name__)
# Per-provider running task. Keyed by provider_id so reload_provider() can
# find and replace a single task without disturbing the rest.
_running_tasks: dict[int, asyncio.Task[None]] = {}
# Keys from ``event.extra`` to copy into ``EventLog.details``. Anything not
# in this list is still available to templates via the merged extras, but
# the event-log row stays slim.
_HA_DETAIL_KEYS: tuple[str, ...] = (
"entity_id",
"friendly_name",
"domain",
"old_state",
"new_state",
"device_class",
"unit_of_measurement",
"area",
"ha_event_type",
"automation_name",
"service_called",
"target_entity",
)
def _ha_passes_filters(event: ServiceEvent, filters: dict[str, Any]) -> bool:
"""HA-specific tracker filter.
Three filter keys, all optional, evaluated as a union: if the entity
matches any one of them, the event passes. Empty filters mean "accept
everything" — different from the Gitea filter which is an intersection.
Filter shape:
* ``collections`` — list of exact ``entity_id`` matches.
* ``entity_glob`` — list of glob patterns (``light.*``, ``*_motion``).
* ``domain_allowlist`` — list of HA domain prefixes (``light``).
"""
collections = filters.get("collections") or []
entity_globs = filters.get("entity_glob") or []
domain_allowlist = filters.get("domain_allowlist") or []
# No filters configured = accept everything.
if not collections and not entity_globs and not domain_allowlist:
return True
entity_id = event.collection_id
domain = event.extra.get("domain") or (
entity_id.split(".", 1)[0] if "." in entity_id else ""
)
if collections and entity_id in collections:
return True
if domain_allowlist and domain in domain_allowlist:
return True
if entity_globs:
from fnmatch import fnmatchcase
for pattern in entity_globs:
if isinstance(pattern, str) and fnmatchcase(entity_id, pattern):
return True
return False
async def _run_provider(provider_id: int) -> None:
"""One per-provider supervisor loop.
Reloads the provider row each iteration so config changes (URL, token,
event types) take effect on the next reconnect cycle — no need for a
full restart in the simple case where only credentials changed.
The ``_emit`` closure is rebuilt every iteration. Its lifetime equals
one ``subscribe()`` call: the callback only runs while the HA client's
drain task is alive. ``provider_name`` is snapshotted at the start of
each (re)connect cycle, so renames take effect on the next reconnect —
chatty enough for v1; revisit if longer-lived WS sessions need fresher
names mid-stream.
"""
assert provider_id is not None, "_run_provider requires a real provider id"
engine = get_engine()
while True:
try:
async with AsyncSession(engine) as session:
row = await session.get(ServiceProvider, provider_id)
if row is None or row.type != "home_assistant":
_LOGGER.info(
"HA provider %s removed or retyped, stopping supervisor",
provider_id,
)
return
config = dict(row.config or {})
provider_name = row.name
url = config.get("url", "")
access_token = config.get("access_token", "")
verify_tls = bool(config.get("verify_tls", True))
event_types = config.get("event_types") or None
if not url or not access_token:
_LOGGER.warning(
"HA provider %s missing url or access_token; retrying in 60s",
provider_id,
)
await asyncio.sleep(60)
continue
session_http = await get_http_session()
ha_provider = HomeAssistantServiceProvider(
session=session_http,
url=url,
access_token=access_token,
verify_tls=verify_tls,
event_types=event_types,
name=provider_name,
)
async def _emit(event: ServiceEvent) -> None:
# Shield the DB-writing dispatch from external cancellation
# (shutdown, supervisor restart). The shield ensures that
# once a transaction is mid-flight, it commits or rolls back
# cleanly instead of being torn down with the asyncio task
# at a write boundary. Worst case: shutdown waits up to one
# dispatch latency longer.
#
# Perf note (Phase 2 follow-up): dispatch_provider_event
# opens a fresh AsyncSession per call. For HA's chatty
# state_changed bus this hammers the pool; batch in a
# follow-up.
try:
await asyncio.shield(dispatch_provider_event(
engine=engine,
provider_id=provider_id,
provider_name=provider_name,
provider_config=config,
event=event,
detail_keys=_HA_DETAIL_KEYS,
filter_fn=_ha_passes_filters,
))
except asyncio.CancelledError:
# Shield re-raises CancelledError to the caller; let it
# propagate so the drain task exits cleanly.
raise
except Exception: # noqa: BLE001
_LOGGER.exception(
"Failed to dispatch HA event for provider %s",
provider_id,
)
_LOGGER.info(
"Starting HA subscription for provider %s (%s)",
provider_id, provider_name,
)
await ha_provider.subscribe(_emit)
except asyncio.CancelledError:
raise
except HomeAssistantAuthError as err:
# Fatal at the provider level — bad token. Sleep long and retry
# so the user has time to fix the token without us hammering HA.
# Error string already redacted by the client before re-raise.
_LOGGER.error(
"HA provider %s auth failed: %s — retrying in 5 minutes",
provider_id, err,
)
await asyncio.sleep(300)
except Exception: # noqa: BLE001
_LOGGER.exception(
"HA supervisor for provider %s crashed; restarting in 30s",
provider_id,
)
await asyncio.sleep(30)
def _make_done_callback(provider_id: int):
"""Return a done-callback that prunes the task from ``_running_tasks``.
Without this, finished supervisors (provider deleted, fatal auth error
after long sleep) leave stale entries in the dict — across many
reload cycles the dict would grow unboundedly. The callback is
registered on every task spawned via ``start_all`` / ``reload_provider``.
"""
def _cb(task: asyncio.Task[None]) -> None:
current = _running_tasks.get(provider_id)
if current is task:
_running_tasks.pop(provider_id, None)
return _cb
async def start_all() -> None:
"""Start a supervisor task for every enabled HA provider."""
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(ServiceProvider).where(
ServiceProvider.type == "home_assistant",
)
)
providers = result.all()
for prov in providers:
if prov.id in _running_tasks and not _running_tasks[prov.id].done():
continue
task = asyncio.create_task(
_run_provider(prov.id),
name=f"ha-subscription-{prov.id}",
)
task.add_done_callback(_make_done_callback(prov.id))
_running_tasks[prov.id] = task
if providers:
_LOGGER.info(
"Started HA subscription manager: %d provider(s)", len(providers),
)
async def stop_all() -> None:
"""Cancel every HA supervisor task and wait for clean shutdown."""
if not _running_tasks:
return
for task in _running_tasks.values():
task.cancel()
# Wait for all to drain; swallow cancellation errors.
await asyncio.gather(*_running_tasks.values(), return_exceptions=True)
_running_tasks.clear()
_LOGGER.info("Stopped all HA subscription supervisors")
async def reload_provider(provider_id: int) -> None:
"""Stop and restart the supervisor for a single provider id.
Hook for the provider CRUD routes — Phase 1.5 will wire it in. For Phase
1, configure-then-restart-backend is the supported flow.
"""
task = _running_tasks.pop(provider_id, None)
if task is not None:
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception): # noqa: BLE001
pass
engine = get_engine()
async with AsyncSession(engine) as session:
prov = await session.get(ServiceProvider, provider_id)
if prov is None or prov.type != "home_assistant":
return
new_task = asyncio.create_task(
_run_provider(provider_id),
name=f"ha-subscription-{provider_id}",
)
new_task.add_done_callback(_make_done_callback(provider_id))
_running_tasks[provider_id] = new_task
@@ -213,4 +213,25 @@ _SAMPLE_CONTEXT = {
"raw_payload": {"action": "opened", "issue": {"title": "Bug report", "number": 1}, "sender": {"login": "user1"}},
"event_type_raw": "webhook_received",
"source_ip": "192.168.1.100",
# Home Assistant variables (for home_assistant provider templates)
"friendly_name": "Front Door",
"entity_id": "binary_sensor.front_door",
"domain": "binary_sensor",
"old_state": "off",
"new_state": "on",
"attributes": {"friendly_name": "Front Door", "device_class": "door"},
"device_class": "door",
"unit_of_measurement": "",
"area": "Entrance",
"last_changed": "2026-05-13T12:34:56.789+00:00",
"last_updated": "2026-05-13T12:34:56.789+00:00",
"automation_name": "Front door notification",
"trigger_source": "state of binary_sensor.front_door",
"service_called": "light.turn_on",
"service_domain": "light",
"service_name": "turn_on",
"service_data": {"entity_id": "light.kitchen"},
"target_entity": "light.kitchen",
"ha_event_type": "state_changed",
"event_data": {"foo": "bar"},
}
@@ -0,0 +1,98 @@
"""Unit tests for HA bot command helpers — Phase 2.
Focus on the security-sensitive bits the reviewer flagged: attribute
filtering, error-message redaction, and the sample-context shape that
flows through Jinja preview rendering.
"""
from __future__ import annotations
from notify_bridge_server.commands.home_assistant_handler import (
_filter_attributes,
_is_sensitive_attr,
_normalize_state,
)
def test_filter_attributes_drops_credential_keys() -> None:
"""HA camera entities expose an ``access_token`` attribute. The handler
MUST NOT surface it to the chat user via /state."""
raw = {
"friendly_name": "Front Camera",
"access_token": "real-camera-proxy-token",
"entity_picture": "/api/camera_proxy/...?token=abc",
"brightness": 200,
}
safe, hidden = _filter_attributes(raw)
assert "access_token" not in safe
# entity_picture contains 'token' substring → blocked.
assert "entity_picture" not in safe
# friendly_name is rendered as a top-level field, not iterated.
assert "friendly_name" not in safe
# brightness is a normal attribute, passes through.
assert safe["brightness"] == 200
assert hidden == 2
def test_filter_attributes_caps_count() -> None:
"""When an entity has dozens of attributes the renderer would overflow
Telegram's 4096-char message limit. Cap at 30 with overflow surfaced."""
raw = {f"attr_{i:03d}": i for i in range(50)}
safe, hidden = _filter_attributes(raw)
assert len(safe) == 30
assert hidden == 20
def test_is_sensitive_attr_case_insensitive() -> None:
"""Match should not depend on key casing — custom integrations are
inconsistent about capitalization."""
assert _is_sensitive_attr("Access_Token") is True
assert _is_sensitive_attr("API_KEY") is True
assert _is_sensitive_attr("password") is True
assert _is_sensitive_attr("brightness") is False
assert _is_sensitive_attr("color_mode") is False
def test_normalize_state_filters_attrs() -> None:
"""End-to-end: feed _normalize_state a malicious state row, verify the
output has redacted attributes + hidden_attr_count surfaced."""
state_row = {
"entity_id": "camera.front_door",
"state": "idle",
"attributes": {
"friendly_name": "Front Door Camera",
"access_token": "leaked",
"brand": "Reolink",
},
"last_changed": "2026-05-13T12:00:00+00:00",
"last_updated": "2026-05-13T12:00:00+00:00",
}
out = _normalize_state(state_row)
assert out["entity_id"] == "camera.front_door"
assert out["friendly_name"] == "Front Door Camera"
assert out["domain"] == "camera"
# Top-level fields preserved.
assert out["state"] == "idle"
# Attributes dict is filtered.
assert "access_token" not in out["attributes"]
assert out["attributes"].get("brand") == "Reolink"
# Hidden count reflects access_token (friendly_name is top-level, not redacted).
assert out["hidden_attr_count"] == 1
def test_normalize_state_handles_missing_attributes() -> None:
"""A state row with no attributes dict should not crash."""
out = _normalize_state({"entity_id": "sensor.x", "state": "1"})
assert out["attributes"] == {}
assert out["hidden_attr_count"] == 0
def test_redact_ha_message_strips_userinfo() -> None:
"""The Phase 1 redact helper is re-exported via the HA package and used
by /entities, /state, /areas before surfacing errors. Make sure the
re-export still works and the contract is what we expect."""
from notify_bridge_core.providers.home_assistant import redact_ha_message
msg = "Cannot connect to https://leak-token@homeassistant.local:8123/api/websocket"
out = redact_ha_message(msg)
assert "leak-token@" not in out
assert "homeassistant.local:8123" in out
@@ -0,0 +1,80 @@
"""Tests for the HA-specific tracker filter (entity_glob, domain_allowlist).
The Gitea filter is an intersection of senders/collections. The HA filter
is intentionally a *union* across the three keys — any match passes — so a
user can mix exact entity ids with glob patterns and domain allowlists
without each one narrowing the others.
"""
from __future__ import annotations
from datetime import datetime, timezone
from notify_bridge_core.models.events import EventType, ServiceEvent
from notify_bridge_core.providers.base import ServiceProviderType
from notify_bridge_server.services.ha_subscription import _ha_passes_filters
def _ha_event(entity_id: str, domain: str | None = None) -> ServiceEvent:
return ServiceEvent(
event_type=EventType.HA_STATE_CHANGED,
provider_type=ServiceProviderType.HOME_ASSISTANT,
provider_name="HA",
collection_id=entity_id,
collection_name=entity_id,
timestamp=datetime.now(timezone.utc),
extra={"domain": domain or (entity_id.split(".", 1)[0] if "." in entity_id else "")},
)
def test_empty_filters_accept_everything() -> None:
assert _ha_passes_filters(_ha_event("light.kitchen"), {}) is True
def test_exact_entity_match() -> None:
filters = {"collections": ["light.kitchen", "switch.lamp"]}
assert _ha_passes_filters(_ha_event("light.kitchen"), filters) is True
assert _ha_passes_filters(_ha_event("light.bedroom"), filters) is False
def test_entity_glob_match() -> None:
filters = {"entity_glob": ["binary_sensor.*_motion", "light.kitchen*"]}
assert _ha_passes_filters(_ha_event("binary_sensor.hallway_motion"), filters) is True
assert _ha_passes_filters(_ha_event("light.kitchen_main"), filters) is True
assert _ha_passes_filters(_ha_event("light.bedroom"), filters) is False
def test_domain_allowlist() -> None:
filters = {"domain_allowlist": ["light", "switch"]}
assert _ha_passes_filters(_ha_event("light.kitchen"), filters) is True
assert _ha_passes_filters(_ha_event("switch.lamp"), filters) is True
assert _ha_passes_filters(_ha_event("sensor.temp"), filters) is False
def test_union_across_keys() -> None:
"""If collections names a specific sensor.* but domain_allowlist names
'light', BOTH should be acceptable — that's the difference from the
Gitea-style intersection filter."""
filters = {
"collections": ["sensor.outdoor_temp"],
"domain_allowlist": ["light"],
}
assert _ha_passes_filters(_ha_event("sensor.outdoor_temp"), filters) is True
assert _ha_passes_filters(_ha_event("light.kitchen"), filters) is True
# Neither matches:
assert _ha_passes_filters(_ha_event("binary_sensor.door"), filters) is False
def test_domain_derived_when_extra_missing() -> None:
"""If the parser didn't populate extra.domain (e.g. malformed event),
the filter must still infer it from the entity_id prefix."""
evt = ServiceEvent(
event_type=EventType.HA_STATE_CHANGED,
provider_type=ServiceProviderType.HOME_ASSISTANT,
provider_name="HA",
collection_id="light.kitchen",
collection_name="light.kitchen",
timestamp=datetime.now(timezone.utc),
extra={}, # No 'domain' key.
)
assert _ha_passes_filters(evt, {"domain_allowlist": ["light"]}) is True
@@ -0,0 +1,187 @@
"""Unit tests for the Home Assistant event parser.
These tests don't need a database or HA server — the parser is a pure
function from ``ha_event_dict`` to :class:`ServiceEvent`.
"""
from __future__ import annotations
from notify_bridge_core.models.events import EventType
from notify_bridge_core.providers.base import ServiceProviderType
from notify_bridge_core.providers.home_assistant.event_parser import parse_event
def _ha_event_envelope(event_type: str, data: dict) -> dict:
return {
"event_type": event_type,
"data": data,
"time_fired": "2026-05-13T12:34:56.789Z",
}
def test_state_changed_basic() -> None:
payload = _ha_event_envelope(
"state_changed",
{
"entity_id": "binary_sensor.front_door",
"old_state": {"state": "off", "attributes": {}},
"new_state": {
"state": "on",
"attributes": {
"friendly_name": "Front Door",
"device_class": "door",
},
},
},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.event_type is EventType.HA_STATE_CHANGED
assert evt.provider_type is ServiceProviderType.HOME_ASSISTANT
assert evt.collection_id == "binary_sensor.front_door"
assert evt.collection_name == "Front Door"
assert evt.extra["old_state"] == "off"
assert evt.extra["new_state"] == "on"
assert evt.extra["domain"] == "binary_sensor"
assert evt.extra["device_class"] == "door"
# Area was not provided in lookup -> None.
assert evt.extra["area"] is None
def test_state_changed_with_area_lookup() -> None:
payload = _ha_event_envelope(
"state_changed",
{
"entity_id": "light.kitchen",
"old_state": {"state": "off", "attributes": {}},
"new_state": {
"state": "on",
"attributes": {"friendly_name": "Kitchen Light"},
},
},
)
evt = parse_event(
payload,
provider_name="HA",
area_lookup={"light.kitchen": "Kitchen"},
)
assert evt is not None
assert evt.extra["area"] == "Kitchen"
def test_state_changed_entity_removed() -> None:
"""new_state=None means HA removed the entity. Surface as 'removed' so
templates can branch on it; collection_name falls back to old_state."""
payload = _ha_event_envelope(
"state_changed",
{
"entity_id": "sensor.dropped",
"old_state": {
"state": "online",
"attributes": {"friendly_name": "Dropped Sensor"},
},
"new_state": None,
},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.extra["new_state"] == "removed"
assert evt.collection_name == "Dropped Sensor"
def test_automation_triggered() -> None:
payload = _ha_event_envelope(
"automation_triggered",
{
"name": "Front door notification",
"entity_id": "automation.front_door_notify",
"source": "state of binary_sensor.front_door",
},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.event_type is EventType.HA_AUTOMATION_TRIGGERED
assert evt.collection_name == "Front door notification"
assert evt.collection_id == "automation.front_door_notify"
assert evt.extra["automation_name"] == "Front door notification"
assert evt.extra["trigger_source"] == "state of binary_sensor.front_door"
def test_call_service_with_target() -> None:
payload = _ha_event_envelope(
"call_service",
{
"domain": "light",
"service": "turn_on",
"service_data": {"entity_id": "light.kitchen"},
},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.event_type is EventType.HA_SERVICE_CALLED
assert evt.collection_id == "light.turn_on"
assert evt.extra["target_entity"] == "light.kitchen"
assert evt.extra["service_domain"] == "light"
assert evt.extra["service_name"] == "turn_on"
def test_call_service_with_multi_target() -> None:
"""When the call hits multiple entities, the parser comma-joins them
so templates can render ``{{ target_entity }}`` without iterating."""
payload = _ha_event_envelope(
"call_service",
{
"domain": "light",
"service": "turn_off",
"service_data": {
"entity_id": ["light.kitchen", "light.living_room"],
},
},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.extra["target_entity"] == "light.kitchen, light.living_room"
def test_generic_event_fallback() -> None:
"""Any event_type not in the known set becomes ha_event_fired with the
raw event_type stashed in extras so loud catch-all subscriptions work."""
payload = _ha_event_envelope(
"custom_event_xyz",
{"foo": "bar"},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.event_type is EventType.HA_EVENT_FIRED
assert evt.extra["ha_event_type"] == "custom_event_xyz"
assert evt.extra["event_data"] == {"foo": "bar"}
def test_malformed_payload_returns_none() -> None:
assert parse_event({}, provider_name="HA") is None
assert parse_event("not a dict", provider_name="HA") is None # type: ignore[arg-type]
# state_changed without entity_id is unrecoverable
bad = _ha_event_envelope("state_changed", {"new_state": None})
assert parse_event(bad, provider_name="HA") is None
# call_service without domain/service is unrecoverable
bad2 = _ha_event_envelope("call_service", {"service": "turn_on"})
assert parse_event(bad2, provider_name="HA") is None
def test_time_fired_iso_with_z_suffix_parses() -> None:
"""HA uses ``Z`` suffix; older Python ``fromisoformat`` rejects it.
The parser must handle both forms or we'd lose the timestamp."""
from datetime import timezone
payload = _ha_event_envelope(
"state_changed",
{
"entity_id": "sensor.temp",
"old_state": {"state": "20", "attributes": {}},
"new_state": {"state": "21", "attributes": {}},
},
)
payload["time_fired"] = "2026-05-13T12:34:56.789Z"
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.timestamp.tzinfo is not None
assert evt.timestamp.utcoffset() == timezone.utc.utcoffset(None)
@@ -0,0 +1,193 @@
"""Tests for the HA WS session helper and slice-before-normalize path.
The reviewer flagged two perf-shaped concerns that we've now addressed:
1. ``/status`` and ``/areas`` previously opened 3 and 2 separate WS
connections respectively. With ``HomeAssistantSession`` they share one
socket — these tests pin the contract.
2. ``/entities`` used to normalize every matching entity before slicing to
``count``. For HA installs with 1000+ entities this materialized 1000+
normalized dicts to throw most away. The optimization moves the slice
*before* normalize; this test exercises a 200-entity fixture and
verifies only the ``count`` survivors get normalized.
"""
from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import patch
import pytest
from notify_bridge_core.providers.home_assistant.client import HomeAssistantSession
from notify_bridge_server.commands import home_assistant_handler as handler
# ---------------------------------------------------------------------------
# Session class — surface contract
# ---------------------------------------------------------------------------
def test_session_class_has_expected_methods() -> None:
"""Anyone consuming ``HomeAssistantSession`` can rely on this surface."""
expected = {"send", "get_states", "get_area_registry", "get_entity_registry"}
actual = {name for name in dir(HomeAssistantSession) if not name.startswith("_")}
assert expected <= actual, f"missing: {expected - actual}"
@pytest.mark.asyncio
async def test_session_get_states_routes_through_send() -> None:
"""``get_states`` is a thin wrapper around ``send`` with the canonical payload."""
sent: list[dict[str, Any]] = []
class _FakeClient:
async def _send_command(self, ws: Any, payload: dict[str, Any]) -> int:
sent.append(payload)
return 1
async def _await_result(self, ws: Any, msg_id: int, timeout: float = 15.0) -> Any:
return [{"entity_id": "light.kitchen", "state": "on", "attributes": {}}]
sess = HomeAssistantSession(_FakeClient(), ws=object()) # type: ignore[arg-type]
result = await sess.get_states()
assert sent == [{"type": "get_states"}]
assert result == [{"entity_id": "light.kitchen", "state": "on", "attributes": {}}]
@pytest.mark.asyncio
async def test_session_methods_use_distinct_payloads() -> None:
"""Each session-scoped method sends the right HA command name."""
sent: list[dict[str, Any]] = []
class _FakeClient:
async def _send_command(self, ws: Any, payload: dict[str, Any]) -> int:
sent.append(payload)
return len(sent)
async def _await_result(self, ws: Any, msg_id: int, timeout: float = 15.0) -> Any:
return []
sess = HomeAssistantSession(_FakeClient(), ws=object()) # type: ignore[arg-type]
await sess.get_states()
await sess.get_area_registry()
await sess.get_entity_registry()
assert [p["type"] for p in sent] == [
"get_states",
"config/area_registry/list",
"config/entity_registry/list",
]
# ---------------------------------------------------------------------------
# slice-before-normalize — perf contract for /entities
# ---------------------------------------------------------------------------
class _FakeAsyncSession:
"""A fake HA session that returns a canned state list."""
def __init__(self, states: list[dict[str, Any]]) -> None:
self._states = states
async def get_states(self) -> list[dict[str, Any]]:
return self._states
class _FakeClient:
"""A fake client whose ``session()`` yields a ``_FakeAsyncSession``."""
def __init__(self, states: list[dict[str, Any]]) -> None:
self._states = states
def session(self): # noqa: D401 — mimics real client signature
states = self._states
class _CM:
async def __aenter__(self_inner):
return _FakeAsyncSession(states)
async def __aexit__(self_inner, *_exc):
return False
return _CM()
def _state_row(entity_id: str, n_attrs: int = 2) -> dict[str, Any]:
return {
"entity_id": entity_id,
"state": "on",
"attributes": {f"attr_{i}": i for i in range(n_attrs)},
}
@pytest.mark.asyncio
async def test_cmd_entities_slices_before_normalizing(monkeypatch: pytest.MonkeyPatch) -> None:
"""200 raw entities, count=10. Normalize must run only 10 times.
We instrument ``_normalize_state`` with a counter to prove the slice
happens before the per-row transform. The total field still reports
all 200 so the user knows the result is truncated.
"""
states = [_state_row(f"light.bulb_{i:03d}") for i in range(200)]
fake_client = _FakeClient(states)
monkeypatch.setattr(handler, "_make_ws_client", lambda provider, session: fake_client)
calls = {"count": 0}
real_normalize = handler._normalize_state
def _counting_normalize(row: dict[str, Any]) -> dict[str, Any]:
calls["count"] += 1
return real_normalize(row)
monkeypatch.setattr(handler, "_normalize_state", _counting_normalize)
# ``get_http_session`` opens a real aiohttp session in the bg; bypass
# it since our fake client never uses the session arg.
async def _fake_http_session() -> Any:
return None
monkeypatch.setattr(handler, "get_http_session", _fake_http_session)
provider = type("FakeProvider", (), {"config": {}, "name": "HA"})()
result = await handler._cmd_entities(provider, args="", count=10)
assert result["total"] == 200
assert result["shown"] == 10
assert len(result["entities"]) == 10
assert calls["count"] == 10, (
f"normalize should run once per survivor; ran {calls['count']} times"
)
@pytest.mark.asyncio
async def test_cmd_entities_glob_filter_still_normalizes_only_survivors(monkeypatch: pytest.MonkeyPatch) -> None:
"""200 raw entities mixed across 2 domains; glob narrows to one.
Normalize count = min(count, matching_total). Demonstrates the
optimization composes with the filter step.
"""
states = [
_state_row(f"light.bulb_{i:03d}") for i in range(100)
] + [
_state_row(f"sensor.temp_{i:03d}") for i in range(100)
]
fake_client = _FakeClient(states)
monkeypatch.setattr(handler, "_make_ws_client", lambda provider, session: fake_client)
calls = {"count": 0}
real_normalize = handler._normalize_state
def _counting_normalize(row: dict[str, Any]) -> dict[str, Any]:
calls["count"] += 1
return real_normalize(row)
monkeypatch.setattr(handler, "_normalize_state", _counting_normalize)
async def _fake_http_session() -> Any:
return None
monkeypatch.setattr(handler, "get_http_session", _fake_http_session)
provider = type("FakeProvider", (), {"config": {}, "name": "HA"})()
result = await handler._cmd_entities(provider, args="light.*", count=5)
assert result["total"] == 100 # all light.* entities counted
assert result["shown"] == 5 # but only 5 normalized
assert calls["count"] == 5
assert all(e["entity_id"].startswith("light.") for e in result["entities"])