feat: Home Assistant provider — WebSocket subscription + bot commands

Adds Home Assistant as a service provider with two coordinated surfaces:

Notifications (subscription):
- Long-lived WebSocket client (aiohttp ws_connect) with auth handshake,
  exponential-backoff reconnect, bounded event queue, and area-registry
  enrichment cached per (re)connect
- ServiceProvider ABC gains an optional `subscribe()` method for push-style
  providers; HomeAssistantServiceProvider uses it via a per-provider
  supervisor task started in the FastAPI lifespan
- 4 event types (state_changed, automation_triggered, call_service,
  event_fired), 4 default Jinja templates (en + ru), HA-specific
  tracker filters (entity_glob, domain_allowlist, exact entity ids)
- Extracted shared dispatch pipeline (api/webhooks.py → services/
  event_dispatch.py) so subscription and webhook ingest share the same
  event_log + deferred-dispatch + quiet-hours code path

Bot commands:
- /status, /entities [glob], /state <entity_id>, /areas
- Multi-command WS session so /status and /areas cost one handshake
- Sensitive-attribute blocklist (camera access_token, entity_picture, etc.)
  and 30-attribute cap to keep /state output safe and within Telegram's
  message size
- Error-message redaction strips URL userinfo before surfacing to chat

Frontend:
- HA descriptor with toggle ConfigField type (new) and tag-input filter
  mode for free-text glob/domain lists (new TagInput component)
- 15 command slots + 4 notification slots wired into the existing
  template-config UI
This commit is contained in:
2026-05-13 14:31:56 +03:00
parent 90f958bdc6
commit 22127e2a59
79 changed files with 4042 additions and 210 deletions
+177
View File
@@ -0,0 +1,177 @@
# Feature Backlog
Curated feature ideas, narrowed from a brainstorming pass on 2026-05-13.
Order is **rough sequencing preference**, not strict priority — adjust as we go.
---
## 1. Quiet Hours — close the gaps in the existing system
**Reality check (verified 2026-05-13).** Quiet hours are already shipped under
the "deferred dispatch" name in v0.8.0. The pipeline lives at
`packages/server/src/notify_bridge_server/services/deferred_dispatch.py` with
helpers in `dispatch_helpers.py` and tests in
`tests/test_deferred_dispatch.py`. What exists:
- Per-tracking-config window: `tracking_config.quiet_hours_enabled`,
`quiet_hours_start`, `quiet_hours_end`.
- Per-link override: `notification_tracker_target.quiet_hours_start`,
`quiet_hours_end`.
- Smart coalescing (asset add + asset remove during a window cancels each
other out, set-union merge for repeated adds).
- Post-window drain via APScheduler one-shot date jobs.
- Wall-clock event types (`scheduled_message`) drop instead of deferring.
- Frontend status surface: `deferred`, `deferred_then_dropped`,
`deferred_then_failed`, with `deferred_until` and `deferred_for_seconds`
fields exposed in the event log.
**What's NOT there (the actual gaps):**
| Gap | Sketch |
| --- | --- |
| **Target-level windows** | Today, hours bind to the *watcher* (tracking config / link). Users naturally think of DND at the *destination* ("don't ping my phone at night, regardless of source"). New column on `notification_target` + dispatcher gate. |
| **Multiple windows per row** | Today is a single HH:MM range. Real schedules want weekday-evening + weekend-all-day. JSON list of windows. |
| **Days-of-week** | Same window every day. Need `days: ["mon", "tue", ...]` filter per window. |
| **Per-window timezone** | Uses the global app TZ. Multi-traveller / multi-target setups want per-window TZ. |
| **Silent mode** | Modes today are defer-or-drop. Telegram `disable_notification=true` ("send but don't ring") is a third useful mode. |
| **Per-receiver windows** | One bot → multiple chats, each potentially with its own DND. Today it's all-or-nothing per target. |
**Recommended cut for v1 of "extend quiet hours":**
- Add target-level quiet hours (new column `notification_target.quiet_hours_json`
= list of `{days, start, end, mode, tz}`).
- Modes: `drop`, `defer`, `silent`. `defer` reuses the existing
deferred-dispatch pipeline (just changes who decides). `silent` maps to
`disable_notification=true` for Telegram; other targets fall through to
normal send (or we treat `silent` as `defer` for non-Telegram targets — TBD).
- Dispatcher precedence: target window wins over link/tracking-config window
when both are configured. Document this explicitly.
- Frontend: new "Quiet hours" fieldset in the target editor (Aurora cassette
style). Reuses Timezone picker; new day-picker chip row.
- Skip days-of-week + multi-window in v1 if scope grows — ship the target-level
cut first, then iterate.
**Open questions.**
- How exactly do target / link / tracking-config windows combine? Proposal:
any window covering "now" wins (drop > defer > silent precedence).
- Should `silent` for non-Telegram targets degrade to normal send or to
defer? Defer is the safer default.
- Does the event log need a new status (`silenced` / `dropped_by_target_qh`)
to make precedence visible?
---
## 2. Immich Smart Actions (expand beyond Auto-Organize)
**What.** Extend the existing Smart Actions pattern (currently:
**Immich Auto-Organize**) with more rule-driven actions against the Immich API.
**Why.** Auto-Organize already proves the descriptor → rule editor → executor
pipeline. Adding actions is mostly authoring new executors + small UI rule
shapes, not new infra.
**Candidates (pick in this order).**
1. **Auto-favorite by person** — when an asset is detected containing person
X (or any of a set), mark it favorite.
2. **Auto-archive by age / album** — assets older than N days in a given
album get archived. Pair with a "dry-run shows count" UX like
Auto-Organize already has.
3. **Duplicate cluster nudge** — periodically run Immich's duplicate API and
send a digest notification with inline buttons ("review", "ignore for 30d").
Depends on inline-button work (see backlog item 4 dependencies).
4. **Share-link rotation** — for an album, regenerate the share link every N
days; notify with the new URL.
5. **Pending-delete review** — push a weekly digest of trash contents before
Immich's auto-purge fires.
**Shape.**
- Reuse the existing **action descriptor** layer
(`packages/core/src/notify_bridge_core/providers/actions.py`,
`action_executor.py`) and the frontend rule editor used by Auto-Organize.
- Each new action = (a) executor in core, (b) rule schema in the descriptor,
(c) frontend descriptor extension for the rule editor fields.
- Persist as `provider_actions` rows (already exists for Auto-Organize) with
a discriminator + JSON config.
**Open questions.**
- Does "auto-favorite by person" need a confirmation queue or run silently?
Default to silent + event_log entry.
- How do we surface "this action moved/changed X assets" in the dashboard?
Probably a per-action stat tile on the provider detail page.
---
## 3. Home Assistant Provider
**Full plan:** [feature-home-assistant.md](feature-home-assistant.md).
**One-line summary.** New WebSocket-based service provider with a 3-phase
ship: subscribe + dispatch (Phase 1), bot commands (Phase 2), HA service
calls as Smart Actions (Phase 3). Chosen over webhook ingest because
Phases 2 + 3 force a long-lived API connection anyway; consolidating on WS
avoids a refactor.
**Status:** planned, not started.
---
## 4. Block-Based Template Builder
**What.** A visual, drag-and-drop builder for notification and command
templates that compiles down to Jinja2. Lives alongside (not instead of) the
current `JinjaEditor`. Author can flip between views.
**Why.** The current Jinja editor is powerful but unforgiving. A block UI
lowers the floor for new users and provides a discovery surface for the
variables documented in `template_configs.py`.
**Shape.**
- Frontend-only feature for v1 — compiles to the same Jinja strings the
backend already accepts.
- Blocks: `Text`, `Variable`, `If`, `For`, `Link`, `Image`, `Icon`, `Caption`,
`Group` (HTML span/group). Each block knows its serialized Jinja
representation.
- Round-trip: variables, simple `{% if %}` / `{% for %}` blocks, and string
literals parse back to blocks; arbitrary Jinja stays in a "Raw" block that
the user can edit as text.
- Variable picker reads `get_template_variables(provider_type, slot)`. This is
the same data already shown in the template-help panel.
- Preview pane unchanged — reuses `services/sample_context.py` server
rendering.
- Toggle in the template editor: **Visual / Code**.
**Open questions.**
- Round-tripping arbitrary Jinja is hard. v1: parseable subset → blocks,
anything else → single Raw block. Show a banner explaining.
- Locale handling: same compiled Jinja, just authored per locale tab.
- Do we want a marketplace of pre-built block compositions? Out of scope for
v1 — bundle import/export is a separate backlog item.
---
## Recommended Sequencing
1. **Quiet Hours per Target** — small, isolated, immediate user value.
2. **Immich Smart Actions** — incremental on existing pattern; ship one
action at a time (start with auto-favorite by person).
3. **Home Assistant Provider** — multi-file, follows new-provider checklist;
biggest user-base expansion.
4. **Block-Based Template Builder** — largest frontend lift; benefits from
the variable-doc work that the other features will exercise.
Dependencies are loose — 1 and 2 are independent of 3 and 4. The block
builder pairs nicely with Home Assistant because HA's rich context surfaces
the value of an easier authoring UX.
---
## Decision log
- **2026-05-13** — Backlog seeded with these four items selected from a
broader brainstorm. Not started.
+284
View File
@@ -0,0 +1,284 @@
# Home Assistant Provider — Implementation Plan
> Status: **planned, not started**. Sequencing: third item on the backlog
> (see [feature-backlog.md](feature-backlog.md)).
> Last updated: 2026-05-13.
## Decision: WebSocket subscription, not webhook
We considered three ingest modes (webhook automation, WebSocket subscription,
hybrid). The WebSocket route is chosen as the architectural foundation because
the medium-term roadmap forces it anyway:
| Phase | Capability | Needs API access? |
| --- | --- | --- |
| 1 | Subscribe to events, emit notifications | Read (event stream) |
| 2 | Bot commands (`/state`, `/entities`, `/areas`) | Read (REST or WS get_states) |
| 3 | Smart Actions (`light.turn_on`, scene activation) | Write (call_service) |
A webhook-only Phase 1 would still need a REST client by Phase 2 and a write
path by Phase 3 — net result is two client implementations + one event
pipeline refactor. WebSocket consolidates all three phases on one connection.
**Tradeoff (be honest):** WebSocket introduces a long-lived-connection pattern
this codebase does not have yet. Reconnect logic, missed-events-on-restart
gap, and a new shape on the `ServiceProvider` ABC are real costs. Phase 1 is
**not** shippable in one short session — plan for a multi-session build.
## Provider abstraction extension
The current `ServiceProvider` ABC
([packages/core/src/notify_bridge_core/providers/base.py](../../packages/core/src/notify_bridge_core/providers/base.py))
is poll-oriented: every provider implements `poll(collection_ids, state) →
(events, new_state)`. Webhook providers (Gitea, Planka, Webhook) satisfy this
by no-op'ing `poll` and shoving events in via `api/webhooks.py` instead.
Home Assistant fits neither cleanly. The plan:
1. Add an **optional** `async subscribe(emit) → None` method on the base ABC.
Default implementation raises `NotImplementedError`. Polling providers do
not override it. The scheduler / lifecycle layer (currently `services/watcher.py`)
gains a "subscription manager" branch that, for any provider whose class
overrides `subscribe`, starts a long-lived task instead of registering
a polling job.
2. `emit` is a callback `(event: ServiceEvent) → None` provided by the
subscription manager — it routes events to the dispatcher exactly like the
webhook handler does today. Keeping the dispatch path unchanged is the
point of this design.
3. Reconnect lives **inside** `subscribe`: the method is expected to be a
`while not cancelled: try connect; on drop, sleep with backoff, retry`
loop. The manager cancels the task on shutdown via the cooperative cancel
token used elsewhere.
This is a small, additive change to one ABC. No existing provider is
modified.
## Phase 1 — Subscribe + Dispatch (MVP)
### Scope
- Long-lived WebSocket connection to HA, authenticated with a long-lived
access token.
- Subscribe to the event bus with optional `event_type` filter (defaults to
`state_changed`).
- Translate HA events into `ServiceEvent` and dispatch via the existing
pipeline. Notifications go out exactly as they do today for any other
provider.
- Filter UI: entity-id glob list, domain allowlist (e.g. `light.*`,
`binary_sensor.*`), event-type allowlist. **Hard-required** to avoid the HA
firehose drowning the bridge.
- Connection test + entity listing via WS `get_states` (no REST client yet —
WS gives us both subscribe and read).
### Out of scope for Phase 1
- Bot commands → Phase 2.
- Service calls → Phase 3.
- Replay of events missed during disconnect (HA does not support this; we
document the gap and surface "reconnected after N seconds" in the event
log).
- Webhook-style ingestion (path-embedded token webhook receiver). If a user
prefers webhooks, we add it later as a second ingestion mode on the same
provider — out of scope for v1.
### Event types (v1)
| HA event | ServiceEvent type | Notification slot |
| --- | --- | --- |
| `state_changed` | `ha_state_changed` | `message_state_changed` |
| `automation_triggered` | `ha_automation_triggered` | `message_automation_triggered` |
| `call_service` | `ha_service_called` | `message_service_called` |
| (custom event types) | `ha_event_fired` | `message_event_fired` |
Default tracking config enables `state_changed` only — the others are loud
and opt-in.
### Context variables exposed to templates
Pulled directly from HA's `state_changed` payload, normalized:
- `entity_id``light.kitchen`
- `friendly_name``attributes.friendly_name` or fallback to `entity_id`
- `domain` — derived from `entity_id` before the dot
- `old_state``from_state.state`
- `new_state``to_state.state`
- `attributes` — dict of new-state attributes (raw)
- `device_class``attributes.device_class` if present
- `area``attributes.area_id` if present (best effort; only set if HA
exposes it via the area registry, which costs a `get_registry` WS call —
see "Open questions")
- `last_changed`, `last_updated` — ISO timestamps
- For non-`state_changed` events: `event_type`, `event_data` (full dict)
### File touch map (Phase 1)
**Core** (`packages/core/src/notify_bridge_core/`)
| Path | Action | Notes |
| --- | --- | --- |
| `providers/base.py` | Modify | Add optional `subscribe(emit)` ABC method (default `NotImplementedError`); add `HOME_ASSISTANT = "home_assistant"` to `ServiceProviderType` |
| `providers/capabilities.py` | Modify | Add `HOME_ASSISTANT_CAPABILITIES` + register |
| `providers/home_assistant/__init__.py` | Create | Export + register template variables |
| `providers/home_assistant/client.py` | Create | WebSocket client (auth, subscribe, get_states, call_service stub) |
| `providers/home_assistant/event_parser.py` | Create | HA event dict → `ServiceEvent` |
| `providers/home_assistant/provider.py` | Create | Class with `connect`, `disconnect`, `subscribe`, `list_collections` (entity list), `get_available_variables`, `get_provider_config_schema`, `test_connection`. `poll` raises NotImplementedError. |
| `templates/defaults/en/home_assistant_*.jinja2` | Create | 4 slot templates |
| `templates/defaults/ru/home_assistant_*.jinja2` | Create | 4 slot templates |
| `templates/defaults/loader.py` | Modify | Add to `PROVIDER_SLOT_FILE_MAP` |
| `templates/command_defaults/loader.py` | Modify | Stub entry — empty commands list for now |
| `templates/context.py` | Modify | HA context builder |
| `templates/validator.py` | Modify | Whitelist HA variable names |
**Server** (`packages/server/src/notify_bridge_server/`)
| Path | Action | Notes |
| --- | --- | --- |
| `services/watcher.py` *(or scheduler / lifecycle module that hosts polling)* | Modify | Add subscription-manager branch — for providers whose class overrides `subscribe`, start/stop long-running task instead of polling |
| `services/scheduler.py` | Verify | Confirm we cancel HA subscription on shutdown (graceful_shutdown_seconds path) |
| `api/template_configs.py` | Modify | `get_template_variables()` entry |
| `api/command_template_configs.py` | Modify | Sample ctx (minimal for Phase 1 — no commands) |
| `services/sample_context.py` | Modify | `_SAMPLE_CONTEXT["home_assistant"]` |
| `database/seeds.py` | Modify | Seed notification templates + default tracking config |
**Frontend** (`frontend/src/`)
| Path | Action | Notes |
| --- | --- | --- |
| `lib/providers/home-assistant.ts` | Create | Descriptor per CLAUDE.md rule 11 |
| `lib/providers/index.ts` | Modify | Register descriptor |
| `lib/locales/en.json` | Modify | `providers.typeHomeAssistant`, `gridDesc.providerHomeAssistant` |
| `lib/locales/ru.json` | Modify | Same |
**Tests**
| Path | Action |
| --- | --- |
| `packages/core/tests/providers/test_home_assistant_parser.py` | Create — HA payload → `ServiceEvent` |
| `packages/core/tests/providers/test_home_assistant_client.py` | Create — WS auth, subscribe, reconnect (use a fake server) |
| `packages/server/tests/test_home_assistant_subscription.py` | Create — subscription manager lifecycle, event flows through dispatcher |
### Frontend descriptor essentials
```text
type: "home_assistant"
defaultName: "Home Assistant"
icon: "home" (consider Lucide icon; HA logo if a custom asset exists)
hasUrl: true // base URL of HA (used to derive WS URL)
configFields:
- url: http(s)://homeassistant.local:8123
- access_token: long-lived access token (required)
- allowed_event_types: comma-separated, defaults to "state_changed"
eventFields: 4 checkboxes (state_changed, automation_triggered,
call_service, event_fired)
extraTrackingFields:
- entity_glob: tag input ("light.*", "binary_sensor.*_motion")
- domain_allowlist: tag input
collectionMeta: { label: "Entities", icon: "..." }
webhookBased: false // we are NOT webhook based
```
WS URL is derived: `wss://{host}/api/websocket` (or `ws://` for plain http
HA). Document this in the UI hint.
### Auth model
- **Long-lived access token** from HA (Profile → Long-Lived Access Tokens).
- Stored encrypted at rest via the same path the other providers use for
secrets (the bridge already has a secret-encryption helper — verify the
exact module name during implementation).
- WS auth handshake: connect → server sends `auth_required` → client sends
`{type: "auth", access_token: "..."}` → server replies `auth_ok` or
`auth_invalid`.
### Risks / open questions (Phase 1)
1. **Reconnect strategy.** Exponential backoff capped at 60s, jittered.
On reconnect, log a `connection_restored_after` event so the UI can
surface the gap. Document that HA does not support event replay.
2. **Area registry.** Pulling `area_id` for entities requires a separate
`config/area_registry/list` WS call. Decision needed: fetch once on
connect and cache, refetch on `area_registry_updated` event, or skip
`area` from the context entirely in v1. Recommendation: fetch on
connect, refetch on `area_registry_updated`, skip if it fails (best-effort).
3. **TLS verification for self-signed HA.** Homelab users often have
self-signed certs. Need a `verify_tls: bool` config field (default true)
and a clear warning when disabled. Same pattern as
`NOTIFY_BRIDGE_ALLOW_PRIVATE_URLS` for the SSRF case.
4. **Backpressure.** HA's `state_changed` can fire hundreds of events per
minute in a busy install. The subscription manager must drop or coalesce
if the dispatcher backlog grows beyond a threshold. Cheapest cut: a
bounded `asyncio.Queue` between WS receiver and dispatch — `put_nowait`
with overflow counter visible in the event log.
5. **Entity filter precedence.** Tracking-config has `collection_ids`
(entity_id list) and we want `entity_glob` + `domain_allowlist`. Decision:
if both `collection_ids` and globs are set, union them (any match passes).
Documented prominently in the tracker UI.
6. **Library choice.** `hass-client` is a Python WS client maintained by the
HA community; alternative is rolling our own with `websockets`. The
latter is ~150 LOC and has no external dependency surface. Recommendation:
roll our own. Re-evaluate if Phase 3 needs registry-aware service calls.
## Phase 2 — Bot Commands
Adds Telegram bot commands for HA tracking configs.
- `/status` — connection status, subscribed event count
- `/entities <glob>` — list matching entities + current state
- `/state <entity_id>` — full state + attributes for one entity
- `/areas` — area registry summary
- `/help`
These use the existing WS connection (no new client) via WS commands
`get_states`, `config/area_registry/list`. Template slots and command
template configs follow the same pattern as Gitea/Planka — see
[CLAUDE.md](../../CLAUDE.md) rule 7 / rule 11 for the full set of locations
that must be updated.
Out-of-scope for Phase 2: any command that mutates HA state.
## Phase 3 — Smart Actions (Service Calls)
A new action descriptor in the existing Smart Actions framework
([packages/core/src/notify_bridge_core/providers/actions.py](../../packages/core/src/notify_bridge_core/providers/actions.py)).
- Action type: `ha_call_service`
- Rule: trigger event → service call (e.g. "on motion event in
`binary_sensor.front_door` → call `light.turn_on` on `light.porch`")
- Executor uses the existing WS connection to send `call_service`.
This phase is gated behind explicit per-target authorization in the UI — HA
service calls can do anything the access token allows, including unlocking
doors. Default state: **disabled**, with a clear consent flow when enabling.
## Rough effort estimates
These are rough — sub-task discovery during Phase 1 will refine them.
| Phase | Estimate (focused work) |
| --- | --- |
| Phase 1 (subscribe + dispatch) | 23 sessions |
| Phase 2 (bot commands) | 1 session |
| Phase 3 (smart actions) | 12 sessions |
## When to start
Phase 1 work order, once you green-light it:
1. ABC extension (`base.py`) + tests for the new `subscribe` shape on a fake
provider.
2. WS client + parser + unit tests against recorded HA fixtures (no live HA
needed for these).
3. Subscription manager in `services/watcher.py` — integration test with the
fake provider from step 1.
4. Templates (en + ru), capabilities entry, validator whitelist.
5. Server: seeds, sample context, template_configs entry.
6. Frontend: descriptor, locale keys, i18n.
7. End-to-end smoke test against a real HA instance (homelab).
Backend restart cadence per the project rule: after **every** change in
`packages/server/` or `packages/core/`.
## Decision log
- **2026-05-13** — Plan drafted. Ingest mode = WebSocket (chosen over
webhook for future-proofing toward Phases 2 + 3). Not started.
+49
View File
@@ -505,3 +505,52 @@ button:focus-visible, a:focus-visible {
scroll-behavior: auto !important; scroll-behavior: auto !important;
} }
} }
/* Shared toggle switch — used by provider config forms, tracking-config
extraTrackingFields, and anywhere else we render a boolean field.
Kept global so adding a new ConfigField type='toggle' caller doesn't
need to copy the CSS into its scoped <style>. */
.toggle-switch {
position: relative;
display: inline-flex;
align-items: center;
cursor: pointer;
height: 1.75rem;
}
.toggle-switch input {
position: absolute;
opacity: 0;
width: 0;
height: 0;
}
.toggle-switch .toggle-track {
position: relative;
width: 2.5rem;
height: 1.375rem;
background: var(--color-border);
border-radius: 9999px;
transition: background 0.2s ease;
}
.toggle-switch .toggle-track::after {
content: '';
position: absolute;
top: 0.1875rem;
left: 0.1875rem;
width: 1rem;
height: 1rem;
background: var(--color-foreground);
border-radius: 50%;
transition: transform 0.2s ease;
}
.toggle-switch input:checked + .toggle-track {
background: var(--color-primary);
}
.toggle-switch input:checked + .toggle-track::after {
transform: translateX(1.125rem);
background: var(--color-primary-foreground);
}
+154
View File
@@ -0,0 +1,154 @@
<script lang="ts">
/**
* Free-text chip input. Bind a string[] of values; commit a new chip on
* Enter, comma, or blur. Backspace on empty input deletes the last chip
* for parity with native chip-input UX.
*
* Used by ProviderDescriptor.userFilters with inputMode === 'tags' for
* free-text filter keys like Home Assistant's entity_glob and
* domain_allowlist. Distinct from MultiEntitySelect, which renders a
* picker dropdown sourced from an enumerable list.
*/
import MdiIcon from './MdiIcon.svelte';
interface Props {
values: string[];
onchange: (values: string[]) => void;
placeholder?: string;
icon?: string;
/** Strip / reject anything matching this regex on each entry. */
sanitize?: (raw: string) => string | null;
}
let { values, onchange, placeholder = '', icon, sanitize }: Props = $props();
let draft = $state('');
function addRaw(raw: string): void {
const trimmed = raw.trim();
if (!trimmed) return;
const cleaned = sanitize ? sanitize(trimmed) : trimmed;
if (!cleaned) return;
if (values.includes(cleaned)) return;
onchange([...values, cleaned]);
}
function commitDraft(): void {
if (!draft.trim()) return;
// Allow comma-separated paste — split on commas and add each.
for (const piece of draft.split(',')) {
addRaw(piece);
}
draft = '';
}
function removeAt(index: number): void {
onchange(values.filter((_, i) => i !== index));
}
function onKey(e: KeyboardEvent): void {
if (e.key === 'Enter' || e.key === ',') {
e.preventDefault();
commitDraft();
} else if (e.key === 'Backspace' && draft === '' && values.length > 0) {
e.preventDefault();
removeAt(values.length - 1);
}
}
</script>
<div class="tag-input">
{#each values as value, i (`${i}-${value}`)}
<span class="tag-chip">
{#if icon}<MdiIcon name={icon} size={12} />{/if}
<span class="tag-text">{value}</span>
<button
type="button"
aria-label="Remove"
class="tag-remove"
onclick={() => removeAt(i)}
>×</button>
</span>
{/each}
<input
type="text"
bind:value={draft}
onkeydown={onKey}
onblur={commitDraft}
placeholder={values.length === 0 ? placeholder : ''}
class="tag-draft"
autocomplete="off"
spellcheck="false"
/>
</div>
<style>
.tag-input {
display: flex;
flex-wrap: wrap;
align-items: center;
gap: 0.375rem;
padding: 0.375rem 0.5rem;
min-height: 2.5rem;
border: 1px solid var(--color-border);
border-radius: 0.375rem;
background: var(--color-background);
cursor: text;
}
.tag-input:focus-within {
border-color: var(--color-primary);
}
.tag-chip {
display: inline-flex;
align-items: center;
gap: 0.25rem;
padding: 0.1875rem 0.5rem;
background: var(--color-muted);
border-radius: 9999px;
font-size: 0.75rem;
line-height: 1;
color: var(--color-foreground);
}
.tag-text {
font-family: var(--font-mono, monospace);
}
.tag-remove {
display: inline-flex;
align-items: center;
justify-content: center;
width: 1rem;
height: 1rem;
padding: 0;
font-size: 0.875rem;
line-height: 1;
background: transparent;
border: none;
color: var(--color-muted-foreground);
cursor: pointer;
border-radius: 9999px;
}
.tag-remove:hover {
background: var(--color-border);
color: var(--color-foreground);
}
.tag-draft {
flex: 1;
min-width: 8rem;
border: none;
outline: none;
background: transparent;
font-size: 0.875rem;
color: var(--color-foreground);
padding: 0.125rem 0;
}
.tag-draft::placeholder {
color: var(--color-muted-foreground);
}
</style>
+21 -1
View File
@@ -235,6 +235,13 @@
"typeNut": "NUT (UPS)", "typeNut": "NUT (UPS)",
"typeGooglePhotos": "Google Photos", "typeGooglePhotos": "Google Photos",
"typeWebhook": "Generic Webhook", "typeWebhook": "Generic Webhook",
"typeHomeAssistant": "Home Assistant",
"haAccessToken": "Long-Lived Access Token",
"haAccessTokenKeep": "Long-Lived Access Token (leave empty to keep current)",
"haAccessTokenHint": "Create one in HA → Profile → Long-Lived Access Tokens. Required for WebSocket subscription.",
"haAccessTokenRequired": "Home Assistant access token is required.",
"haVerifyTls": "Verify TLS certificate",
"haVerifyTlsHint": "Disable only for self-signed HA on a trusted LAN. Keep enabled for any internet-reachable instance.",
"loadError": "Failed to load providers.", "loadError": "Failed to load providers.",
"externalDomain": "External Domain", "externalDomain": "External Domain",
"optional": "optional", "optional": "optional",
@@ -320,6 +327,13 @@
"selectBoards": "Select boards...", "selectBoards": "Select boards...",
"upsDevices": "UPS Devices", "upsDevices": "UPS Devices",
"selectUpsDevices": "Select UPS devices...", "selectUpsDevices": "Select UPS devices...",
"entities": "Entities",
"selectEntities": "Select entities...",
"entities_count": "entity(ies)",
"haEntityGlob": "Entity glob filter",
"haEntityGlobPlaceholder": "light.*, binary_sensor.*_motion",
"haDomainAllowlist": "Domain allowlist",
"haDomainAllowlistPlaceholder": "light, switch, binary_sensor",
"eventTypes": "Event Types", "eventTypes": "Event Types",
"notificationTargets": "Notification Targets", "notificationTargets": "Notification Targets",
"scanInterval": "Scan Interval (seconds)", "scanInterval": "Scan Interval (seconds)",
@@ -644,6 +658,11 @@
"upsOverload": "UPS overloaded", "upsOverload": "UPS overloaded",
"scheduledMessage": "Scheduled message", "scheduledMessage": "Scheduled message",
"webhookReceived": "Webhook received", "webhookReceived": "Webhook received",
"haStateChanged": "Entity state changed",
"haAutomationTriggered": "Automation triggered",
"haServiceCalled": "Service called",
"haEventFired": "Other HA event (catch-all)",
"haEventFiredHint": "Fires for any HA event type not covered by the boxes above. Useful for custom integrations; expect high volume.",
"trackImages": "Track images", "trackImages": "Track images",
"trackVideos": "Track videos", "trackVideos": "Track videos",
"favoritesOnly": "Favorites only", "favoritesOnly": "Favorites only",
@@ -1345,7 +1364,8 @@
"providerScheduler": "Time-based scheduled messages", "providerScheduler": "Time-based scheduled messages",
"providerNut": "Network UPS monitoring", "providerNut": "Network UPS monitoring",
"providerGooglePhotos": "Google Photos albums & shared libraries", "providerGooglePhotos": "Google Photos albums & shared libraries",
"providerWebhook": "Receive events via HTTP POST" "providerWebhook": "Receive events via HTTP POST",
"providerHomeAssistant": "Home Assistant event bus over WebSocket"
}, },
"webhookLogs": { "webhookLogs": {
"title": "Recent Payloads", "title": "Recent Payloads",
+21 -1
View File
@@ -235,6 +235,13 @@
"typeNut": "NUT (ИБП)", "typeNut": "NUT (ИБП)",
"typeGooglePhotos": "Google Фото", "typeGooglePhotos": "Google Фото",
"typeWebhook": "Универсальный вебхук", "typeWebhook": "Универсальный вебхук",
"typeHomeAssistant": "Home Assistant",
"haAccessToken": "Долгоживущий токен доступа",
"haAccessTokenKeep": "Долгоживущий токен (оставьте пустым для сохранения)",
"haAccessTokenHint": "Создайте в HA → Профиль → Long-Lived Access Tokens. Нужен для WebSocket-подписки.",
"haAccessTokenRequired": "Токен доступа Home Assistant обязателен.",
"haVerifyTls": "Проверять TLS-сертификат",
"haVerifyTlsHint": "Отключайте только для самоподписанного HA в доверенной локальной сети. Оставляйте включённым для любого экземпляра, доступного из интернета.",
"loadError": "Не удалось загрузить провайдеры.", "loadError": "Не удалось загрузить провайдеры.",
"externalDomain": "Внешний домен", "externalDomain": "Внешний домен",
"optional": "необязательно", "optional": "необязательно",
@@ -320,6 +327,13 @@
"selectBoards": "Выберите доски...", "selectBoards": "Выберите доски...",
"upsDevices": "ИБП устройства", "upsDevices": "ИБП устройства",
"selectUpsDevices": "Выберите ИБП...", "selectUpsDevices": "Выберите ИБП...",
"entities": "Сущности",
"selectEntities": "Выберите сущности...",
"entities_count": "сущность(ей)",
"haEntityGlob": "Фильтр по entity (glob)",
"haEntityGlobPlaceholder": "light.*, binary_sensor.*_motion",
"haDomainAllowlist": "Разрешённые домены",
"haDomainAllowlistPlaceholder": "light, switch, binary_sensor",
"eventTypes": "Типы событий", "eventTypes": "Типы событий",
"notificationTargets": "Получатели уведомлений", "notificationTargets": "Получатели уведомлений",
"scanInterval": "Интервал проверки (секунды)", "scanInterval": "Интервал проверки (секунды)",
@@ -644,6 +658,11 @@
"upsOverload": "Перегрузка ИБП", "upsOverload": "Перегрузка ИБП",
"scheduledMessage": "Запланированное сообщение", "scheduledMessage": "Запланированное сообщение",
"webhookReceived": "Вебхук получен", "webhookReceived": "Вебхук получен",
"haStateChanged": "Состояние сущности изменилось",
"haAutomationTriggered": "Сработала автоматизация",
"haServiceCalled": "Вызвана служба",
"haEventFired": "Прочее событие HA (catch-all)",
"haEventFiredHint": "Срабатывает на любые типы событий HA, не охваченные чекбоксами выше. Полезно для пользовательских интеграций; ожидайте большой объём.",
"trackImages": "Фото", "trackImages": "Фото",
"trackVideos": "Видео", "trackVideos": "Видео",
"favoritesOnly": "Только избранные", "favoritesOnly": "Только избранные",
@@ -1345,7 +1364,8 @@
"providerScheduler": "Запланированные сообщения по расписанию", "providerScheduler": "Запланированные сообщения по расписанию",
"providerNut": "Мониторинг ИБП через NUT", "providerNut": "Мониторинг ИБП через NUT",
"providerGooglePhotos": "Альбомы и общие библиотеки Google Фото", "providerGooglePhotos": "Альбомы и общие библиотеки Google Фото",
"providerWebhook": "Приём событий через HTTP POST" "providerWebhook": "Приём событий через HTTP POST",
"providerHomeAssistant": "Шина событий Home Assistant по WebSocket"
}, },
"webhookLogs": { "webhookLogs": {
"title": "Последние запросы", "title": "Последние запросы",
@@ -0,0 +1,96 @@
import type { ProviderDescriptor } from './types';
export const homeAssistantDescriptor: ProviderDescriptor = {
type: 'home_assistant',
defaultName: 'Home Assistant',
icon: 'mdiHomeAssistant',
hasUrl: true,
urlPlaceholder: 'http://homeassistant.local:8123',
configFields: [
{
key: 'access_token', configKey: 'access_token',
label: 'providers.haAccessToken', editLabel: 'providers.haAccessTokenKeep',
type: 'password', required: 'create-only', hint: 'providers.haAccessTokenHint',
},
{
key: 'verify_tls', configKey: 'verify_tls',
label: 'providers.haVerifyTls',
type: 'toggle', optional: true, hint: 'providers.haVerifyTlsHint',
defaultValue: true,
},
],
buildConfig(form, editing) {
const config: Record<string, unknown> = { url: form.url };
if (form.access_token) config.access_token = form.access_token;
// Coerce truthy/falsy form values to a real boolean. The toggle
// control binds to `checked`, so this is normally already a bool,
// but legacy form state may carry the string defaults.
config.verify_tls = form.verify_tls === false || form.verify_tls === 'false' ? false : true;
if (!editing && !form.access_token) {
return { config, error: 'providers.haAccessTokenRequired' };
}
return { config };
},
hasConfigChanged(form, existing) {
const existingVerify = existing.verify_tls !== false;
const formVerify = !(form.verify_tls === false || form.verify_tls === 'false');
return (
form.url !== (existing.url || '') ||
!!form.access_token ||
existingVerify !== formVerify
);
},
eventFields: [
{ key: 'track_ha_state_changed', label: 'trackingConfig.haStateChanged', default: true },
{ key: 'track_ha_automation_triggered', label: 'trackingConfig.haAutomationTriggered', default: false },
{ key: 'track_ha_service_called', label: 'trackingConfig.haServiceCalled', default: false },
{
key: 'track_ha_event_fired',
label: 'trackingConfig.haEventFired',
default: false,
hint: 'trackingConfig.haEventFiredHint',
},
],
// entity_glob / domain_allowlist tag-style filters. Stored on the
// tracker's `filters` JSON column (not the flat form root) — the
// TrackerForm reads `inputMode: 'tags'` to render a chip input rather
// than a picker, and `filterKey` routes the value into
// `tracker.filters[filterKey]` at save time.
userFilters: [
{
key: 'entity_glob',
filterKey: 'entity_glob',
inputMode: 'tags',
label: 'notificationTracker.haEntityGlob',
placeholder: 'notificationTracker.haEntityGlobPlaceholder',
icon: 'mdiAsterisk',
},
{
key: 'domain_allowlist',
filterKey: 'domain_allowlist',
inputMode: 'tags',
label: 'notificationTracker.haDomainAllowlist',
placeholder: 'notificationTracker.haDomainAllowlistPlaceholder',
icon: 'mdiTagOutline',
},
],
collectionMeta: {
label: 'notificationTracker.entities',
icon: 'mdiViewList',
placeholder: 'notificationTracker.selectEntities',
countLabel: 'notificationTracker.entities_count',
desc: (col: { state?: string; domain?: string; entity_id?: string; id?: string }) => {
const parts: string[] = [];
if (col.domain) parts.push(col.domain);
if (col.state) parts.push(col.state);
if (parts.length === 0) return col.entity_id || col.id || '';
return parts.join(' · ');
},
},
};
+2
View File
@@ -13,6 +13,7 @@ import { schedulerDescriptor } from './scheduler';
import { nutDescriptor } from './nut'; import { nutDescriptor } from './nut';
import { googlePhotosDescriptor } from './google-photos'; import { googlePhotosDescriptor } from './google-photos';
import { webhookDescriptor } from './webhook'; import { webhookDescriptor } from './webhook';
import { homeAssistantDescriptor } from './home-assistant';
const REGISTRY: ReadonlyMap<string, ProviderDescriptor> = new Map([ const REGISTRY: ReadonlyMap<string, ProviderDescriptor> = new Map([
['immich', immichDescriptor], ['immich', immichDescriptor],
@@ -22,6 +23,7 @@ const REGISTRY: ReadonlyMap<string, ProviderDescriptor> = new Map([
['nut', nutDescriptor], ['nut', nutDescriptor],
['google_photos', googlePhotosDescriptor], ['google_photos', googlePhotosDescriptor],
['webhook', webhookDescriptor], ['webhook', webhookDescriptor],
['home_assistant', homeAssistantDescriptor],
]); ]);
/** Look up a provider descriptor by type. Returns null for unknown types. */ /** Look up a provider descriptor by type. Returns null for unknown types. */
+21 -8
View File
@@ -20,7 +20,7 @@ export interface ConfigField {
configKey?: string; configKey?: string;
/** i18n key for the field label. */ /** i18n key for the field label. */
label: string; label: string;
type: 'text' | 'password' | 'number' | 'grid-select'; type: 'text' | 'password' | 'number' | 'grid-select' | 'toggle';
/** Grid-select item source function name from grid-items.ts. */ /** Grid-select item source function name from grid-items.ts. */
gridItems?: string; gridItems?: string;
gridColumns?: number; gridColumns?: number;
@@ -123,17 +123,30 @@ export interface CollectionMeta {
// ── User-identity filters (TrackerForm) ────────────────────────────── // ── User-identity filters (TrackerForm) ──────────────────────────────
/** /**
* Declares a filter that picks user identities from the provider's known * Declares a filter rendered on the tracker form. Two input modes:
* senders. Rendered as a MultiEntitySelect populated from the provider's *
* `/users` endpoint. The picked values are stored as `string[]` under * * ``picker`` (default) — populated from the provider's ``/users``
* `tracker.filters[key]`. * endpoint, rendered as a ``MultiEntitySelect``. Used for sender
* allowlists / blocklists where the valid values are known.
* * ``tags`` — free-text chip input. Used for glob patterns and other
* filter values that aren't enumerable in advance.
*
* Either way the picked values are stored as ``string[]`` under
* ``tracker.filters[filterKey ?? key]``.
*/ */
export interface UserFilterMeta { export interface UserFilterMeta {
/** Filter key inside `tracker.filters` (e.g. "senders", "exclude_senders"). */ /** Form field key — used internally for binding. */
key: string; key: string;
/** i18n key for the label rendered above the picker. */ /**
* Filter key inside ``tracker.filters``. Defaults to ``key`` when
* omitted (backward compat with the original sender allowlist usage).
*/
filterKey?: string;
/** ``picker`` (default) or ``tags`` for free-text chip input. */
inputMode?: 'picker' | 'tags';
/** i18n key for the label rendered above the input. */
label: string; label: string;
/** i18n key for the picker placeholder. */ /** i18n key for the placeholder (picker dropdown or chip input). */
placeholder: string; placeholder: string;
/** MDI icon shown on chips and dropdown rows. */ /** MDI icon shown on chips and dropdown rows. */
icon: string; icon: string;
@@ -7,6 +7,7 @@
import MdiIcon from '$lib/components/MdiIcon.svelte'; import MdiIcon from '$lib/components/MdiIcon.svelte';
import EntitySelect from '$lib/components/EntitySelect.svelte'; import EntitySelect from '$lib/components/EntitySelect.svelte';
import MultiEntitySelect from '$lib/components/MultiEntitySelect.svelte'; import MultiEntitySelect from '$lib/components/MultiEntitySelect.svelte';
import TagInput from '$lib/components/TagInput.svelte';
import { getDescriptor } from '$lib/providers'; import { getDescriptor } from '$lib/providers';
interface Props { interface Props {
@@ -123,14 +124,24 @@
{#if descriptor?.userFilters && descriptor.userFilters.length > 0} {#if descriptor?.userFilters && descriptor.userFilters.length > 0}
{@const userItems = users.map(u => ({ value: u.id, label: u.name }))} {@const userItems = users.map(u => ({ value: u.id, label: u.name }))}
{#each descriptor.userFilters as uf (uf.key)} {#each descriptor.userFilters as uf (uf.key)}
{@const filterKey = uf.filterKey ?? uf.key}
<div> <div>
<div class="block text-sm font-medium mb-1">{t(uf.label)}</div> <div class="block text-sm font-medium mb-1">{t(uf.label)}</div>
{#if uf.inputMode === 'tags'}
<TagInput
values={form.filters[filterKey] || []}
onchange={(vals) => form.filters = { ...form.filters, [filterKey]: vals }}
placeholder={t(uf.placeholder)}
icon={uf.icon}
/>
{:else}
<MultiEntitySelect <MultiEntitySelect
items={userItems.map(i => ({ ...i, icon: uf.icon }))} items={userItems.map(i => ({ ...i, icon: uf.icon }))}
values={form.filters[uf.key] || []} values={form.filters[filterKey] || []}
onchange={(vals) => form.filters = { ...form.filters, [uf.key]: vals }} onchange={(vals) => form.filters = { ...form.filters, [filterKey]: vals }}
placeholder={t(uf.placeholder)} placeholder={t(uf.placeholder)}
/> />
{/if}
</div> </div>
{/each} {/each}
{/if} {/if}
@@ -321,6 +321,11 @@
<input id="prv-{field.key}" type="number" bind:value={form[field.key]} <input id="prv-{field.key}" type="number" bind:value={form[field.key]}
min={field.min} max={field.max} min={field.min} max={field.max}
class="w-full px-3 py-2 border border-[var(--color-border)] rounded-md text-sm bg-[var(--color-background)]" /> class="w-full px-3 py-2 border border-[var(--color-border)] rounded-md text-sm bg-[var(--color-background)]" />
{:else if field.type === 'toggle'}
<label class="toggle-switch">
<input id="prv-{field.key}" type="checkbox" bind:checked={form[field.key]} />
<span class="toggle-track"></span>
</label>
{:else} {:else}
<input id="prv-{field.key}" type={field.type} bind:value={form[field.key]} <input id="prv-{field.key}" type={field.type} bind:value={form[field.key]}
required={field.required === true || (field.required === 'create-only' && !editing)} required={field.required === true || (field.required === 'create-only' && !editing)}
@@ -107,6 +107,11 @@
{:else if field.type === 'number'} {:else if field.type === 'number'}
<input id="prv-{field.key}" type="number" bind:value={form[field.key]} min={field.min} max={field.max} <input id="prv-{field.key}" type="number" bind:value={form[field.key]} min={field.min} max={field.max}
class="w-full px-3 py-2 border border-[var(--color-border)] rounded-md text-sm bg-[var(--color-background)]" /> class="w-full px-3 py-2 border border-[var(--color-border)] rounded-md text-sm bg-[var(--color-background)]" />
{:else if field.type === 'toggle'}
<label class="toggle-switch">
<input id="prv-{field.key}" type="checkbox" bind:checked={form[field.key]} />
<span class="toggle-track"></span>
</label>
{:else} {:else}
<input id="prv-{field.key}" type={field.type} bind:value={form[field.key]} <input id="prv-{field.key}" type={field.type} bind:value={form[field.key]}
required={field.required === true || field.required === 'create-only'} required={field.required === true || field.required === 'create-only'}
@@ -65,6 +65,12 @@ class EventType(str, Enum):
UPS_REPLACE_BATTERY = "ups_replace_battery" UPS_REPLACE_BATTERY = "ups_replace_battery"
UPS_OVERLOAD = "ups_overload" UPS_OVERLOAD = "ups_overload"
# Home Assistant events
HA_STATE_CHANGED = "ha_state_changed"
HA_AUTOMATION_TRIGGERED = "ha_automation_triggered"
HA_SERVICE_CALLED = "ha_service_called"
HA_EVENT_FIRED = "ha_event_fired"
@dataclass @dataclass
class ServiceEvent: class ServiceEvent:
@@ -4,7 +4,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Awaitable, Callable
if TYPE_CHECKING: if TYPE_CHECKING:
from notify_bridge_core.models.events import ServiceEvent from notify_bridge_core.models.events import ServiceEvent
@@ -21,6 +21,13 @@ class ServiceProviderType(str, Enum):
NUT = "nut" NUT = "nut"
GOOGLE_PHOTOS = "google_photos" GOOGLE_PHOTOS = "google_photos"
WEBHOOK = "webhook" WEBHOOK = "webhook"
HOME_ASSISTANT = "home_assistant"
# Callback signature for push-style providers: a coroutine that accepts a
# parsed ServiceEvent and is expected to enqueue it for dispatch. Returning
# None keeps the contract narrow — error handling stays inside the callback.
EventEmitCallback = Callable[["ServiceEvent"], Awaitable[None]]
class ServiceProvider(ABC): class ServiceProvider(ABC):
@@ -28,10 +35,27 @@ class ServiceProvider(ABC):
A service provider connects to an external service (e.g., Immich photo server) A service provider connects to an external service (e.g., Immich photo server)
and can poll for changes, producing generic ServiceEvent objects. and can poll for changes, producing generic ServiceEvent objects.
Two ingest modes coexist on this base class:
* Polling providers (Immich, NUT, Google Photos, Scheduler) implement
:meth:`poll` and leave :attr:`supports_subscription` False.
* Webhook providers (Gitea, Planka, generic Webhook) no-op :meth:`poll`
and receive events out-of-band via ``api/webhooks.py``.
* Subscription providers (Home Assistant) flip
:attr:`supports_subscription` to True and implement :meth:`subscribe`
to run a long-lived task that pushes events through an
``emit`` callback. They typically no-op :meth:`poll`.
""" """
provider_type: ServiceProviderType provider_type: ServiceProviderType
# When True, the lifecycle layer (server-side subscription manager) starts
# a long-running task that calls :meth:`subscribe` instead of registering
# this provider with the polling scheduler. Default False keeps the
# legacy poll/webhook flow intact for every existing provider.
supports_subscription: bool = False
@abstractmethod @abstractmethod
async def connect(self) -> bool: async def connect(self) -> bool:
"""Connect to the service and verify connectivity. """Connect to the service and verify connectivity.
@@ -59,6 +83,27 @@ class ServiceProvider(ABC):
Tuple of (list of events detected, updated state dict). Tuple of (list of events detected, updated state dict).
""" """
async def subscribe(self, emit: EventEmitCallback) -> None:
"""Run a long-lived subscription that calls ``emit`` for each event.
Override on providers with :attr:`supports_subscription` = True. The
implementation is expected to:
* Loop until cancelled (the subscription manager uses
:func:`asyncio.Task.cancel` on shutdown).
* Handle its own reconnect with exponential backoff — never propagate
transient network errors to the caller.
* Pass parsed :class:`ServiceEvent` instances to ``emit`` for
enqueueing/dispatch. The callback is responsible for routing.
The default implementation raises :class:`NotImplementedError` so
accidental wiring of a polling provider into the subscription manager
fails loudly rather than silently doing nothing.
"""
raise NotImplementedError(
f"{type(self).__name__} does not support subscription-based ingest"
)
@abstractmethod @abstractmethod
def get_available_variables(self) -> list[TemplateVariableDefinition]: def get_available_variables(self) -> list[TemplateVariableDefinition]:
"""Return the template variables this provider makes available.""" """Return the template variables this provider makes available."""
@@ -444,6 +444,76 @@ WEBHOOK_CAPABILITIES = ProviderCapabilities(
], ],
) )
# ---------------------------------------------------------------------------
# Home Assistant provider capabilities
# ---------------------------------------------------------------------------
HOME_ASSISTANT_CAPABILITIES = ProviderCapabilities(
provider_type="home_assistant",
display_name="Home Assistant",
webhook_based=False,
supported_filters=[
{
"key": "collections",
"label": "Entities",
"type": "tags",
"placeholder": "light.kitchen",
},
{
"key": "entity_glob",
"label": "Entity glob",
"type": "tags",
"placeholder": "light.*",
},
{
"key": "domain_allowlist",
"label": "Domains",
"type": "tags",
"placeholder": "light, binary_sensor",
},
],
notification_slots=[
{"name": "message_ha_state_changed", "description": "Entity state changed"},
{"name": "message_ha_automation_triggered", "description": "Automation triggered"},
{"name": "message_ha_service_called", "description": "HA service called"},
{"name": "message_ha_event_fired", "description": "Other HA event fired"},
],
events=[
{"name": "ha_state_changed", "description": "Entity state changed"},
{"name": "ha_automation_triggered", "description": "Automation triggered"},
{"name": "ha_service_called", "description": "HA service called"},
{"name": "ha_event_fired", "description": "Other HA event fired (catch-all)"},
],
command_slots=[
# Response templates
{"name": "start", "description": "/start greeting message"},
{"name": "help", "description": "/help command listing"},
{"name": "status", "description": "/status connection summary"},
{"name": "entities", "description": "/entities matching glob"},
{"name": "state", "description": "/state single-entity drill-down"},
{"name": "areas", "description": "/areas with entity counts"},
{"name": "rate_limited", "description": "Rate limit warning message"},
{"name": "no_results", "description": "Empty results fallback"},
# Description slots
{"name": "desc_help", "description": "Menu description for /help"},
{"name": "desc_status", "description": "Menu description for /status"},
{"name": "desc_entities", "description": "Menu description for /entities"},
{"name": "desc_state", "description": "Menu description for /state"},
{"name": "desc_areas", "description": "Menu description for /areas"},
# Usage examples
{"name": "usage_entities", "description": "Usage example for /entities"},
{"name": "usage_state", "description": "Usage example for /state"},
],
commands=[
{"name": "status", "description": "Show connection status"},
{"name": "entities", "description": "List entities (optional glob)"},
{"name": "state", "description": "Show state for one entity"},
{"name": "areas", "description": "List HA areas with entity counts"},
{"name": "help", "description": "Show commands"},
],
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Registry # Registry
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -456,6 +526,7 @@ _REGISTRY: dict[str, ProviderCapabilities] = {
"nut": NUT_CAPABILITIES, "nut": NUT_CAPABILITIES,
"google_photos": GOOGLE_PHOTOS_CAPABILITIES, "google_photos": GOOGLE_PHOTOS_CAPABILITIES,
"webhook": WEBHOOK_CAPABILITIES, "webhook": WEBHOOK_CAPABILITIES,
"home_assistant": HOME_ASSISTANT_CAPABILITIES,
} }
@@ -0,0 +1,34 @@
"""Home Assistant service provider implementation."""
from notify_bridge_core.providers.base import ServiceProviderType
from notify_bridge_core.templates.variables import registry
from .client import (
HomeAssistantApiError,
HomeAssistantAuthError,
HomeAssistantWSClient,
_redact as redact_ha_message,
)
from .event_parser import parse_event
from .provider import (
DEFAULT_HA_EVENT_TYPES,
HOME_ASSISTANT_VARIABLES,
HomeAssistantServiceProvider,
)
# Register HA variables in the global registry — same pattern as the other
# providers in this package.
registry.register_provider_variables(
ServiceProviderType.HOME_ASSISTANT, HOME_ASSISTANT_VARIABLES,
)
__all__ = [
"DEFAULT_HA_EVENT_TYPES",
"HOME_ASSISTANT_VARIABLES",
"HomeAssistantApiError",
"HomeAssistantAuthError",
"HomeAssistantServiceProvider",
"HomeAssistantWSClient",
"parse_event",
"redact_ha_message",
]
@@ -0,0 +1,506 @@
"""Home Assistant WebSocket client.
Implements the slice of the HA WebSocket API we need for Phase 1:
* Authenticate with a long-lived access token.
* Subscribe to events (optionally filtered by ``event_type``).
* Fetch the state list (``get_states``) for entity picker UI.
* Fetch the entity and area registries to build an ``entity_id -> area_id``
lookup that the parser uses to enrich ``state_changed`` events with the
area name.
* Run an indefinite subscription loop with exponential backoff reconnect.
The HA protocol reference is at
https://developers.home-assistant.io/docs/api/websocket/ — message ids are
ascending integers, server replies use the same id, and authentication must
complete before any other command is accepted.
"""
from __future__ import annotations
import asyncio
import itertools
import logging
import random
import time
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Awaitable, Callable
from urllib.parse import urlparse, urlunparse
import aiohttp
_LOGGER = logging.getLogger(__name__)
class HomeAssistantAuthError(Exception):
"""Raised when HA rejects our access token. Fatal — no point retrying."""
class HomeAssistantApiError(Exception):
"""Raised when an HA WS command returns ``success: false``."""
# Default reconnect backoff: 2s, 4s, 8s, ..., capped at 60s with jitter.
_RECONNECT_BASE_SECONDS = 2.0
_RECONNECT_MAX_SECONDS = 60.0
_RECONNECT_JITTER_RATIO = 0.2
# Bounded queue between the WS receive loop and the emit consumer. Overflow
# drops the oldest event (FIFO) and logs at WARNING — better to lose one
# state_changed than fall behind the firehose indefinitely.
_EMIT_QUEUE_SIZE = 1000
def _ws_url_from_base(base_url: str) -> str:
"""Derive the HA WebSocket URL from the user-provided HTTP(S) base URL.
``http://homeassistant.local:8123`` -> ``ws://homeassistant.local:8123/api/websocket``.
The user enters their normal HA URL; we transform the scheme + append
the API path. This keeps the UI single-field and avoids confusion about
which URL form to use.
Userinfo (``user:pass@host``) is **stripped** — credentials embedded in
the URL would otherwise flow into log lines and exception strings via
``aiohttp`` error messages. The HA WS protocol uses an access-token
handshake; HTTP basic auth in the URL is never the intended path.
"""
parsed = urlparse(base_url.rstrip("/"))
if parsed.scheme in ("ws", "wss"):
scheme = parsed.scheme
elif parsed.scheme == "https":
scheme = "wss"
else:
scheme = "ws"
# ``netloc`` may contain ``user:pass@host:port``; ``hostname`` + ``port``
# rebuild it without the credential prefix.
host = parsed.hostname or ""
if parsed.port is not None:
netloc = f"{host}:{parsed.port}"
else:
netloc = host
return urlunparse(
(scheme, netloc, "/api/websocket", "", "", "")
)
def _redact(text: str) -> str:
"""Strip embedded credentials from text before logging.
``aiohttp`` exception strings include the URL, so a malformed
``https://token@host`` would otherwise expose the token. This is a
defense-in-depth measure — ``_ws_url_from_base`` already strips
userinfo from the connect URL, but third-party libs may quote the
user-supplied input separately.
"""
if not text:
return text
# Match ``scheme://[user[:pass]@]host`` and drop the userinfo segment.
import re
return re.sub(
r"(?P<scheme>\w+://)(?:[^/@\s]+@)",
r"\g<scheme>",
text,
)
class HomeAssistantWSClient:
"""Single-instance WebSocket client for one HA server."""
def __init__(
self,
session: aiohttp.ClientSession,
base_url: str,
access_token: str,
verify_tls: bool = True,
) -> None:
self._session = session
self._ws_url = _ws_url_from_base(base_url)
self._access_token = access_token
self._verify_tls = verify_tls
self._id_counter = itertools.count(1)
# ------------------------------------------------------------------
# Connection primitives
# ------------------------------------------------------------------
@asynccontextmanager
async def _connect(self) -> AsyncIterator[aiohttp.ClientWebSocketResponse]:
"""Open a fresh WS, complete the auth handshake, and yield the socket.
Raises :class:`HomeAssistantAuthError` on invalid token (fatal) and
:class:`HomeAssistantApiError` on other handshake failures (caller
decides whether to retry).
"""
ws = await self._session.ws_connect(
self._ws_url,
ssl=None if self._verify_tls else False,
heartbeat=30,
autoping=True,
)
try:
await self._authenticate(ws)
yield ws
finally:
await ws.close()
async def _authenticate(self, ws: aiohttp.ClientWebSocketResponse) -> None:
"""Run the HA auth handshake on a freshly-opened socket."""
greeting = await ws.receive_json(timeout=10)
if greeting.get("type") != "auth_required":
raise HomeAssistantApiError(
f"Expected auth_required, got {greeting.get('type')!r}"
)
await ws.send_json({"type": "auth", "access_token": self._access_token})
result = await ws.receive_json(timeout=10)
msg_type = result.get("type")
if msg_type == "auth_ok":
return
if msg_type == "auth_invalid":
raise HomeAssistantAuthError(
result.get("message") or "Home Assistant rejected the access token"
)
raise HomeAssistantApiError(
f"Unexpected auth response: {msg_type!r}"
)
async def _send_command(
self,
ws: aiohttp.ClientWebSocketResponse,
payload: dict[str, Any],
) -> int:
"""Send a command with an auto-assigned id; return that id."""
msg_id = next(self._id_counter)
await ws.send_json({"id": msg_id, **payload})
return msg_id
async def _await_result(
self,
ws: aiohttp.ClientWebSocketResponse,
msg_id: int,
timeout: float = 15.0,
) -> Any:
"""Wait for a ``result`` message matching ``msg_id`` and return its payload.
``time.monotonic`` is the right clock here — wall-clock deadlines
would jump on NTP sync, and ``asyncio.get_event_loop().time()``
is deprecated when called outside a running-loop context.
"""
deadline = time.monotonic() + timeout
while True:
remaining = deadline - time.monotonic()
if remaining <= 0:
raise HomeAssistantApiError(
f"Timed out waiting for result of command id={msg_id}"
)
msg = await ws.receive_json(timeout=remaining)
if msg.get("id") != msg_id:
# Ignore unsolicited events that arrive between sending a
# request-style command and its result.
continue
if msg.get("type") != "result":
continue
if not msg.get("success", False):
err = msg.get("error", {})
raise HomeAssistantApiError(
f"HA command failed: {err.get('code')} {err.get('message')}"
)
return msg.get("result")
# ------------------------------------------------------------------
# Multi-command session
# ------------------------------------------------------------------
@asynccontextmanager
async def session(self) -> AsyncIterator["HomeAssistantSession"]:
"""Open one authenticated WS and let the caller run multiple commands.
Each one-shot method (``get_states``, ``get_area_registry``, ...)
opens a brand-new connection with a full TCP + WS + auth handshake.
For callers that need to chain several queries (e.g. /status: connection
check + entity list + area count) that overhead adds up — 3 separate
TLS handshakes and 3 auth round-trips for what is really one logical
request.
Usage:
async with client.session() as sess:
states = await sess.get_states()
areas = await sess.get_area_registry()
The session shares the same id counter as the client, so message ids
are unique across both one-shot calls and session-scoped calls if
they happen to run concurrently against the same client instance.
"""
async with self._connect() as ws:
yield HomeAssistantSession(self, ws)
# ------------------------------------------------------------------
# One-shot commands
# ------------------------------------------------------------------
async def test_connection(self) -> tuple[bool, str]:
"""Connect, authenticate, and immediately close. Returns ``(ok, message)``."""
try:
async with self._connect() as _ws:
return True, "OK"
except HomeAssistantAuthError as err:
return False, f"Auth failed: {err}"
except (aiohttp.ClientError, asyncio.TimeoutError) as err:
return False, f"Connection failed: {err}"
except HomeAssistantApiError as err:
return False, str(err)
async def get_states(self) -> list[dict[str, Any]]:
"""Fetch the current state of every entity HA knows about."""
async with self._connect() as ws:
msg_id = await self._send_command(ws, {"type": "get_states"})
result = await self._await_result(ws, msg_id)
return list(result or [])
async def get_area_registry(self) -> list[dict[str, Any]]:
"""Fetch the area registry (``area_id`` -> name + metadata)."""
async with self._connect() as ws:
msg_id = await self._send_command(
ws, {"type": "config/area_registry/list"}
)
result = await self._await_result(ws, msg_id)
return list(result or [])
async def get_entity_registry(self) -> list[dict[str, Any]]:
"""Fetch the entity registry (entity_id -> area_id + metadata)."""
async with self._connect() as ws:
msg_id = await self._send_command(
ws, {"type": "config/entity_registry/list"}
)
result = await self._await_result(ws, msg_id)
return list(result or [])
async def get_entity_to_area_lookup(self) -> dict[str, str]:
"""Build ``{entity_id: area_name}`` using the entity + area registries.
Best-effort: returns an empty dict on any failure so the parser still
works without area enrichment.
"""
try:
entities = await self.get_entity_registry()
areas = await self.get_area_registry()
except (HomeAssistantApiError, aiohttp.ClientError, asyncio.TimeoutError) as err:
_LOGGER.warning("Could not fetch HA registry, areas disabled: %s", err)
return {}
area_names = {a.get("area_id"): a.get("name") for a in areas if a.get("area_id")}
lookup: dict[str, str] = {}
for entry in entities:
entity_id = entry.get("entity_id")
area_id = entry.get("area_id")
if not isinstance(entity_id, str) or not area_id:
continue
name = area_names.get(area_id)
if name:
lookup[entity_id] = str(name)
return lookup
# ------------------------------------------------------------------
# Subscription loop with reconnect
# ------------------------------------------------------------------
async def run_subscription(
self,
on_event: Callable[[dict[str, Any]], Awaitable[None]],
event_types: list[str] | None = None,
on_status_change: Callable[[str, str | None], None] | None = None,
refresh_areas: Callable[[], Awaitable[dict[str, str]]] | None = None,
) -> None:
"""Run an indefinite subscription loop, reconnecting on drop.
Parameters
----------
on_event:
Coroutine called with the inner ``event`` dict (the WS envelope is
stripped). Slow callbacks apply TCP backpressure naturally; the
internal queue prevents unbounded memory growth if the callback
stalls.
event_types:
Restrict the subscription to these HA event types. ``None`` or
empty subscribes to everything (very loud — only use for debug).
on_status_change:
Callback invoked with ``("connected", None)`` after a successful
handshake and ``("disconnected", reason)`` when a connection drops.
Useful for surfacing connection state in the event log.
refresh_areas:
Optional coroutine called on each (re)connect to refresh the
area lookup. The result is not used by ``run_subscription``
itself — the caller stores it where its ``on_event`` can read.
"""
attempt = 0
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=_EMIT_QUEUE_SIZE)
overflow_count = 0
async def _drain() -> None:
while True:
evt = await queue.get()
try:
await on_event(evt)
except Exception: # noqa: BLE001
_LOGGER.exception("on_event callback raised; continuing")
finally:
queue.task_done()
drain_task = asyncio.create_task(_drain(), name="ha-emit-drain")
try:
while True:
try:
async with self._connect() as ws:
attempt = 0
if on_status_change is not None:
on_status_change("connected", None)
if refresh_areas is not None:
try:
# Note: refresh_areas opens its own WS in our
# current design (each one-shot command does).
# Fine for v1 — a few hundred ms once per
# (re)connect.
await refresh_areas()
except Exception: # noqa: BLE001
_LOGGER.exception("Area refresh failed; continuing without")
# Subscribe. Passing per-event-type subscriptions is
# cheaper than subscribing to everything and filtering
# in Python — HA does the filtering.
if event_types:
for evt_type in event_types:
sub_id = await self._send_command(
ws,
{"type": "subscribe_events", "event_type": evt_type},
)
await self._await_result(ws, sub_id)
else:
sub_id = await self._send_command(
ws, {"type": "subscribe_events"}
)
await self._await_result(ws, sub_id)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
payload = msg.json()
if payload.get("type") != "event":
continue
event_obj = payload.get("event")
if not isinstance(event_obj, dict):
continue
try:
queue.put_nowait(event_obj)
except asyncio.QueueFull:
overflow_count += 1
if overflow_count % 50 == 1:
_LOGGER.warning(
"HA event queue full, dropped %d events so far "
"(consumer is slower than HA event rate)",
overflow_count,
)
# Drop oldest, retry put. This keeps the
# most recent state visible at the cost
# of older transient changes.
try:
queue.get_nowait()
queue.task_done()
except asyncio.QueueEmpty:
pass
try:
queue.put_nowait(event_obj)
except asyncio.QueueFull:
pass
elif msg.type in (
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSING,
aiohttp.WSMsgType.ERROR,
):
raise aiohttp.ClientConnectionError(
f"WS closed: {msg.type.name}"
)
else:
# PING/PONG handled by aiohttp autoping=True;
# BINARY/CONTINUATION are not used by HA today.
# Log at debug so a future protocol change is
# visible without spamming production logs.
_LOGGER.debug(
"Ignored WS message of type %s", msg.type.name,
)
except HomeAssistantAuthError as err:
# Fatal — caller must fix the access token. Reraise so
# the provider can mark itself unhealthy.
if on_status_change is not None:
on_status_change("disconnected", _redact(f"auth: {err}"))
raise
except asyncio.CancelledError:
if on_status_change is not None:
on_status_change("disconnected", "cancelled")
raise
except Exception as err: # noqa: BLE001
redacted = _redact(str(err))
if on_status_change is not None:
on_status_change("disconnected", redacted)
delay = min(
_RECONNECT_BASE_SECONDS * (2 ** attempt),
_RECONNECT_MAX_SECONDS,
)
delay *= 1 + random.uniform(-_RECONNECT_JITTER_RATIO, _RECONNECT_JITTER_RATIO)
_LOGGER.warning(
"HA WS connection lost (%s); reconnecting in %.1fs",
redacted, delay,
)
attempt = min(attempt + 1, 10)
await asyncio.sleep(delay)
finally:
drain_task.cancel()
# Drain task may finish via CancelledError (normal) or via an
# unhandled exception thrown by on_event. Either way is fine here
# — we're tearing down. Split the two cases for clarity rather
# than catching `Exception` + `CancelledError` in one clause.
try:
await drain_task
except asyncio.CancelledError:
pass
except Exception: # noqa: BLE001
_LOGGER.exception("HA drain task raised during shutdown")
# ---------------------------------------------------------------------------
# Multi-command session
# ---------------------------------------------------------------------------
class HomeAssistantSession:
"""A multi-command HA WS session bound to a single authenticated socket.
Created via :meth:`HomeAssistantWSClient.session`. Use when you need to
issue several commands in a row — sharing the connection saves the TCP
+ WS + auth round trips for every command after the first.
The session forwards id assignment to the parent client's monotonic
counter so ids stay unique across all sessions sharing the same client.
"""
def __init__(
self,
client: HomeAssistantWSClient,
ws: aiohttp.ClientWebSocketResponse,
) -> None:
self._client = client
self._ws = ws
async def send(self, payload: dict[str, Any], timeout: float = 15.0) -> Any:
"""Send one command and wait for its ``result`` envelope."""
msg_id = await self._client._send_command(self._ws, payload)
return await self._client._await_result(self._ws, msg_id, timeout=timeout)
async def get_states(self) -> list[dict[str, Any]]:
result = await self.send({"type": "get_states"})
return list(result or [])
async def get_area_registry(self) -> list[dict[str, Any]]:
result = await self.send({"type": "config/area_registry/list"})
return list(result or [])
async def get_entity_registry(self) -> list[dict[str, Any]]:
result = await self.send({"type": "config/entity_registry/list"})
return list(result or [])
@@ -0,0 +1,267 @@
"""Home Assistant event parser — HA WebSocket event dict -> ServiceEvent.
The HA event bus delivers events with this envelope:
.. code-block:: json
{
"id": 7,
"type": "event",
"event": {
"event_type": "state_changed",
"data": { ... event-type-specific ... },
"origin": "LOCAL",
"time_fired": "2026-05-13T12:34:56.789Z",
"context": { ... }
}
}
The parser accepts the inner ``event`` dict (the WS client strips the outer
envelope before calling us) and emits a :class:`ServiceEvent` ready for the
existing dispatch path. Areas are looked up via an optional ``area_lookup``
mapping so the parser stays pure — the WS client maintains the registry
cache and passes its current snapshot on each call.
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from typing import Any
from notify_bridge_core.models.events import EventType, ServiceEvent
from notify_bridge_core.providers.base import ServiceProviderType
_LOGGER = logging.getLogger(__name__)
# Defensive caps for fields that get persisted to the event_log row. Home
# Assistant's own constraints keep entity ids well under 70 chars, but a
# misbehaving custom integration could emit kilobyte-sized strings that
# would bloat the JSON details column.
_MAX_ENTITY_ID_LEN = 255
_MAX_EVENT_DATA_BYTES = 4096
def _parse_time_fired(raw: Any) -> datetime:
"""Parse HA's ``time_fired`` ISO string, falling back to now() on garbage.
HA always sends UTC with a ``Z`` suffix or explicit ``+00:00``. Datetime
parsing is wrapped because a malformed payload should not break the
pipeline — better to dispatch with a slightly-off timestamp than drop.
"""
if isinstance(raw, str):
try:
# ``datetime.fromisoformat`` accepts ``+00:00`` natively; rewrite
# the trailing ``Z`` since pre-3.11 stdlib rejects it.
cleaned = raw[:-1] + "+00:00" if raw.endswith("Z") else raw
return datetime.fromisoformat(cleaned)
except ValueError:
_LOGGER.debug("Unparseable HA time_fired %r, using now()", raw)
return datetime.now(timezone.utc)
def _domain_of(entity_id: str) -> str:
"""Return the HA domain prefix (``light.kitchen`` -> ``light``)."""
if "." in entity_id:
return entity_id.split(".", 1)[0]
return ""
def _friendly_name(state_obj: dict[str, Any] | None, entity_id: str) -> str:
"""Pull ``friendly_name`` from attributes or fall back to entity_id."""
if not state_obj:
return entity_id
attrs = state_obj.get("attributes") or {}
name = attrs.get("friendly_name")
return str(name) if name else entity_id
def parse_event(
ha_event: dict[str, Any],
provider_name: str,
area_lookup: dict[str, str] | None = None,
) -> ServiceEvent | None:
"""Parse one HA event dict into a :class:`ServiceEvent`.
Returns None for malformed payloads (missing ``event_type`` etc.) so the
caller can drop without raising. Genuine network/parsing exceptions
bubble up — only known-bad payload shapes return None.
"""
if not isinstance(ha_event, dict):
return None
event_type_raw = ha_event.get("event_type")
if not isinstance(event_type_raw, str):
return None
data = ha_event.get("data") or {}
timestamp = _parse_time_fired(ha_event.get("time_fired"))
area_lookup = area_lookup or {}
if event_type_raw == "state_changed":
return _parse_state_changed(data, timestamp, provider_name, area_lookup)
if event_type_raw == "automation_triggered":
return _parse_automation_triggered(data, timestamp, provider_name)
if event_type_raw == "call_service":
return _parse_call_service(data, timestamp, provider_name)
# Everything else maps to the generic "event_fired" slot. Tracking
# configs decide whether to enable this loud catch-all.
return _parse_generic_event(event_type_raw, data, timestamp, provider_name)
def _parse_state_changed(
data: dict[str, Any],
timestamp: datetime,
provider_name: str,
area_lookup: dict[str, str],
) -> ServiceEvent | None:
entity_id = data.get("entity_id")
if not isinstance(entity_id, str):
return None
entity_id = entity_id[:_MAX_ENTITY_ID_LEN]
old_state_obj = data.get("old_state") if isinstance(data.get("old_state"), dict) else None
new_state_obj = data.get("new_state") if isinstance(data.get("new_state"), dict) else None
# ``new_state`` is None when an entity is removed — surface it as a
# transition to the literal string "removed" so templates can branch.
old_state_val = old_state_obj.get("state") if old_state_obj else None
new_state_val = new_state_obj.get("state") if new_state_obj else "removed"
attributes = (new_state_obj or {}).get("attributes") or {}
friendly_name = _friendly_name(new_state_obj or old_state_obj, entity_id)
domain = _domain_of(entity_id)
extra: dict[str, Any] = {
"entity_id": entity_id,
"friendly_name": friendly_name,
"domain": domain,
"old_state": old_state_val,
"new_state": new_state_val,
"attributes": attributes,
"device_class": attributes.get("device_class"),
"unit_of_measurement": attributes.get("unit_of_measurement"),
"area": area_lookup.get(entity_id),
"ha_event_type": "state_changed",
}
if new_state_obj and "last_changed" in new_state_obj:
extra["last_changed"] = new_state_obj["last_changed"]
if new_state_obj and "last_updated" in new_state_obj:
extra["last_updated"] = new_state_obj["last_updated"]
return ServiceEvent(
event_type=EventType.HA_STATE_CHANGED,
provider_type=ServiceProviderType.HOME_ASSISTANT,
provider_name=provider_name,
collection_id=entity_id,
collection_name=friendly_name,
timestamp=timestamp,
extra=extra,
)
def _parse_automation_triggered(
data: dict[str, Any],
timestamp: datetime,
provider_name: str,
) -> ServiceEvent | None:
entity_id = data.get("entity_id")
if isinstance(entity_id, str):
entity_id = entity_id[:_MAX_ENTITY_ID_LEN]
automation_name = data.get("name") or (entity_id if isinstance(entity_id, str) else "automation")
source = data.get("source") or ""
collection_id = entity_id if isinstance(entity_id, str) else f"automation.{automation_name}"
collection_id = collection_id[:_MAX_ENTITY_ID_LEN]
return ServiceEvent(
event_type=EventType.HA_AUTOMATION_TRIGGERED,
provider_type=ServiceProviderType.HOME_ASSISTANT,
provider_name=provider_name,
collection_id=collection_id,
collection_name=str(automation_name),
timestamp=timestamp,
extra={
"entity_id": entity_id,
"automation_name": str(automation_name),
"trigger_source": str(source),
"ha_event_type": "automation_triggered",
},
)
def _parse_call_service(
data: dict[str, Any],
timestamp: datetime,
provider_name: str,
) -> ServiceEvent | None:
domain = data.get("domain")
service = data.get("service")
if not isinstance(domain, str) or not isinstance(service, str):
return None
domain = domain[:_MAX_ENTITY_ID_LEN]
service = service[:_MAX_ENTITY_ID_LEN]
service_data = data.get("service_data") if isinstance(data.get("service_data"), dict) else {}
qualified = f"{domain}.{service}"
target_entity = None
if isinstance(service_data, dict):
raw_target = service_data.get("entity_id")
if isinstance(raw_target, str):
target_entity = raw_target
elif isinstance(raw_target, list) and raw_target:
target_entity = ", ".join(str(x) for x in raw_target)
return ServiceEvent(
event_type=EventType.HA_SERVICE_CALLED,
provider_type=ServiceProviderType.HOME_ASSISTANT,
provider_name=provider_name,
collection_id=qualified,
collection_name=qualified,
timestamp=timestamp,
extra={
"service_domain": domain,
"service_name": service,
"service_called": qualified,
"service_data": service_data,
"target_entity": target_entity,
"ha_event_type": "call_service",
},
)
def _parse_generic_event(
event_type_raw: str,
data: dict[str, Any],
timestamp: datetime,
provider_name: str,
) -> ServiceEvent | None:
event_type_raw = event_type_raw[:_MAX_ENTITY_ID_LEN]
# Cap the serialized payload so a custom HA integration that emits
# a multi-megabyte event_data dict doesn't blow up the event_log JSON
# column. Templates can still reference fields up to the cap; beyond it
# the dict is replaced with a marker so the limit is visible to authors.
capped_data: Any = data
try:
serialized = json.dumps(data, default=str)
except (TypeError, ValueError):
# Unserializable payload — keep the dict in-memory so templates can
# still read scalar fields, but flag the size as 0 to avoid surprises.
serialized = ""
if len(serialized.encode("utf-8")) > _MAX_EVENT_DATA_BYTES:
capped_data = {
"_truncated": True,
"_original_size_bytes": len(serialized.encode("utf-8")),
"_note": f"event_data exceeded {_MAX_EVENT_DATA_BYTES}B and was dropped",
}
return ServiceEvent(
event_type=EventType.HA_EVENT_FIRED,
provider_type=ServiceProviderType.HOME_ASSISTANT,
provider_name=provider_name,
collection_id=event_type_raw,
collection_name=event_type_raw,
timestamp=timestamp,
extra={
"ha_event_type": event_type_raw,
"event_data": capped_data,
},
)
@@ -0,0 +1,312 @@
"""Home Assistant service provider — WebSocket subscription based.
Unlike polling providers (Immich, NUT, Google Photos) and webhook providers
(Gitea, Planka), the HA provider maintains a long-lived WebSocket connection
to the HA server and pushes events into the dispatch pipeline as they
arrive. The lifecycle is owned by the server-side subscription manager
(see ``services/ha_subscription.py``).
"""
from __future__ import annotations
import logging
from typing import Any
import aiohttp
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.providers.base import (
EventEmitCallback,
ServiceProvider,
ServiceProviderType,
)
from notify_bridge_core.templates.variables import TemplateVariableDefinition
from .client import HomeAssistantWSClient
from .event_parser import parse_event
_LOGGER = logging.getLogger(__name__)
# Home Assistant template variables exposed to Jinja2.
HOME_ASSISTANT_VARIABLES: list[TemplateVariableDefinition] = [
TemplateVariableDefinition(
name="entity_id",
type="string",
description="HA entity id (e.g. light.kitchen)",
example="light.kitchen",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="friendly_name",
type="string",
description="Human-readable entity name from attributes.friendly_name",
example="Kitchen Light",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="domain",
type="string",
description="HA domain prefix of the entity (e.g. light, sensor, binary_sensor)",
example="light",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="old_state",
type="string",
description="Previous state string before the change",
example="off",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="new_state",
type="string",
description="New state string (literal 'removed' when entity was deleted)",
example="on",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="attributes",
type="dict",
description="Full attributes dict of the new state",
example='{"brightness": 255, "color_mode": "brightness"}',
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="device_class",
type="string",
description="Device class from attributes (motion, door, temperature, ...)",
example="motion",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="unit_of_measurement",
type="string",
description="Unit suffix for numeric sensors",
example="°C",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="area",
type="string",
description="Area name from the HA area registry (empty when not assigned)",
example="Kitchen",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="last_changed",
type="string",
description="ISO timestamp of last state change",
example="2026-05-13T12:34:56.789+00:00",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="last_updated",
type="string",
description="ISO timestamp of last attribute or state update",
example="2026-05-13T12:34:56.789+00:00",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="automation_name",
type="string",
description="Automation name (automation_triggered events)",
example="Front Door Notification",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="trigger_source",
type="string",
description="Why an automation fired (automation_triggered events)",
example="state of binary_sensor.front_door",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="service_called",
type="string",
description="Qualified service name (call_service events)",
example="light.turn_on",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="service_domain",
type="string",
description="Service domain (call_service events)",
example="light",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="service_name",
type="string",
description="Service name within domain (call_service events)",
example="turn_on",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="service_data",
type="dict",
description="Service payload (call_service events)",
example='{"entity_id": "light.kitchen"}',
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="target_entity",
type="string",
description="entity_id targeted by a service call (comma-joined for multi-target)",
example="light.kitchen",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="ha_event_type",
type="string",
description="Raw HA event_type (state_changed, automation_triggered, ...)",
example="state_changed",
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
TemplateVariableDefinition(
name="event_data",
type="dict",
description="Raw event data (generic event_fired events)",
example='{"key": "value"}',
provider_type=ServiceProviderType.HOME_ASSISTANT,
),
]
# Default event types subscribed to when the user does not override. Only
# state_changed is on by default — the others are loud and opt-in via the
# tracking-config event checkboxes.
DEFAULT_HA_EVENT_TYPES: tuple[str, ...] = ("state_changed",)
class HomeAssistantServiceProvider(ServiceProvider):
"""Home Assistant WebSocket subscription provider."""
provider_type = ServiceProviderType.HOME_ASSISTANT
supports_subscription = True
def __init__(
self,
session: aiohttp.ClientSession,
url: str,
access_token: str,
verify_tls: bool = True,
event_types: list[str] | None = None,
name: str = "Home Assistant",
) -> None:
self._client = HomeAssistantWSClient(
session=session,
base_url=url,
access_token=access_token,
verify_tls=verify_tls,
)
self._name = name
self._event_types = list(event_types) if event_types else list(DEFAULT_HA_EVENT_TYPES)
# ``_area_lookup`` is refreshed on every (re)connect by run_subscription's
# ``refresh_areas`` hook so the parser can enrich state_changed events
# with the current area name.
self._area_lookup: dict[str, str] = {}
@property
def client(self) -> HomeAssistantWSClient:
return self._client
async def connect(self) -> bool:
ok, _ = await self._client.test_connection()
return ok
async def disconnect(self) -> None:
# Session lifecycle is managed by the caller; the WS connection is
# owned by run_subscription which exits on cancel.
return None
async def poll(
self,
collection_ids: list[str],
tracker_state: dict[str, Any],
) -> tuple[list[ServiceEvent], dict[str, Any]]:
# Subscription-based ingest. The polling scheduler MUST NOT call us
# — the subscription manager owns this provider's lifecycle instead.
return [], tracker_state
async def subscribe(self, emit: EventEmitCallback) -> None:
async def _on_event(ha_event: dict[str, Any]) -> None:
event = parse_event(
ha_event,
provider_name=self._name,
area_lookup=self._area_lookup,
)
if event is None:
return
await emit(event)
async def _refresh_areas() -> dict[str, str]:
try:
self._area_lookup = await self._client.get_entity_to_area_lookup()
except Exception: # noqa: BLE001
# Best-effort: keep the previous lookup on failure.
_LOGGER.exception("Failed to refresh HA area lookup")
return self._area_lookup
await self._client.run_subscription(
on_event=_on_event,
event_types=self._event_types,
refresh_areas=_refresh_areas,
)
def get_available_variables(self) -> list[TemplateVariableDefinition]:
return list(HOME_ASSISTANT_VARIABLES)
def get_provider_config_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "Home Assistant base URL (http://homeassistant.local:8123)",
"example": "http://homeassistant.local:8123",
},
"access_token": {
"type": "string",
"description": "Long-lived access token (HA Profile -> Long-Lived Access Tokens)",
"secret": True,
},
"verify_tls": {
"type": "boolean",
"description": "Validate TLS certificate. Disable only for self-signed HA setups on trusted networks.",
"default": True,
},
"event_types": {
"type": "array",
"items": {"type": "string"},
"description": "HA event types to subscribe to. Defaults to ['state_changed'].",
"default": list(DEFAULT_HA_EVENT_TYPES),
},
},
"required": ["url", "access_token"],
}
async def list_collections(self) -> list[dict[str, Any]]:
"""Return the current entity list for the entity-picker UI."""
try:
states = await self._client.get_states()
except Exception as err: # noqa: BLE001
_LOGGER.warning("Could not fetch HA states: %s", err)
return []
out: list[dict[str, Any]] = []
for state in states:
entity_id = state.get("entity_id")
if not isinstance(entity_id, str):
continue
attrs = state.get("attributes") or {}
out.append({
"id": entity_id,
"name": attrs.get("friendly_name") or entity_id,
"state": state.get("state"),
"domain": entity_id.split(".", 1)[0] if "." in entity_id else "",
})
return out
async def test_connection(self) -> dict[str, Any]:
ok, message = await self._client.test_connection()
return {"ok": ok, "message": message}
@@ -0,0 +1,9 @@
🗺️ <b>Areas</b>
{%- if areas %}
{%- for a in areas %}
<b>{{ a.name }}</b> — {{ a.entity_count }} entity(ies)
{%- endfor %}
<i>Total: {{ total }}</i>
{%- else %}
No areas configured in Home Assistant.
{%- endif %}
@@ -0,0 +1 @@
List HA areas with entity counts
@@ -0,0 +1 @@
List entities (optional glob)
@@ -0,0 +1 @@
Show full state for one entity
@@ -0,0 +1 @@
Show Home Assistant connection status
@@ -0,0 +1,11 @@
🔍 <b>Entities</b>{% if glob %} matching <code>{{ glob }}</code>{% endif %}
{%- if entities %}
{%- for e in entities %}
<code>{{ e.entity_id }}</code> — <b>{{ e.state }}</b>{% if e.unit_of_measurement %} {{ e.unit_of_measurement }}{% endif %}{% if e.friendly_name and e.friendly_name != e.entity_id %} · <i>{{ e.friendly_name }}</i>{% endif %}
{%- endfor %}
{%- if total > shown %}
<i>Showing {{ shown }} of {{ total }} — refine the glob to narrow further.</i>
{%- endif %}
{%- else %}
No entities matched.
{%- endif %}
@@ -0,0 +1,4 @@
🏠 <b>Home Assistant commands</b>
{%- for cmd in commands %}
/{{ cmd.name }} — {{ cmd.description }}
{%- endfor %}
@@ -0,0 +1 @@
⏳ Too many requests. Please wait a moment and try again.
@@ -0,0 +1,3 @@
🏠 <b>Home Assistant bot</b>
Send /help to see what I can do.
@@ -0,0 +1,27 @@
{%- if found %}
🏠 <b>{{ friendly_name }}</b>
<code>{{ entity_id }}</code>
State: <b>{{ state }}</b>{% if unit_of_measurement %} {{ unit_of_measurement }}{% endif %}
{%- if device_class %}
Class: <i>{{ device_class }}</i>
{%- endif %}
{%- if last_changed %}
Last changed: <i>{{ last_changed }}</i>
{%- endif %}
{%- if attributes %}
<b>Attributes</b>
{%- for key, value in attributes.items() %}
• {{ key }}: <code>{{ (value if value is string else value | tojson) | string | truncate(120) }}</code>
{%- endfor %}
{%- if hidden_attr_count and hidden_attr_count > 0 %}
<i>… and {{ hidden_attr_count }} more attribute(s) hidden (sensitive or truncated for length)</i>
{%- endif %}
{%- endif %}
{%- elif reason == 'missing_arg' %}
Usage: <code>/state &lt;entity_id&gt;</code>
{%- elif reason == 'not_found' %}
Entity <code>{{ entity_id }}</code> not found.
{%- else %}
Could not load state for <code>{{ entity_id }}</code>: {{ error }}
{%- endif %}
@@ -0,0 +1,8 @@
🏠 <b>{{ provider_name }}</b>
{%- if ok %}
<i>Connected</i> · {{ url }}
Entities: <b>{{ entity_count }}</b> · Areas: <b>{{ area_count }}</b>
{%- else %}
<i>Disconnected</i>
<code>{{ message }}</code>
{%- endif %}
@@ -0,0 +1 @@
/entities [glob] e.g. /entities light.*
@@ -0,0 +1 @@
/state &lt;entity_id&gt; e.g. /state light.kitchen
@@ -64,6 +64,15 @@ PROVIDER_COMMAND_SLOTS: dict[str, list[str]] = {
# Usage example slots # Usage example slots
"usage_latest", "usage_search", "usage_random", "usage_latest", "usage_search", "usage_random",
], ],
"home_assistant": [
# Response templates
"start", "help", "status", "entities", "state", "areas",
"rate_limited", "no_results",
# Description slots
"desc_help", "desc_status", "desc_entities", "desc_state", "desc_areas",
# Usage examples
"usage_entities", "usage_state",
],
} }
# Backward-compatible aliases # Backward-compatible aliases
@@ -0,0 +1,9 @@
🗺️ <b>Зоны</b>
{%- if areas %}
{%- for a in areas %}
<b>{{ a.name }}</b> — {{ a.entity_count }} сущность(ей)
{%- endfor %}
<i>Всего: {{ total }}</i>
{%- else %}
В Home Assistant не настроено ни одной зоны.
{%- endif %}
@@ -0,0 +1 @@
Список зон HA с количеством сущностей
@@ -0,0 +1 @@
Список сущностей (можно указать glob)
@@ -0,0 +1 @@
Показать список команд
@@ -0,0 +1 @@
Полное состояние одной сущности
@@ -0,0 +1 @@
Статус подключения к Home Assistant
@@ -0,0 +1,11 @@
🔍 <b>Сущности</b>{% if glob %} по шаблону <code>{{ glob }}</code>{% endif %}
{%- if entities %}
{%- for e in entities %}
<code>{{ e.entity_id }}</code> — <b>{{ e.state }}</b>{% if e.unit_of_measurement %} {{ e.unit_of_measurement }}{% endif %}{% if e.friendly_name and e.friendly_name != e.entity_id %} · <i>{{ e.friendly_name }}</i>{% endif %}
{%- endfor %}
{%- if total > shown %}
<i>Показано {{ shown }} из {{ total }} — уточните шаблон, чтобы сузить.</i>
{%- endif %}
{%- else %}
Совпадений не найдено.
{%- endif %}
@@ -0,0 +1,4 @@
🏠 <b>Команды Home Assistant</b>
{%- for cmd in commands %}
/{{ cmd.name }} — {{ cmd.description }}
{%- endfor %}
@@ -0,0 +1 @@
Нет результатов.
@@ -0,0 +1 @@
⏳ Слишком много запросов. Попробуйте снова чуть позже.
@@ -0,0 +1,3 @@
🏠 <b>Бот Home Assistant</b>
Отправьте /help, чтобы посмотреть, что я умею.
@@ -0,0 +1,27 @@
{%- if found %}
🏠 <b>{{ friendly_name }}</b>
<code>{{ entity_id }}</code>
Состояние: <b>{{ state }}</b>{% if unit_of_measurement %} {{ unit_of_measurement }}{% endif %}
{%- if device_class %}
Класс: <i>{{ device_class }}</i>
{%- endif %}
{%- if last_changed %}
Последнее изменение: <i>{{ last_changed }}</i>
{%- endif %}
{%- if attributes %}
<b>Атрибуты</b>
{%- for key, value in attributes.items() %}
• {{ key }}: <code>{{ (value if value is string else value | tojson) | string | truncate(120) }}</code>
{%- endfor %}
{%- if hidden_attr_count and hidden_attr_count > 0 %}
<i>… и ещё {{ hidden_attr_count }} атрибут(ов) скрыты (содержат секреты или обрезаны по длине)</i>
{%- endif %}
{%- endif %}
{%- elif reason == 'missing_arg' %}
Использование: <code>/state &lt;entity_id&gt;</code>
{%- elif reason == 'not_found' %}
Сущность <code>{{ entity_id }}</code> не найдена.
{%- else %}
Не удалось загрузить состояние <code>{{ entity_id }}</code>: {{ error }}
{%- endif %}
@@ -0,0 +1,8 @@
🏠 <b>{{ provider_name }}</b>
{%- if ok %}
<i>Подключено</i> · {{ url }}
Сущностей: <b>{{ entity_count }}</b> · Зон: <b>{{ area_count }}</b>
{%- else %}
<i>Отключено</i>
<code>{{ message }}</code>
{%- endif %}
@@ -0,0 +1 @@
/entities [glob] например /entities light.*
@@ -0,0 +1 @@
/state &lt;entity_id&gt; например /state light.kitchen
@@ -0,0 +1,7 @@
⚙️ Automation triggered: <b>{{ automation_name }}</b>
{%- if trigger_source %}
<i>Source:</i> {{ trigger_source }}
{%- endif %}
{%- if entity_id %}
<code>{{ entity_id }}</code>
{%- endif %}
@@ -0,0 +1,4 @@
📡 HA event: <b>{{ ha_event_type }}</b>
{%- if event_data %}
<pre>{{ event_data | tojson(indent=2) }}</pre>
{%- endif %}
@@ -0,0 +1,4 @@
🔧 Service called: <b>{{ service_called }}</b>
{%- if target_entity %}
<i>Target:</i> <code>{{ target_entity }}</code>
{%- endif %}
@@ -0,0 +1,11 @@
🏠 <b>{{ friendly_name }}</b>{% if area %} <i>({{ area }})</i>{% endif %}
{%- if old_state %}
{{ old_state }} → <b>{{ new_state }}</b>{% if unit_of_measurement %} {{ unit_of_measurement }}{% endif %}
{%- else %}
<b>{{ new_state }}</b>{% if unit_of_measurement %} {{ unit_of_measurement }}{% endif %}
{%- endif %}
{%- if device_class %}
<i>{{ device_class }}</i> · <code>{{ entity_id }}</code>
{%- else %}
<code>{{ entity_id }}</code>
{%- endif %}
@@ -73,6 +73,12 @@ PROVIDER_SLOT_FILE_MAP: dict[str, dict[str, str]] = {
"message_ups_replace_battery": "nut_ups_replace_battery.jinja2", "message_ups_replace_battery": "nut_ups_replace_battery.jinja2",
"message_ups_overload": "nut_ups_overload.jinja2", "message_ups_overload": "nut_ups_overload.jinja2",
}, },
"home_assistant": {
"message_ha_state_changed": "ha_state_changed.jinja2",
"message_ha_automation_triggered": "ha_automation_triggered.jinja2",
"message_ha_service_called": "ha_service_called.jinja2",
"message_ha_event_fired": "ha_event_fired.jinja2",
},
} }
# Backward-compatible alias # Backward-compatible alias
@@ -0,0 +1,7 @@
⚙️ Автоматизация сработала: <b>{{ automation_name }}</b>
{%- if trigger_source %}
<i>Триггер:</i> {{ trigger_source }}
{%- endif %}
{%- if entity_id %}
<code>{{ entity_id }}</code>
{%- endif %}
@@ -0,0 +1,4 @@
📡 Событие HA: <b>{{ ha_event_type }}</b>
{%- if event_data %}
<pre>{{ event_data | tojson(indent=2) }}</pre>
{%- endif %}
@@ -0,0 +1,4 @@
🔧 Вызвана служба: <b>{{ service_called }}</b>
{%- if target_entity %}
<i>Цель:</i> <code>{{ target_entity }}</code>
{%- endif %}
@@ -0,0 +1,11 @@
🏠 <b>{{ friendly_name }}</b>{% if area %} <i>({{ area }})</i>{% endif %}
{%- if old_state %}
{{ old_state }} → <b>{{ new_state }}</b>{% if unit_of_measurement %} {{ unit_of_measurement }}{% endif %}
{%- else %}
<b>{{ new_state }}</b>{% if unit_of_measurement %} {{ unit_of_measurement }}{% endif %}
{%- endif %}
{%- if device_class %}
<i>{{ device_class }}</i> · <code>{{ entity_id }}</code>
{%- else %}
<code>{{ entity_id }}</code>
{%- endif %}
@@ -565,6 +565,63 @@ async def preview_raw(
"count": 2, "count": 2,
# /rate_limited # /rate_limited
"wait": 15, "wait": 15,
# --- Home Assistant: /status, /entities, /state, /areas ---
"ok": True,
"message": "OK",
"provider_name": "Home Assistant",
"url": "http://homeassistant.local:8123",
"entity_count": 142,
"area_count": 8,
"entities": [
{
"entity_id": "binary_sensor.front_door",
"friendly_name": "Front Door",
"domain": "binary_sensor",
"state": "off",
"attributes": {"device_class": "door", "friendly_name": "Front Door"},
"device_class": "door",
"unit_of_measurement": None,
"last_changed": "2026-05-13T12:34:56.789+00:00",
"last_updated": "2026-05-13T12:34:56.789+00:00",
},
{
"entity_id": "sensor.kitchen_temperature",
"friendly_name": "Kitchen Temperature",
"domain": "sensor",
"state": "21.4",
"attributes": {"unit_of_measurement": "°C", "friendly_name": "Kitchen Temperature"},
"device_class": "temperature",
"unit_of_measurement": "°C",
"last_changed": "2026-05-13T12:30:00+00:00",
"last_updated": "2026-05-13T12:30:00+00:00",
},
],
"glob": "binary_sensor.*",
"total": 12,
"shown": 2,
# /state — single entity drill-down. ``found`` controls which branch
# of the template renders.
"found": True,
"entity_id": "light.kitchen",
"friendly_name": "Kitchen Light",
"domain": "light",
"state": "on",
"attributes": {
"brightness": 200,
"color_mode": "brightness",
},
"hidden_attr_count": 0,
"device_class": None,
"unit_of_measurement": None,
"last_changed": "2026-05-13T12:34:56.789+00:00",
"last_updated": "2026-05-13T12:34:56.789+00:00",
"reason": "",
"error": "",
# /areas
"areas": [
{"area_id": "kitchen", "name": "Kitchen", "entity_count": 14},
{"area_id": "entrance", "name": "Entrance", "entity_count": 4},
],
} }
return render_template_preview(body.template, sample_ctx) return render_template_preview(body.template, sample_ctx)
@@ -3,7 +3,7 @@
import logging import logging
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, ValidationError from pydantic import AnyHttpUrl, BaseModel, ValidationError, field_validator
from sqlmodel import select from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from typing import Any from typing import Any
@@ -103,6 +103,54 @@ class WebhookProviderConfig(BaseModel):
max_stored_payloads: int = 20 # 1-100 max_stored_payloads: int = 20 # 1-100
class HomeAssistantProviderConfig(BaseModel):
url: str
access_token: str
verify_tls: bool = True
event_types: list[str] | None = None
@field_validator("url")
@classmethod
def _validate_url(cls, raw: str) -> str:
"""Reject malformed URLs early so the user sees a clear error.
``AnyHttpUrl`` accepts the homelab-friendly forms
(``http://homeassistant.local:8123``) while rejecting garbage like
``not-a-url`` or ``ftp://...``. Validation is best-effort; we still
re-derive the WebSocket URL at runtime.
"""
try:
AnyHttpUrl(raw)
except ValueError as err:
raise ValueError(f"url must be a valid http(s) URL: {err}") from err
return raw
@field_validator("event_types")
@classmethod
def _validate_event_types(cls, raw: list[str] | None) -> list[str] | None:
"""Cap list size and per-entry length; reject obvious junk.
We don't whitelist event names — HA has unbounded custom event types
from third-party integrations. Length and count caps are enough to
keep a misconfiguration from blowing up the subscription handshake.
"""
if raw is None:
return None
if len(raw) > 50:
raise ValueError("event_types accepts at most 50 entries")
cleaned: list[str] = []
for entry in raw:
if not isinstance(entry, str):
raise ValueError("event_types entries must be strings")
entry = entry.strip()
if not entry:
continue
if len(entry) > 100:
raise ValueError("event_types entries must be <=100 chars")
cleaned.append(entry)
return cleaned or None
_PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = { _PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
"immich": ImmichProviderConfig, "immich": ImmichProviderConfig,
"gitea": GiteaProviderConfig, "gitea": GiteaProviderConfig,
@@ -111,6 +159,7 @@ _PROVIDER_CONFIG_MODELS: dict[str, type[BaseModel]] = {
"nut": NutProviderConfig, "nut": NutProviderConfig,
"google_photos": GooglePhotosProviderConfig, "google_photos": GooglePhotosProviderConfig,
"webhook": WebhookProviderConfig, "webhook": WebhookProviderConfig,
"home_assistant": HomeAssistantProviderConfig,
} }
@@ -160,6 +209,18 @@ async def _test_provider_connection(provider: ServiceProvider) -> dict[str, Any]
gp = make_google_photos_provider(http_session, provider) gp = make_google_photos_provider(http_session, provider)
return await gp.test_connection() return await gp.test_connection()
if provider.type == "home_assistant":
from notify_bridge_core.providers.home_assistant import HomeAssistantServiceProvider
ha = HomeAssistantServiceProvider(
session=http_session,
url=provider.config.get("url", ""),
access_token=provider.config.get("access_token", ""),
verify_tls=bool(provider.config.get("verify_tls", True)),
event_types=provider.config.get("event_types") or None,
name=provider.name,
)
return await ha.test_connection()
if provider.type in ("scheduler", "webhook"): if provider.type in ("scheduler", "webhook"):
return {"ok": True, "message": "Virtual provider — always available"} return {"ok": True, "message": "Virtual provider — always available"}
@@ -285,6 +285,8 @@ async def get_template_variables(
**_planka_variables(), **_planka_variables(),
# --- NUT (UPS) slots --- # --- NUT (UPS) slots ---
**_nut_variables(), **_nut_variables(),
# --- Home Assistant slots ---
**_home_assistant_variables(),
# --- Scheduler slots --- # --- Scheduler slots ---
"message_scheduled_message": { "message_scheduled_message": {
"description": "Notification for scheduled message events", "description": "Notification for scheduled message events",
@@ -433,6 +435,58 @@ def _nut_variables() -> dict:
} }
def _home_assistant_variables() -> dict:
common = {
"entity_id": "HA entity id (e.g. light.kitchen)",
"friendly_name": "Human-readable entity name from attributes.friendly_name",
"domain": "HA domain prefix (light, sensor, binary_sensor, ...)",
"attributes": "Full attributes dict of the new state",
"device_class": "Device class (motion, door, temperature, ...)",
"unit_of_measurement": "Unit suffix for numeric sensors",
"area": "Area name from the HA area registry (empty when not assigned)",
"ha_event_type": "Raw HA event_type (state_changed, automation_triggered, ...)",
"last_changed": "ISO timestamp of last state change",
"last_updated": "ISO timestamp of last attribute or state update",
}
return {
"message_ha_state_changed": {
"description": "Entity state changed",
"variables": {
**common,
"old_state": "Previous state string",
"new_state": "New state string ('removed' if entity deleted)",
},
},
"message_ha_automation_triggered": {
"description": "Automation triggered",
"variables": {
"entity_id": common["entity_id"],
"automation_name": "Automation name",
"trigger_source": "Why the automation fired",
"ha_event_type": common["ha_event_type"],
},
},
"message_ha_service_called": {
"description": "HA service called",
"variables": {
"service_called": "Qualified service name (e.g. light.turn_on)",
"service_domain": "Service domain",
"service_name": "Service name within domain",
"service_data": "Service payload dict",
"target_entity": "entity_id targeted by the call (comma-joined for multi-target)",
"ha_event_type": common["ha_event_type"],
},
},
"message_ha_event_fired": {
"description": "Other HA event fired (catch-all)",
"variables": {
"ha_event_type": common["ha_event_type"],
"event_data": "Raw event data dict (use {{ event_data | tojson }} to render)",
},
},
}
@router.post("", status_code=status.HTTP_201_CREATED) @router.post("", status_code=status.HTTP_201_CREATED)
async def create_config( async def create_config(
body: TemplateConfigCreate, body: TemplateConfigCreate,
@@ -13,7 +13,6 @@ from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from notify_bridge_core.models.events import ServiceEvent from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.notifications.dispatcher import NotificationDispatcher, TargetConfig
from notify_bridge_core.providers.gitea.event_parser import parse_webhook as parse_gitea_webhook from notify_bridge_core.providers.gitea.event_parser import parse_webhook as parse_gitea_webhook
from notify_bridge_core.providers.planka.event_parser import parse_webhook as parse_planka_webhook from notify_bridge_core.providers.planka.event_parser import parse_webhook as parse_planka_webhook
from notify_bridge_core.providers.webhook.event_parser import parse_webhook as parse_generic_webhook from notify_bridge_core.providers.webhook.event_parser import parse_webhook as parse_generic_webhook
@@ -27,13 +26,7 @@ from ..database.models import (
ServiceProvider, ServiceProvider,
WebhookPayloadLog, WebhookPayloadLog,
) )
from ..services.dispatch_helpers import ( from ..services.event_dispatch import dispatch_provider_event
GateReason,
apply_tracking_display_filters,
evaluate_event_gate,
get_app_timezone,
load_link_data,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -131,7 +124,7 @@ def _passes_filters(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Shared dispatch helper # Shared dispatch helper (legacy wrapper — body moved to services/event_dispatch.py)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def _dispatch_webhook_event( async def _dispatch_webhook_event(
@@ -142,185 +135,16 @@ async def _dispatch_webhook_event(
event: ServiceEvent, event: ServiceEvent,
detail_keys: tuple[str, ...], detail_keys: tuple[str, ...],
) -> int: ) -> int:
"""Load trackers, filter, create EventLogs, dispatch notifications, and commit. """Webhook-flavoured dispatch — thin wrapper over ``dispatch_provider_event``."""
return await dispatch_provider_event(
Parameters engine=engine,
----------
engine:
SQLAlchemy async engine.
provider_id:
ID of the ServiceProvider that received the webhook.
provider_name:
Human-readable name of the provider (for logging).
provider_config:
The provider's ``config`` dict (passed through to target config builder).
event:
Parsed :class:`ServiceEvent` to dispatch.
detail_keys:
Keys from ``event.extra`` to include in the EventLog ``details`` dict.
Returns
-------
int
Number of successfully dispatched notifications.
"""
dispatched = 0
# ``defers_to_schedule`` is collected during the loop and flushed AFTER the
# main session commits — the only side-effect of failing to schedule is a
# delayed delivery (the startup loader / catch-up scan will reschedule),
# so this is best-effort and must not roll back the DB writes.
defers_to_schedule: set[Any] = set()
async with AsyncSession(engine) as session:
# App timezone is identical across trackers within one webhook request;
# pull it once.
app_tz = await get_app_timezone(session)
tracker_result = await session.exec(
select(NotificationTracker).where(
NotificationTracker.provider_id == provider_id,
NotificationTracker.enabled == True, # noqa: E712
)
)
trackers = tracker_result.all()
from ..services.deferred_dispatch import defer_event, is_deferrable
for tracker in trackers:
filters = tracker.filters or {}
if not _passes_filters(event, filters):
_LOGGER.debug(
"Event filtered out for tracker %d (%s)", tracker.id, tracker.name
)
continue
link_data = await load_link_data(session, tracker.id)
if not link_data:
continue
# Log event
extra_details = {k: v for k, v in event.extra.items() if k in detail_keys}
event_log_row = EventLog(
user_id=tracker.user_id,
tracker_id=tracker.id,
tracker_name=tracker.name,
provider_id=provider_id, provider_id=provider_id,
provider_name=provider_name, provider_name=provider_name,
event_type=event.event_type.value, provider_config=provider_config,
collection_id=event.collection_id,
collection_name=event.collection_name,
assets_count=0,
details={
"provider_type": event.provider_type.value,
**extra_details,
},
)
session.add(event_log_row)
await session.flush()
event_log_id = event_log_row.id
# Dedupe defers by parent ``link_id``: broadcast links emit one
# ``link_data`` entry per child, all sharing the same parent id —
# the deferred row is one-per-link, so we only call ``defer_event``
# once per distinct id (earliest fire_at wins on ties).
groups: dict[int, tuple[Any, list[TargetConfig]]] = {}
defers_for_event: dict[int, Any] = {}
for ld in link_data:
tc = ld["tracking_config"]
if tc is not None:
outcome = evaluate_event_gate(event, tc, app_tz)
if outcome.reason is GateReason.QUIET_HOURS:
if is_deferrable(event.event_type.value) and outcome.quiet_hours_end_at is not None:
link_id = ld.get("link_id")
if link_id is not None:
prior = defers_for_event.get(link_id)
if prior is None or outcome.quiet_hours_end_at < prior:
defers_for_event[link_id] = outcome.quiet_hours_end_at
continue
if outcome.reason is GateReason.EVENT_TYPE_DISABLED:
continue
tmpl = ld["template_config"]
target_cfg = TargetConfig(
type=ld["target_type"],
config=ld["target_config"],
template_slots=ld["template_slots"],
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
date_only_format=tmpl.date_only_format if tmpl and tmpl.date_only_format else "%d.%m.%Y",
provider_api_key=provider_config.get("api_token"),
provider_internal_url=provider_config.get("url", ""),
provider_external_url=provider_config.get("url", ""),
receivers=ld["receivers"],
)
key = id(tc) if tc is not None else 0
if key not in groups:
groups[key] = (tc, [])
groups[key][1].append(target_cfg)
# Persist defers + stamp event_log dispatch_status in the same
# session that holds the EventLog row, so the "deferred" badge
# only appears if the underlying queue rows actually exist.
if defers_for_event:
earliest = min(defers_for_event.values())
for link_id, fire_at in defers_for_event.items():
await defer_event(
session,
event=event, event=event,
user_id=tracker.user_id, detail_keys=detail_keys,
tracker_id=tracker.id, filter_fn=_passes_filters,
link_id=link_id,
event_log_id=event_log_id,
fire_at=fire_at,
) )
details = dict(event_log_row.details or {})
if not details.get("dispatch_status"):
details["dispatch_status"] = "deferred"
details["deferred_until"] = earliest.isoformat()
event_log_row.details = details
session.add(event_log_row)
defers_to_schedule.update(defers_for_event.values())
# Dispatch to targets. Isolate dispatcher exceptions per group so
# a failed remote call doesn't bubble out, abort the surrounding
# transaction, and roll back the just-written defers/event_log.
from ..services.http_session import get_http_session
dispatcher = NotificationDispatcher(session=await get_http_session())
for tc, target_configs in groups.values():
if not target_configs:
continue
shaped_event = apply_tracking_display_filters(event, tc)
if shaped_event is None:
continue
try:
results = await dispatcher.dispatch(shaped_event, target_configs)
except Exception as err: # noqa: BLE001
_LOGGER.exception(
"Dispatcher raised for tracker %d: %s", tracker.id, err,
)
continue
for r in results:
if r.get("success"):
dispatched += 1
else:
_LOGGER.error(
"Notification failed for tracker %d: %s",
tracker.id, r.get("error", "unknown"),
)
await session.commit()
# Schedule drain jobs OUTSIDE the DB session so an APScheduler hiccup
# can't roll back the persisted defer rows.
if defers_to_schedule:
from ..services.scheduler import schedule_deferred_drain
for fire_at in defers_to_schedule:
try:
schedule_deferred_drain(fire_at)
except Exception: # noqa: BLE001
_LOGGER.exception(
"Failed to schedule deferred drain for %s", fire_at,
)
return dispatched
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -34,12 +34,14 @@ def _auto_register() -> None:
from .planka_handler import PlankaCommandHandler from .planka_handler import PlankaCommandHandler
from .nut_handler import NutCommandHandler from .nut_handler import NutCommandHandler
from .webhook_handler import WebhookCommandHandler from .webhook_handler import WebhookCommandHandler
from .home_assistant_handler import HomeAssistantCommandHandler
register_handler(ImmichCommandHandler()) register_handler(ImmichCommandHandler())
register_handler(GiteaCommandHandler()) register_handler(GiteaCommandHandler())
register_handler(PlankaCommandHandler()) register_handler(PlankaCommandHandler())
register_handler(NutCommandHandler()) register_handler(NutCommandHandler())
register_handler(WebhookCommandHandler()) register_handler(WebhookCommandHandler())
register_handler(HomeAssistantCommandHandler())
# Auto-register on import # Auto-register on import
@@ -0,0 +1,375 @@
"""Home Assistant bot command handler.
Phase 2 of the HA integration. Each command opens a fresh WebSocket
connection to HA same approach used by ``HomeAssistantServiceProvider.
list_collections`` so the handler does not need to coordinate with the
long-lived subscription supervisor.
Commands:
* ``/status`` connection health, subscribed area / entity counts.
* ``/entities [glob]`` list matching entities with their current state.
* ``/state <entity_id>`` full state + attributes for one entity.
* ``/areas`` area registry summary with entity counts per area.
"""
from __future__ import annotations
import logging
from fnmatch import fnmatchcase
from typing import Any
import aiohttp
from notify_bridge_core.providers.home_assistant import (
HomeAssistantApiError,
HomeAssistantAuthError,
HomeAssistantWSClient,
redact_ha_message,
)
from ..database.models import (
CommandConfig,
CommandTracker,
CommandTrackerListener,
ServiceProvider,
TelegramBot,
)
from ..services.http_session import get_http_session
from .base import CommandResponse, ProviderCommandHandler
from .handler import _render_cmd_template
_LOGGER = logging.getLogger(__name__)
_HA_COMMANDS = {"status", "entities", "state", "areas"}
# HA exposes credentials and tokens through state attributes for some
# integrations — most notably ``camera.*`` entities surface a working
# ``access_token`` for the camera proxy URL, and ``entity_picture`` can
# carry signed URLs. Filtering keys by substring blocklist before rendering
# protects the chat user from seeing those values in /state output.
#
# Match is case-insensitive substring; tokens are intentionally generic so
# custom integrations that follow the obvious naming conventions are also
# covered. Anything not matched still renders.
_SENSITIVE_ATTR_TOKENS: tuple[str, ...] = (
"access_token",
"token",
"secret",
"password",
"passwd",
"api_key",
"apikey",
"private_key",
"session_id",
"authorization",
"bearer",
"cookie",
# ``entity_picture`` is a URL that often embeds a signed token in its
# query string (HA generates these for camera and media_player entities).
# The key itself doesn't match the credential token blocklist, so it
# gets its own explicit entry.
"entity_picture",
)
# Attributes already rendered as top-level fields by the state template; no
# point repeating them in the "Attributes" iteration.
_TOP_LEVEL_ATTRS: frozenset[str] = frozenset({
"friendly_name", "unit_of_measurement", "device_class",
})
# Hard cap on the number of attributes shown in /state to prevent message
# truncation when an entity has dozens (e.g. weather hourly forecasts,
# light supported features). After the cap, an "and N more" line is added
# by the template logic.
_MAX_ATTRIBUTES_RENDERED = 30
def _is_sensitive_attr(key: str) -> bool:
lowered = str(key).lower()
return any(tok in lowered for tok in _SENSITIVE_ATTR_TOKENS)
def _filter_attributes(attrs: dict[str, Any]) -> tuple[dict[str, Any], int]:
"""Drop sensitive keys, cap count, return ``(visible_attrs, hidden_count)``.
Hidden count covers both the security filter (blocklisted keys) and the
size cap (entries beyond ``_MAX_ATTRIBUTES_RENDERED``). The template can
surface "and N more hidden" so users know the view is incomplete.
"""
if not isinstance(attrs, dict):
return {}, 0
safe: dict[str, Any] = {}
redacted = 0
for key, value in attrs.items():
if not isinstance(key, str):
continue
if key in _TOP_LEVEL_ATTRS:
continue
if _is_sensitive_attr(key):
redacted += 1
continue
safe[key] = value
overflow = max(0, len(safe) - _MAX_ATTRIBUTES_RENDERED)
if overflow > 0:
# Stable order — sort by key so the truncation point is deterministic.
capped = dict(sorted(safe.items())[:_MAX_ATTRIBUTES_RENDERED])
return capped, redacted + overflow
return safe, redacted
def _make_ws_client(provider: ServiceProvider, session: aiohttp.ClientSession) -> HomeAssistantWSClient:
"""Build a one-shot WS client from the provider row."""
config = provider.config or {}
return HomeAssistantWSClient(
session=session,
base_url=config.get("url", ""),
access_token=config.get("access_token", ""),
verify_tls=bool(config.get("verify_tls", True)),
)
def _domain_of(entity_id: str) -> str:
return entity_id.split(".", 1)[0] if "." in entity_id else ""
def _normalize_state(state_row: dict[str, Any]) -> dict[str, Any]:
"""Flatten an HA state dict into the shape templates consume.
``attributes`` is filtered through ``_filter_attributes`` to drop
credential-like keys (e.g. ``camera.access_token``) and cap the rendered
count. ``hidden_attr_count`` is exposed so the template can surface
"and N more hidden" if the user wants to see everything they need to
use a different tool (or the HA UI itself).
"""
entity_id = state_row.get("entity_id") or ""
raw_attrs = state_row.get("attributes") or {}
visible_attrs, hidden_count = _filter_attributes(raw_attrs)
return {
"entity_id": entity_id,
"friendly_name": raw_attrs.get("friendly_name") or entity_id,
"domain": _domain_of(entity_id),
"state": state_row.get("state"),
"attributes": visible_attrs,
"hidden_attr_count": hidden_count,
"device_class": raw_attrs.get("device_class"),
"unit_of_measurement": raw_attrs.get("unit_of_measurement"),
"last_changed": state_row.get("last_changed"),
"last_updated": state_row.get("last_updated"),
}
# ---------------------------------------------------------------------------
# Command implementations
# ---------------------------------------------------------------------------
async def _cmd_status(provider: ServiceProvider) -> dict[str, Any]:
"""``/status`` — connection health + counts.
Health is derived from a live connection rather than the supervisor's
in-memory state so the user sees what's happening *right now* if they
just edited the token / URL. Connection + entity-count + area-count run
on a single WS session so a healthy /status costs one TCP + TLS + WS +
auth handshake instead of three.
"""
session = await get_http_session()
client = _make_ws_client(provider, session)
ok = True
message = "OK"
entity_count = 0
area_count = 0
try:
async with client.session() as sess:
# Reaching here proves connect + auth succeeded.
try:
entity_count = len(await sess.get_states())
except HomeAssistantApiError as err:
_LOGGER.debug("HA /status get_states failed: %s", err)
try:
area_count = len(await sess.get_area_registry())
except HomeAssistantApiError as err:
_LOGGER.debug("HA /status get_area_registry failed: %s", err)
except HomeAssistantAuthError as err:
ok = False
message = f"Auth failed: {redact_ha_message(str(err))}"
except (aiohttp.ClientError, HomeAssistantApiError) as err:
ok = False
message = redact_ha_message(str(err)) or "Connection failed"
return {
"ok": ok,
"message": message,
"provider_name": provider.name or "",
"url": (provider.config or {}).get("url", ""),
"entity_count": entity_count,
"area_count": area_count,
}
async def _cmd_entities(provider: ServiceProvider, args: str, count: int) -> dict[str, Any]:
"""``/entities [glob]`` — entities filtered by glob, capped at ``count``.
Empty args returns the first ``count`` entities in entity_id order. A
glob pattern is matched against the entity_id (case-insensitive). The
normalization step (which walks the attribute dict to redact secrets)
runs **only on the survivors** sorting and slicing happen on raw
state rows first, so an HA install with 1000+ entities doesn't
materialize 1000 normalized dicts just to discard most of them.
"""
session = await get_http_session()
client = _make_ws_client(provider, session)
try:
async with client.session() as sess:
states = await sess.get_states()
except (HomeAssistantApiError, aiohttp.ClientError, HomeAssistantAuthError) as err:
redacted = redact_ha_message(str(err))
_LOGGER.warning("HA /entities failed: %s", redacted)
return {"entities": [], "glob": args.strip(), "total": 0, "error": redacted}
glob = args.strip()
if glob:
lower_glob = glob.lower()
raw_matches = [
s for s in states
if isinstance(s.get("entity_id"), str)
and fnmatchcase(s["entity_id"].lower(), lower_glob)
]
else:
raw_matches = [s for s in states if isinstance(s.get("entity_id"), str)]
total = len(raw_matches)
raw_matches.sort(key=lambda s: s.get("entity_id", ""))
return {
"entities": [_normalize_state(s) for s in raw_matches[:count]],
"glob": glob,
"total": total,
"shown": min(count, total),
}
async def _cmd_state(provider: ServiceProvider, args: str) -> dict[str, Any]:
"""``/state <entity_id>`` — single-entity drill-down.
Returns ``found=False`` when the entity_id is missing or not present.
Templates render the no-results fallback in that case. Uses the session
context manager for consistency with the other commands even though
there's only one underlying WS call today — leaves the door open for
Phase 3 (service calls) to chain a follow-up on the same socket.
"""
target = args.strip()
if not target:
return {"found": False, "entity_id": "", "reason": "missing_arg"}
session = await get_http_session()
client = _make_ws_client(provider, session)
try:
async with client.session() as sess:
states = await sess.get_states()
except (HomeAssistantApiError, aiohttp.ClientError, HomeAssistantAuthError) as err:
redacted = redact_ha_message(str(err))
_LOGGER.warning("HA /state failed: %s", redacted)
return {"found": False, "entity_id": target, "reason": "api_error", "error": redacted}
for s in states:
if s.get("entity_id") == target:
normalized = _normalize_state(s)
return {"found": True, **normalized}
return {"found": False, "entity_id": target, "reason": "not_found"}
async def _cmd_areas(provider: ServiceProvider) -> dict[str, Any]:
"""``/areas`` — area registry with per-area entity counts.
Areas without entities are still listed so users can see which areas
exist in HA but haven't been assigned anything. The entity counts come
from the entity registry, not the state list the registry includes
disabled entities, which matches what users see in the HA UI. Both
registry calls share a single WS session so /areas costs one handshake.
"""
session = await get_http_session()
client = _make_ws_client(provider, session)
try:
async with client.session() as sess:
areas = await sess.get_area_registry()
# Entity registry failure is non-fatal — areas can still be
# listed without per-area counts.
try:
entities = await sess.get_entity_registry()
except HomeAssistantApiError:
entities = []
except (HomeAssistantApiError, aiohttp.ClientError, HomeAssistantAuthError) as err:
redacted = redact_ha_message(str(err))
_LOGGER.warning("HA /areas failed: %s", redacted)
return {"areas": [], "total": 0, "error": redacted}
counts: dict[str, int] = {}
for ent in entities:
area_id = ent.get("area_id")
if isinstance(area_id, str):
counts[area_id] = counts.get(area_id, 0) + 1
rows: list[dict[str, Any]] = []
for a in areas:
area_id = a.get("area_id")
if not isinstance(area_id, str):
continue
rows.append({
"area_id": area_id,
"name": a.get("name") or area_id,
"entity_count": counts.get(area_id, 0),
})
rows.sort(key=lambda r: r.get("name", "").lower())
return {"areas": rows, "total": len(rows)}
# ---------------------------------------------------------------------------
# Handler class
# ---------------------------------------------------------------------------
class HomeAssistantCommandHandler(ProviderCommandHandler):
"""Routes ``/status``, ``/entities``, ``/state``, ``/areas`` to the WS client."""
provider_type = "home_assistant"
def get_provider_commands(self) -> set[str]:
return _HA_COMMANDS
def get_rate_categories(self) -> dict[str, str]:
# All HA commands hit the WS API and share an "api" rate-limit bucket.
return {cmd: "api" for cmd in _HA_COMMANDS}
async def handle(
self,
cmd: str,
args: str,
count: int,
locale: str,
response_mode: str, # noqa: ARG002 — HA has no media commands; always text
provider: ServiceProvider,
cmd_templates: dict[str, dict[str, str]],
bot: TelegramBot, # noqa: ARG002
tracker: CommandTracker, # noqa: ARG002
config: CommandConfig, # noqa: ARG002
*,
listener: CommandTrackerListener | None = None, # noqa: ARG002
allowed_album_ids: set[str] | None = None, # noqa: ARG002 — HA has no album scope
page: int = 1, # noqa: ARG002 — no pagination in v1
) -> CommandResponse | None:
if cmd == "status":
ctx = await _cmd_status(provider)
elif cmd == "entities":
ctx = await _cmd_entities(provider, args, count)
elif cmd == "state":
ctx = await _cmd_state(provider, args)
elif cmd == "areas":
ctx = await _cmd_areas(provider)
else:
return None
return CommandResponse(text=_render_cmd_template(cmd_templates, cmd, locale, ctx))
@@ -11,6 +11,8 @@ _RATE_CATEGORY: dict[str, str] = {
"repos": "api", "issues": "api", "prs": "api", "commits": "api", "repos": "api", "issues": "api", "prs": "api", "commits": "api",
# Planka (API calls share a category) # Planka (API calls share a category)
"boards": "api", "cards": "api", "lists": "api", "boards": "api", "cards": "api", "lists": "api",
# Home Assistant (WebSocket queries share a category)
"entities": "api", "state": "api", "areas": "api",
} }
@@ -292,6 +292,23 @@ async def migrate_schema(engine: AsyncEngine) -> None:
) )
logger.info("Added track_webhook_received column to tracking_config table") logger.info("Added track_webhook_received column to tracking_config table")
# Add Home Assistant tracking flags to tracking_config if missing.
# state_changed defaults ON to match the canonical "watch the state bus"
# use case; the other three are loud and opt-in (defaults 0).
if await _has_table(conn, "tracking_config"):
ha_flags = [
("track_ha_state_changed", "INTEGER DEFAULT 1"),
("track_ha_automation_triggered", "INTEGER DEFAULT 0"),
("track_ha_service_called", "INTEGER DEFAULT 0"),
("track_ha_event_fired", "INTEGER DEFAULT 0"),
]
for col_name, col_type in ha_flags:
if not await _has_column(conn, "tracking_config", col_name):
await conn.execute(
text(f"ALTER TABLE tracking_config ADD COLUMN {col_name} {col_type}")
)
logger.info("Added %s column to tracking_config table", col_name)
# Add quiet hours to tracking_config if missing. # Add quiet hours to tracking_config if missing.
# Start/end are nullable HH:MM strings; quiet_hours_enabled gates them. # Start/end are nullable HH:MM strings; quiet_hours_enabled gates them.
if await _has_table(conn, "tracking_config"): if await _has_table(conn, "tracking_config"):
@@ -165,6 +165,12 @@ class TrackingConfig(SQLModel, table=True):
# Generic Webhook event tracking # Generic Webhook event tracking
track_webhook_received: bool = Field(default=True) track_webhook_received: bool = Field(default=True)
# Home Assistant event tracking
track_ha_state_changed: bool = Field(default=True)
track_ha_automation_triggered: bool = Field(default=False)
track_ha_service_called: bool = Field(default=False)
track_ha_event_fired: bool = Field(default=False)
# Immich asset display # Immich asset display
track_images: bool = Field(default=True) track_images: bool = Field(default=True)
track_videos: bool = Field(default=True) track_videos: bool = Field(default=True)
@@ -158,6 +158,7 @@ async def _seed_default_templates() -> None:
await _seed_provider_template(session, "nut", "NUT") await _seed_provider_template(session, "nut", "NUT")
await _seed_provider_template(session, "google_photos", "Google Photos") await _seed_provider_template(session, "google_photos", "Google Photos")
await _seed_provider_template(session, "webhook", "Generic Webhook") await _seed_provider_template(session, "webhook", "Generic Webhook")
await _seed_provider_template(session, "home_assistant", "Home Assistant")
await session.commit() await session.commit()
@@ -187,6 +188,10 @@ async def _seed_default_command_templates() -> None:
await _seed_provider_command_template( await _seed_provider_command_template(
session, "webhook", "Default Webhook Commands", "Default Generic Webhook command templates", session, "webhook", "Default Webhook Commands", "Default Generic Webhook command templates",
) )
await _seed_provider_command_template(
session, "home_assistant", "Default Home Assistant Commands",
"Default Home Assistant command templates",
)
await session.commit() await session.commit()
@@ -272,6 +277,14 @@ async def _seed_default_tracking_configs() -> None:
"track_ups_replace_battery": True, "track_ups_replace_battery": True,
"track_ups_overload": True, "track_ups_overload": True,
}, },
{
"provider_type": "home_assistant",
"name": "Default Home Assistant",
"track_ha_state_changed": True,
"track_ha_automation_triggered": False,
"track_ha_service_called": False,
"track_ha_event_fired": False,
},
] ]
for cfg in defaults: for cfg in defaults:
@@ -139,12 +139,20 @@ async def lifespan(app: FastAPI):
set_webhook_secret(_secret or None) set_webhook_secret(_secret or None)
from .services.scheduler import start_scheduler, get_scheduler from .services.scheduler import start_scheduler, get_scheduler
await start_scheduler() await start_scheduler()
# Phase 1 of the Home Assistant provider: subscription-based ingest runs
# outside the polling scheduler. ``start_all`` spawns one supervisor task
# per enabled HA provider row. No-op when no HA providers are configured.
from .services.ha_subscription import start_all as start_ha_subscriptions
await start_ha_subscriptions()
_READY = True _READY = True
yield yield
# Graceful shutdown — stop the scheduler FIRST so in-flight jobs finish # Graceful shutdown — cancel HA supervisors FIRST so they release their
# before we close their HTTP session. Then close the shared session and # WS connections before the shared HTTP session is closed. Then stop the
# dispose the DB engine. # polling scheduler. Order matters: scheduler.shutdown(wait=True) drains
# in-flight jobs that may also use the shared session.
_READY = False _READY = False
from .services.ha_subscription import stop_all as stop_ha_subscriptions
await stop_ha_subscriptions()
scheduler = get_scheduler() scheduler = get_scheduler()
if scheduler.running: if scheduler.running:
scheduler.shutdown(wait=True) scheduler.shutdown(wait=True)
@@ -115,6 +115,16 @@ def _make_collection_provider(
return make_planka_provider(http_session, provider) return make_planka_provider(http_session, provider)
if ptype == "google_photos": if ptype == "google_photos":
return make_google_photos_provider(http_session, provider) return make_google_photos_provider(http_session, provider)
if ptype == "home_assistant":
from notify_bridge_core.providers.home_assistant import HomeAssistantServiceProvider
return HomeAssistantServiceProvider(
session=http_session,
url=config.get("url", ""),
access_token=config.get("access_token", ""),
verify_tls=bool(config.get("verify_tls", True)),
event_types=config.get("event_types") or None,
name=provider.name,
)
# NUT provider needs no http_session # NUT provider needs no http_session
if ptype == "nut": if ptype == "nut":
return make_nut_provider(provider) # type: ignore[return-value] return make_nut_provider(provider) # type: ignore[return-value]
@@ -122,7 +132,7 @@ def _make_collection_provider(
# Set of provider types that need an aiohttp session for collection listing. # Set of provider types that need an aiohttp session for collection listing.
_HTTP_COLLECTION_PROVIDERS = {"immich", "gitea", "planka", "google_photos"} _HTTP_COLLECTION_PROVIDERS = {"immich", "gitea", "planka", "google_photos", "home_assistant"}
async def list_provider_collections(provider: ServiceProvider) -> list[dict[str, Any]]: async def list_provider_collections(provider: ServiceProvider) -> list[dict[str, Any]]:
@@ -204,6 +204,12 @@ def _event_type_enabled(event: ServiceEvent, tc: TrackingConfig) -> bool:
"ups_comms_restored": tc.track_ups_comms_restored, "ups_comms_restored": tc.track_ups_comms_restored,
"ups_replace_battery": tc.track_ups_replace_battery, "ups_replace_battery": tc.track_ups_replace_battery,
"ups_overload": tc.track_ups_overload, "ups_overload": tc.track_ups_overload,
# Home Assistant events — use getattr so legacy DB rows / test mocks
# that pre-date the columns still pass the gate (default to tracked).
"ha_state_changed": getattr(tc, "track_ha_state_changed", True),
"ha_automation_triggered": getattr(tc, "track_ha_automation_triggered", False),
"ha_service_called": getattr(tc, "track_ha_service_called", False),
"ha_event_fired": getattr(tc, "track_ha_event_fired", False),
} }
return flag_map.get(event_type, True) return flag_map.get(event_type, True)
@@ -0,0 +1,239 @@
"""Shared dispatch helper for push-style providers.
Push-style providers (webhook receivers in ``api/webhooks.py`` and the
Home Assistant subscription manager in ``services/ha_subscription.py``)
share the same downstream pipeline: write an :class:`EventLog`, evaluate
quiet hours / event-type gates, defer if needed, otherwise hand off to the
:class:`NotificationDispatcher`.
This module extracts that pipeline so both callers can reuse it without
either side importing from the other (which would create a server/api ->
services -> api cycle).
"""
from __future__ import annotations
import logging
from typing import Any, Awaitable, Callable
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.notifications.dispatcher import (
NotificationDispatcher,
TargetConfig,
)
from ..database.models import EventLog, NotificationTracker
from .deferred_dispatch import defer_event, is_deferrable
from .dispatch_helpers import (
GateReason,
apply_tracking_display_filters,
evaluate_event_gate,
get_app_timezone,
load_link_data,
)
_LOGGER = logging.getLogger(__name__)
# Filter signature: ``(event, tracker.filters dict) -> bool``. Returning False
# drops the event for that tracker before any DB writes happen. Callers pass
# provider-specific logic (Gitea sender allowlist, HA entity glob, etc.).
FilterFn = Callable[[ServiceEvent, dict[str, Any]], bool]
async def dispatch_provider_event(
engine: Any,
provider_id: int,
provider_name: str,
provider_config: dict[str, Any],
event: ServiceEvent,
detail_keys: tuple[str, ...],
filter_fn: FilterFn,
) -> int:
"""Load matching trackers, log, gate, defer, and dispatch one event.
Parameters
----------
engine:
SQLAlchemy async engine.
provider_id:
ID of the :class:`ServiceProvider` the event came from.
provider_name:
Human-readable name (for logging only).
provider_config:
``ServiceProvider.config`` dict; flowed into :class:`TargetConfig`.
event:
Parsed :class:`ServiceEvent` to dispatch.
detail_keys:
Keys from ``event.extra`` to copy into ``EventLog.details``.
filter_fn:
Per-event tracker-level filter. Returning False drops the event for
that tracker before any DB writes.
Returns
-------
int
Number of successfully dispatched notifications across all trackers.
"""
dispatched = 0
# Drain-scheduling is best-effort: a scheduling failure must not roll
# back the persisted defer rows (startup catch-up re-establishes them).
defers_to_schedule: set[Any] = set()
async with AsyncSession(engine) as session:
# App timezone is identical across trackers in one inbound event;
# pull it once.
app_tz = await get_app_timezone(session)
tracker_result = await session.exec(
select(NotificationTracker).where(
NotificationTracker.provider_id == provider_id,
NotificationTracker.enabled == True, # noqa: E712
)
)
trackers = tracker_result.all()
for tracker in trackers:
filters = tracker.filters or {}
if not filter_fn(event, filters):
_LOGGER.debug(
"Event filtered out for tracker %d (%s)", tracker.id, tracker.name
)
continue
link_data = await load_link_data(session, tracker.id)
if not link_data:
continue
extra_details = {k: v for k, v in event.extra.items() if k in detail_keys}
event_log_row = EventLog(
user_id=tracker.user_id,
tracker_id=tracker.id,
tracker_name=tracker.name,
provider_id=provider_id,
provider_name=provider_name,
event_type=event.event_type.value,
collection_id=event.collection_id,
collection_name=event.collection_name,
assets_count=0,
details={
"provider_type": event.provider_type.value,
**extra_details,
},
)
session.add(event_log_row)
await session.flush()
event_log_id = event_log_row.id
# Dedupe defers by parent link_id: broadcast links emit one
# link_data entry per child, sharing the same parent id — the
# deferred row is one-per-link, so we call defer_event only
# once per distinct id (earliest fire_at wins on ties).
groups: dict[int, tuple[Any, list[TargetConfig]]] = {}
defers_for_event: dict[int, Any] = {}
for ld in link_data:
tc = ld["tracking_config"]
if tc is not None:
outcome = evaluate_event_gate(event, tc, app_tz)
if outcome.reason is GateReason.QUIET_HOURS:
if (
is_deferrable(event.event_type.value)
and outcome.quiet_hours_end_at is not None
):
link_id = ld.get("link_id")
if link_id is not None:
prior = defers_for_event.get(link_id)
if prior is None or outcome.quiet_hours_end_at < prior:
defers_for_event[link_id] = outcome.quiet_hours_end_at
continue
if outcome.reason is GateReason.EVENT_TYPE_DISABLED:
continue
tmpl = ld["template_config"]
target_cfg = TargetConfig(
type=ld["target_type"],
config=ld["target_config"],
template_slots=ld["template_slots"],
date_format=tmpl.date_format if tmpl else "%d.%m.%Y, %H:%M UTC",
date_only_format=(
tmpl.date_only_format
if tmpl and tmpl.date_only_format
else "%d.%m.%Y"
),
provider_api_key=provider_config.get("api_token"),
provider_internal_url=provider_config.get("url", ""),
provider_external_url=provider_config.get("url", ""),
receivers=ld["receivers"],
)
key = id(tc) if tc is not None else 0
if key not in groups:
groups[key] = (tc, [])
groups[key][1].append(target_cfg)
# Persist defers + stamp event_log dispatch_status in the same
# session that holds the EventLog row, so the "deferred" badge
# only appears if the underlying queue rows actually exist.
if defers_for_event:
earliest = min(defers_for_event.values())
for link_id, fire_at in defers_for_event.items():
await defer_event(
session,
event=event,
user_id=tracker.user_id,
tracker_id=tracker.id,
link_id=link_id,
event_log_id=event_log_id,
fire_at=fire_at,
)
details = dict(event_log_row.details or {})
if not details.get("dispatch_status"):
details["dispatch_status"] = "deferred"
details["deferred_until"] = earliest.isoformat()
event_log_row.details = details
session.add(event_log_row)
defers_to_schedule.update(defers_for_event.values())
# Dispatch to targets. Isolate dispatcher exceptions per group so
# a failed remote call doesn't bubble out, abort the surrounding
# transaction, and roll back the just-written defers / event_log.
from .http_session import get_http_session
dispatcher = NotificationDispatcher(session=await get_http_session())
for tc, target_configs in groups.values():
if not target_configs:
continue
shaped_event = apply_tracking_display_filters(event, tc)
if shaped_event is None:
continue
try:
results = await dispatcher.dispatch(shaped_event, target_configs)
except Exception as err: # noqa: BLE001
_LOGGER.exception(
"Dispatcher raised for tracker %d: %s", tracker.id, err,
)
continue
for r in results:
if r.get("success"):
dispatched += 1
else:
_LOGGER.error(
"Notification failed for tracker %d: %s",
tracker.id, r.get("error", "unknown"),
)
await session.commit()
# Schedule drain jobs OUTSIDE the DB session so an APScheduler hiccup
# can't roll back the persisted defer rows.
if defers_to_schedule:
from .scheduler import schedule_deferred_drain
for fire_at in defers_to_schedule:
try:
schedule_deferred_drain(fire_at)
except Exception: # noqa: BLE001
_LOGGER.exception(
"Failed to schedule deferred drain for %s", fire_at,
)
return dispatched
@@ -0,0 +1,293 @@
"""Home Assistant subscription manager.
Phase 1 of the HA provider lives here. For every enabled ``home_assistant``
:class:`ServiceProvider` row in the DB, this module spawns one long-running
asyncio task that:
1. Builds an :class:`HomeAssistantServiceProvider` from the provider row.
2. Calls ``provider.subscribe(emit)`` which loops forever connect,
authenticate, subscribe, drain events through ``emit`` and reconnects
with exponential backoff on any drop.
3. Each ``emit`` call hands the parsed :class:`ServiceEvent` to
:func:`dispatch_provider_event` (the shared dispatch helper that webhook
providers also use), so quiet hours, deferred dispatch, and event-log
writes all behave identically to the rest of the system.
Lifecycle is owned by ``main.py`` via :func:`start_all` and :func:`stop_all`.
Phase 1 does not reconcile against DB changes after boot adding,
modifying, or removing a HA provider requires a server restart. Phase 1.5
will add a CRUD-triggered :func:`reload_provider` hook.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from notify_bridge_core.models.events import ServiceEvent
from notify_bridge_core.providers.home_assistant import (
HomeAssistantAuthError,
HomeAssistantServiceProvider,
)
from ..database.engine import get_engine
from ..database.models import ServiceProvider
from .event_dispatch import dispatch_provider_event
from .http_session import get_http_session
_LOGGER = logging.getLogger(__name__)
# Per-provider running task. Keyed by provider_id so reload_provider() can
# find and replace a single task without disturbing the rest.
_running_tasks: dict[int, asyncio.Task[None]] = {}
# Keys from ``event.extra`` to copy into ``EventLog.details``. Anything not
# in this list is still available to templates via the merged extras, but
# the event-log row stays slim.
_HA_DETAIL_KEYS: tuple[str, ...] = (
"entity_id",
"friendly_name",
"domain",
"old_state",
"new_state",
"device_class",
"unit_of_measurement",
"area",
"ha_event_type",
"automation_name",
"service_called",
"target_entity",
)
def _ha_passes_filters(event: ServiceEvent, filters: dict[str, Any]) -> bool:
"""HA-specific tracker filter.
Three filter keys, all optional, evaluated as a union: if the entity
matches any one of them, the event passes. Empty filters mean "accept
everything" — different from the Gitea filter which is an intersection.
Filter shape:
* ``collections`` list of exact ``entity_id`` matches.
* ``entity_glob`` list of glob patterns (``light.*``, ``*_motion``).
* ``domain_allowlist`` list of HA domain prefixes (``light``).
"""
collections = filters.get("collections") or []
entity_globs = filters.get("entity_glob") or []
domain_allowlist = filters.get("domain_allowlist") or []
# No filters configured = accept everything.
if not collections and not entity_globs and not domain_allowlist:
return True
entity_id = event.collection_id
domain = event.extra.get("domain") or (
entity_id.split(".", 1)[0] if "." in entity_id else ""
)
if collections and entity_id in collections:
return True
if domain_allowlist and domain in domain_allowlist:
return True
if entity_globs:
from fnmatch import fnmatchcase
for pattern in entity_globs:
if isinstance(pattern, str) and fnmatchcase(entity_id, pattern):
return True
return False
async def _run_provider(provider_id: int) -> None:
"""One per-provider supervisor loop.
Reloads the provider row each iteration so config changes (URL, token,
event types) take effect on the next reconnect cycle no need for a
full restart in the simple case where only credentials changed.
The ``_emit`` closure is rebuilt every iteration. Its lifetime equals
one ``subscribe()`` call: the callback only runs while the HA client's
drain task is alive. ``provider_name`` is snapshotted at the start of
each (re)connect cycle, so renames take effect on the next reconnect
chatty enough for v1; revisit if longer-lived WS sessions need fresher
names mid-stream.
"""
assert provider_id is not None, "_run_provider requires a real provider id"
engine = get_engine()
while True:
try:
async with AsyncSession(engine) as session:
row = await session.get(ServiceProvider, provider_id)
if row is None or row.type != "home_assistant":
_LOGGER.info(
"HA provider %s removed or retyped, stopping supervisor",
provider_id,
)
return
config = dict(row.config or {})
provider_name = row.name
url = config.get("url", "")
access_token = config.get("access_token", "")
verify_tls = bool(config.get("verify_tls", True))
event_types = config.get("event_types") or None
if not url or not access_token:
_LOGGER.warning(
"HA provider %s missing url or access_token; retrying in 60s",
provider_id,
)
await asyncio.sleep(60)
continue
session_http = await get_http_session()
ha_provider = HomeAssistantServiceProvider(
session=session_http,
url=url,
access_token=access_token,
verify_tls=verify_tls,
event_types=event_types,
name=provider_name,
)
async def _emit(event: ServiceEvent) -> None:
# Shield the DB-writing dispatch from external cancellation
# (shutdown, supervisor restart). The shield ensures that
# once a transaction is mid-flight, it commits or rolls back
# cleanly instead of being torn down with the asyncio task
# at a write boundary. Worst case: shutdown waits up to one
# dispatch latency longer.
#
# Perf note (Phase 2 follow-up): dispatch_provider_event
# opens a fresh AsyncSession per call. For HA's chatty
# state_changed bus this hammers the pool; batch in a
# follow-up.
try:
await asyncio.shield(dispatch_provider_event(
engine=engine,
provider_id=provider_id,
provider_name=provider_name,
provider_config=config,
event=event,
detail_keys=_HA_DETAIL_KEYS,
filter_fn=_ha_passes_filters,
))
except asyncio.CancelledError:
# Shield re-raises CancelledError to the caller; let it
# propagate so the drain task exits cleanly.
raise
except Exception: # noqa: BLE001
_LOGGER.exception(
"Failed to dispatch HA event for provider %s",
provider_id,
)
_LOGGER.info(
"Starting HA subscription for provider %s (%s)",
provider_id, provider_name,
)
await ha_provider.subscribe(_emit)
except asyncio.CancelledError:
raise
except HomeAssistantAuthError as err:
# Fatal at the provider level — bad token. Sleep long and retry
# so the user has time to fix the token without us hammering HA.
# Error string already redacted by the client before re-raise.
_LOGGER.error(
"HA provider %s auth failed: %s — retrying in 5 minutes",
provider_id, err,
)
await asyncio.sleep(300)
except Exception: # noqa: BLE001
_LOGGER.exception(
"HA supervisor for provider %s crashed; restarting in 30s",
provider_id,
)
await asyncio.sleep(30)
def _make_done_callback(provider_id: int):
"""Return a done-callback that prunes the task from ``_running_tasks``.
Without this, finished supervisors (provider deleted, fatal auth error
after long sleep) leave stale entries in the dict across many
reload cycles the dict would grow unboundedly. The callback is
registered on every task spawned via ``start_all`` / ``reload_provider``.
"""
def _cb(task: asyncio.Task[None]) -> None:
current = _running_tasks.get(provider_id)
if current is task:
_running_tasks.pop(provider_id, None)
return _cb
async def start_all() -> None:
"""Start a supervisor task for every enabled HA provider."""
engine = get_engine()
async with AsyncSession(engine) as session:
result = await session.exec(
select(ServiceProvider).where(
ServiceProvider.type == "home_assistant",
)
)
providers = result.all()
for prov in providers:
if prov.id in _running_tasks and not _running_tasks[prov.id].done():
continue
task = asyncio.create_task(
_run_provider(prov.id),
name=f"ha-subscription-{prov.id}",
)
task.add_done_callback(_make_done_callback(prov.id))
_running_tasks[prov.id] = task
if providers:
_LOGGER.info(
"Started HA subscription manager: %d provider(s)", len(providers),
)
async def stop_all() -> None:
"""Cancel every HA supervisor task and wait for clean shutdown."""
if not _running_tasks:
return
for task in _running_tasks.values():
task.cancel()
# Wait for all to drain; swallow cancellation errors.
await asyncio.gather(*_running_tasks.values(), return_exceptions=True)
_running_tasks.clear()
_LOGGER.info("Stopped all HA subscription supervisors")
async def reload_provider(provider_id: int) -> None:
"""Stop and restart the supervisor for a single provider id.
Hook for the provider CRUD routes Phase 1.5 will wire it in. For Phase
1, configure-then-restart-backend is the supported flow.
"""
task = _running_tasks.pop(provider_id, None)
if task is not None:
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception): # noqa: BLE001
pass
engine = get_engine()
async with AsyncSession(engine) as session:
prov = await session.get(ServiceProvider, provider_id)
if prov is None or prov.type != "home_assistant":
return
new_task = asyncio.create_task(
_run_provider(provider_id),
name=f"ha-subscription-{provider_id}",
)
new_task.add_done_callback(_make_done_callback(provider_id))
_running_tasks[provider_id] = new_task
@@ -213,4 +213,25 @@ _SAMPLE_CONTEXT = {
"raw_payload": {"action": "opened", "issue": {"title": "Bug report", "number": 1}, "sender": {"login": "user1"}}, "raw_payload": {"action": "opened", "issue": {"title": "Bug report", "number": 1}, "sender": {"login": "user1"}},
"event_type_raw": "webhook_received", "event_type_raw": "webhook_received",
"source_ip": "192.168.1.100", "source_ip": "192.168.1.100",
# Home Assistant variables (for home_assistant provider templates)
"friendly_name": "Front Door",
"entity_id": "binary_sensor.front_door",
"domain": "binary_sensor",
"old_state": "off",
"new_state": "on",
"attributes": {"friendly_name": "Front Door", "device_class": "door"},
"device_class": "door",
"unit_of_measurement": "",
"area": "Entrance",
"last_changed": "2026-05-13T12:34:56.789+00:00",
"last_updated": "2026-05-13T12:34:56.789+00:00",
"automation_name": "Front door notification",
"trigger_source": "state of binary_sensor.front_door",
"service_called": "light.turn_on",
"service_domain": "light",
"service_name": "turn_on",
"service_data": {"entity_id": "light.kitchen"},
"target_entity": "light.kitchen",
"ha_event_type": "state_changed",
"event_data": {"foo": "bar"},
} }
@@ -0,0 +1,98 @@
"""Unit tests for HA bot command helpers — Phase 2.
Focus on the security-sensitive bits the reviewer flagged: attribute
filtering, error-message redaction, and the sample-context shape that
flows through Jinja preview rendering.
"""
from __future__ import annotations
from notify_bridge_server.commands.home_assistant_handler import (
_filter_attributes,
_is_sensitive_attr,
_normalize_state,
)
def test_filter_attributes_drops_credential_keys() -> None:
"""HA camera entities expose an ``access_token`` attribute. The handler
MUST NOT surface it to the chat user via /state."""
raw = {
"friendly_name": "Front Camera",
"access_token": "real-camera-proxy-token",
"entity_picture": "/api/camera_proxy/...?token=abc",
"brightness": 200,
}
safe, hidden = _filter_attributes(raw)
assert "access_token" not in safe
# entity_picture contains 'token' substring → blocked.
assert "entity_picture" not in safe
# friendly_name is rendered as a top-level field, not iterated.
assert "friendly_name" not in safe
# brightness is a normal attribute, passes through.
assert safe["brightness"] == 200
assert hidden == 2
def test_filter_attributes_caps_count() -> None:
"""When an entity has dozens of attributes the renderer would overflow
Telegram's 4096-char message limit. Cap at 30 with overflow surfaced."""
raw = {f"attr_{i:03d}": i for i in range(50)}
safe, hidden = _filter_attributes(raw)
assert len(safe) == 30
assert hidden == 20
def test_is_sensitive_attr_case_insensitive() -> None:
"""Match should not depend on key casing — custom integrations are
inconsistent about capitalization."""
assert _is_sensitive_attr("Access_Token") is True
assert _is_sensitive_attr("API_KEY") is True
assert _is_sensitive_attr("password") is True
assert _is_sensitive_attr("brightness") is False
assert _is_sensitive_attr("color_mode") is False
def test_normalize_state_filters_attrs() -> None:
"""End-to-end: feed _normalize_state a malicious state row, verify the
output has redacted attributes + hidden_attr_count surfaced."""
state_row = {
"entity_id": "camera.front_door",
"state": "idle",
"attributes": {
"friendly_name": "Front Door Camera",
"access_token": "leaked",
"brand": "Reolink",
},
"last_changed": "2026-05-13T12:00:00+00:00",
"last_updated": "2026-05-13T12:00:00+00:00",
}
out = _normalize_state(state_row)
assert out["entity_id"] == "camera.front_door"
assert out["friendly_name"] == "Front Door Camera"
assert out["domain"] == "camera"
# Top-level fields preserved.
assert out["state"] == "idle"
# Attributes dict is filtered.
assert "access_token" not in out["attributes"]
assert out["attributes"].get("brand") == "Reolink"
# Hidden count reflects access_token (friendly_name is top-level, not redacted).
assert out["hidden_attr_count"] == 1
def test_normalize_state_handles_missing_attributes() -> None:
"""A state row with no attributes dict should not crash."""
out = _normalize_state({"entity_id": "sensor.x", "state": "1"})
assert out["attributes"] == {}
assert out["hidden_attr_count"] == 0
def test_redact_ha_message_strips_userinfo() -> None:
"""The Phase 1 redact helper is re-exported via the HA package and used
by /entities, /state, /areas before surfacing errors. Make sure the
re-export still works and the contract is what we expect."""
from notify_bridge_core.providers.home_assistant import redact_ha_message
msg = "Cannot connect to https://leak-token@homeassistant.local:8123/api/websocket"
out = redact_ha_message(msg)
assert "leak-token@" not in out
assert "homeassistant.local:8123" in out
@@ -0,0 +1,80 @@
"""Tests for the HA-specific tracker filter (entity_glob, domain_allowlist).
The Gitea filter is an intersection of senders/collections. The HA filter
is intentionally a *union* across the three keys any match passes so a
user can mix exact entity ids with glob patterns and domain allowlists
without each one narrowing the others.
"""
from __future__ import annotations
from datetime import datetime, timezone
from notify_bridge_core.models.events import EventType, ServiceEvent
from notify_bridge_core.providers.base import ServiceProviderType
from notify_bridge_server.services.ha_subscription import _ha_passes_filters
def _ha_event(entity_id: str, domain: str | None = None) -> ServiceEvent:
return ServiceEvent(
event_type=EventType.HA_STATE_CHANGED,
provider_type=ServiceProviderType.HOME_ASSISTANT,
provider_name="HA",
collection_id=entity_id,
collection_name=entity_id,
timestamp=datetime.now(timezone.utc),
extra={"domain": domain or (entity_id.split(".", 1)[0] if "." in entity_id else "")},
)
def test_empty_filters_accept_everything() -> None:
assert _ha_passes_filters(_ha_event("light.kitchen"), {}) is True
def test_exact_entity_match() -> None:
filters = {"collections": ["light.kitchen", "switch.lamp"]}
assert _ha_passes_filters(_ha_event("light.kitchen"), filters) is True
assert _ha_passes_filters(_ha_event("light.bedroom"), filters) is False
def test_entity_glob_match() -> None:
filters = {"entity_glob": ["binary_sensor.*_motion", "light.kitchen*"]}
assert _ha_passes_filters(_ha_event("binary_sensor.hallway_motion"), filters) is True
assert _ha_passes_filters(_ha_event("light.kitchen_main"), filters) is True
assert _ha_passes_filters(_ha_event("light.bedroom"), filters) is False
def test_domain_allowlist() -> None:
filters = {"domain_allowlist": ["light", "switch"]}
assert _ha_passes_filters(_ha_event("light.kitchen"), filters) is True
assert _ha_passes_filters(_ha_event("switch.lamp"), filters) is True
assert _ha_passes_filters(_ha_event("sensor.temp"), filters) is False
def test_union_across_keys() -> None:
"""If collections names a specific sensor.* but domain_allowlist names
'light', BOTH should be acceptable that's the difference from the
Gitea-style intersection filter."""
filters = {
"collections": ["sensor.outdoor_temp"],
"domain_allowlist": ["light"],
}
assert _ha_passes_filters(_ha_event("sensor.outdoor_temp"), filters) is True
assert _ha_passes_filters(_ha_event("light.kitchen"), filters) is True
# Neither matches:
assert _ha_passes_filters(_ha_event("binary_sensor.door"), filters) is False
def test_domain_derived_when_extra_missing() -> None:
"""If the parser didn't populate extra.domain (e.g. malformed event),
the filter must still infer it from the entity_id prefix."""
evt = ServiceEvent(
event_type=EventType.HA_STATE_CHANGED,
provider_type=ServiceProviderType.HOME_ASSISTANT,
provider_name="HA",
collection_id="light.kitchen",
collection_name="light.kitchen",
timestamp=datetime.now(timezone.utc),
extra={}, # No 'domain' key.
)
assert _ha_passes_filters(evt, {"domain_allowlist": ["light"]}) is True
@@ -0,0 +1,187 @@
"""Unit tests for the Home Assistant event parser.
These tests don't need a database or HA server — the parser is a pure
function from ``ha_event_dict`` to :class:`ServiceEvent`.
"""
from __future__ import annotations
from notify_bridge_core.models.events import EventType
from notify_bridge_core.providers.base import ServiceProviderType
from notify_bridge_core.providers.home_assistant.event_parser import parse_event
def _ha_event_envelope(event_type: str, data: dict) -> dict:
return {
"event_type": event_type,
"data": data,
"time_fired": "2026-05-13T12:34:56.789Z",
}
def test_state_changed_basic() -> None:
payload = _ha_event_envelope(
"state_changed",
{
"entity_id": "binary_sensor.front_door",
"old_state": {"state": "off", "attributes": {}},
"new_state": {
"state": "on",
"attributes": {
"friendly_name": "Front Door",
"device_class": "door",
},
},
},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.event_type is EventType.HA_STATE_CHANGED
assert evt.provider_type is ServiceProviderType.HOME_ASSISTANT
assert evt.collection_id == "binary_sensor.front_door"
assert evt.collection_name == "Front Door"
assert evt.extra["old_state"] == "off"
assert evt.extra["new_state"] == "on"
assert evt.extra["domain"] == "binary_sensor"
assert evt.extra["device_class"] == "door"
# Area was not provided in lookup -> None.
assert evt.extra["area"] is None
def test_state_changed_with_area_lookup() -> None:
payload = _ha_event_envelope(
"state_changed",
{
"entity_id": "light.kitchen",
"old_state": {"state": "off", "attributes": {}},
"new_state": {
"state": "on",
"attributes": {"friendly_name": "Kitchen Light"},
},
},
)
evt = parse_event(
payload,
provider_name="HA",
area_lookup={"light.kitchen": "Kitchen"},
)
assert evt is not None
assert evt.extra["area"] == "Kitchen"
def test_state_changed_entity_removed() -> None:
"""new_state=None means HA removed the entity. Surface as 'removed' so
templates can branch on it; collection_name falls back to old_state."""
payload = _ha_event_envelope(
"state_changed",
{
"entity_id": "sensor.dropped",
"old_state": {
"state": "online",
"attributes": {"friendly_name": "Dropped Sensor"},
},
"new_state": None,
},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.extra["new_state"] == "removed"
assert evt.collection_name == "Dropped Sensor"
def test_automation_triggered() -> None:
payload = _ha_event_envelope(
"automation_triggered",
{
"name": "Front door notification",
"entity_id": "automation.front_door_notify",
"source": "state of binary_sensor.front_door",
},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.event_type is EventType.HA_AUTOMATION_TRIGGERED
assert evt.collection_name == "Front door notification"
assert evt.collection_id == "automation.front_door_notify"
assert evt.extra["automation_name"] == "Front door notification"
assert evt.extra["trigger_source"] == "state of binary_sensor.front_door"
def test_call_service_with_target() -> None:
payload = _ha_event_envelope(
"call_service",
{
"domain": "light",
"service": "turn_on",
"service_data": {"entity_id": "light.kitchen"},
},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.event_type is EventType.HA_SERVICE_CALLED
assert evt.collection_id == "light.turn_on"
assert evt.extra["target_entity"] == "light.kitchen"
assert evt.extra["service_domain"] == "light"
assert evt.extra["service_name"] == "turn_on"
def test_call_service_with_multi_target() -> None:
"""When the call hits multiple entities, the parser comma-joins them
so templates can render ``{{ target_entity }}`` without iterating."""
payload = _ha_event_envelope(
"call_service",
{
"domain": "light",
"service": "turn_off",
"service_data": {
"entity_id": ["light.kitchen", "light.living_room"],
},
},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.extra["target_entity"] == "light.kitchen, light.living_room"
def test_generic_event_fallback() -> None:
"""Any event_type not in the known set becomes ha_event_fired with the
raw event_type stashed in extras so loud catch-all subscriptions work."""
payload = _ha_event_envelope(
"custom_event_xyz",
{"foo": "bar"},
)
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.event_type is EventType.HA_EVENT_FIRED
assert evt.extra["ha_event_type"] == "custom_event_xyz"
assert evt.extra["event_data"] == {"foo": "bar"}
def test_malformed_payload_returns_none() -> None:
assert parse_event({}, provider_name="HA") is None
assert parse_event("not a dict", provider_name="HA") is None # type: ignore[arg-type]
# state_changed without entity_id is unrecoverable
bad = _ha_event_envelope("state_changed", {"new_state": None})
assert parse_event(bad, provider_name="HA") is None
# call_service without domain/service is unrecoverable
bad2 = _ha_event_envelope("call_service", {"service": "turn_on"})
assert parse_event(bad2, provider_name="HA") is None
def test_time_fired_iso_with_z_suffix_parses() -> None:
"""HA uses ``Z`` suffix; older Python ``fromisoformat`` rejects it.
The parser must handle both forms or we'd lose the timestamp."""
from datetime import timezone
payload = _ha_event_envelope(
"state_changed",
{
"entity_id": "sensor.temp",
"old_state": {"state": "20", "attributes": {}},
"new_state": {"state": "21", "attributes": {}},
},
)
payload["time_fired"] = "2026-05-13T12:34:56.789Z"
evt = parse_event(payload, provider_name="HA")
assert evt is not None
assert evt.timestamp.tzinfo is not None
assert evt.timestamp.utcoffset() == timezone.utc.utcoffset(None)
@@ -0,0 +1,193 @@
"""Tests for the HA WS session helper and slice-before-normalize path.
The reviewer flagged two perf-shaped concerns that we've now addressed:
1. ``/status`` and ``/areas`` previously opened 3 and 2 separate WS
connections respectively. With ``HomeAssistantSession`` they share one
socket these tests pin the contract.
2. ``/entities`` used to normalize every matching entity before slicing to
``count``. For HA installs with 1000+ entities this materialized 1000+
normalized dicts to throw most away. The optimization moves the slice
*before* normalize; this test exercises a 200-entity fixture and
verifies only the ``count`` survivors get normalized.
"""
from __future__ import annotations
import asyncio
from typing import Any
from unittest.mock import patch
import pytest
from notify_bridge_core.providers.home_assistant.client import HomeAssistantSession
from notify_bridge_server.commands import home_assistant_handler as handler
# ---------------------------------------------------------------------------
# Session class — surface contract
# ---------------------------------------------------------------------------
def test_session_class_has_expected_methods() -> None:
"""Anyone consuming ``HomeAssistantSession`` can rely on this surface."""
expected = {"send", "get_states", "get_area_registry", "get_entity_registry"}
actual = {name for name in dir(HomeAssistantSession) if not name.startswith("_")}
assert expected <= actual, f"missing: {expected - actual}"
@pytest.mark.asyncio
async def test_session_get_states_routes_through_send() -> None:
"""``get_states`` is a thin wrapper around ``send`` with the canonical payload."""
sent: list[dict[str, Any]] = []
class _FakeClient:
async def _send_command(self, ws: Any, payload: dict[str, Any]) -> int:
sent.append(payload)
return 1
async def _await_result(self, ws: Any, msg_id: int, timeout: float = 15.0) -> Any:
return [{"entity_id": "light.kitchen", "state": "on", "attributes": {}}]
sess = HomeAssistantSession(_FakeClient(), ws=object()) # type: ignore[arg-type]
result = await sess.get_states()
assert sent == [{"type": "get_states"}]
assert result == [{"entity_id": "light.kitchen", "state": "on", "attributes": {}}]
@pytest.mark.asyncio
async def test_session_methods_use_distinct_payloads() -> None:
"""Each session-scoped method sends the right HA command name."""
sent: list[dict[str, Any]] = []
class _FakeClient:
async def _send_command(self, ws: Any, payload: dict[str, Any]) -> int:
sent.append(payload)
return len(sent)
async def _await_result(self, ws: Any, msg_id: int, timeout: float = 15.0) -> Any:
return []
sess = HomeAssistantSession(_FakeClient(), ws=object()) # type: ignore[arg-type]
await sess.get_states()
await sess.get_area_registry()
await sess.get_entity_registry()
assert [p["type"] for p in sent] == [
"get_states",
"config/area_registry/list",
"config/entity_registry/list",
]
# ---------------------------------------------------------------------------
# slice-before-normalize — perf contract for /entities
# ---------------------------------------------------------------------------
class _FakeAsyncSession:
"""A fake HA session that returns a canned state list."""
def __init__(self, states: list[dict[str, Any]]) -> None:
self._states = states
async def get_states(self) -> list[dict[str, Any]]:
return self._states
class _FakeClient:
"""A fake client whose ``session()`` yields a ``_FakeAsyncSession``."""
def __init__(self, states: list[dict[str, Any]]) -> None:
self._states = states
def session(self): # noqa: D401 — mimics real client signature
states = self._states
class _CM:
async def __aenter__(self_inner):
return _FakeAsyncSession(states)
async def __aexit__(self_inner, *_exc):
return False
return _CM()
def _state_row(entity_id: str, n_attrs: int = 2) -> dict[str, Any]:
return {
"entity_id": entity_id,
"state": "on",
"attributes": {f"attr_{i}": i for i in range(n_attrs)},
}
@pytest.mark.asyncio
async def test_cmd_entities_slices_before_normalizing(monkeypatch: pytest.MonkeyPatch) -> None:
"""200 raw entities, count=10. Normalize must run only 10 times.
We instrument ``_normalize_state`` with a counter to prove the slice
happens before the per-row transform. The total field still reports
all 200 so the user knows the result is truncated.
"""
states = [_state_row(f"light.bulb_{i:03d}") for i in range(200)]
fake_client = _FakeClient(states)
monkeypatch.setattr(handler, "_make_ws_client", lambda provider, session: fake_client)
calls = {"count": 0}
real_normalize = handler._normalize_state
def _counting_normalize(row: dict[str, Any]) -> dict[str, Any]:
calls["count"] += 1
return real_normalize(row)
monkeypatch.setattr(handler, "_normalize_state", _counting_normalize)
# ``get_http_session`` opens a real aiohttp session in the bg; bypass
# it since our fake client never uses the session arg.
async def _fake_http_session() -> Any:
return None
monkeypatch.setattr(handler, "get_http_session", _fake_http_session)
provider = type("FakeProvider", (), {"config": {}, "name": "HA"})()
result = await handler._cmd_entities(provider, args="", count=10)
assert result["total"] == 200
assert result["shown"] == 10
assert len(result["entities"]) == 10
assert calls["count"] == 10, (
f"normalize should run once per survivor; ran {calls['count']} times"
)
@pytest.mark.asyncio
async def test_cmd_entities_glob_filter_still_normalizes_only_survivors(monkeypatch: pytest.MonkeyPatch) -> None:
"""200 raw entities mixed across 2 domains; glob narrows to one.
Normalize count = min(count, matching_total). Demonstrates the
optimization composes with the filter step.
"""
states = [
_state_row(f"light.bulb_{i:03d}") for i in range(100)
] + [
_state_row(f"sensor.temp_{i:03d}") for i in range(100)
]
fake_client = _FakeClient(states)
monkeypatch.setattr(handler, "_make_ws_client", lambda provider, session: fake_client)
calls = {"count": 0}
real_normalize = handler._normalize_state
def _counting_normalize(row: dict[str, Any]) -> dict[str, Any]:
calls["count"] += 1
return real_normalize(row)
monkeypatch.setattr(handler, "_normalize_state", _counting_normalize)
async def _fake_http_session() -> Any:
return None
monkeypatch.setattr(handler, "get_http_session", _fake_http_session)
provider = type("FakeProvider", (), {"config": {}, "name": "HA"})()
result = await handler._cmd_entities(provider, args="light.*", count=5)
assert result["total"] == 100 # all light.* entities counted
assert result["shown"] == 5 # but only 5 normalized
assert calls["count"] == 5
assert all(e["entity_id"].startswith("light.") for e in result["entities"])