22127e2a59
Adds Home Assistant as a service provider with two coordinated surfaces: Notifications (subscription): - Long-lived WebSocket client (aiohttp ws_connect) with auth handshake, exponential-backoff reconnect, bounded event queue, and area-registry enrichment cached per (re)connect - ServiceProvider ABC gains an optional `subscribe()` method for push-style providers; HomeAssistantServiceProvider uses it via a per-provider supervisor task started in the FastAPI lifespan - 4 event types (state_changed, automation_triggered, call_service, event_fired), 4 default Jinja templates (en + ru), HA-specific tracker filters (entity_glob, domain_allowlist, exact entity ids) - Extracted shared dispatch pipeline (api/webhooks.py → services/ event_dispatch.py) so subscription and webhook ingest share the same event_log + deferred-dispatch + quiet-hours code path Bot commands: - /status, /entities [glob], /state <entity_id>, /areas - Multi-command WS session so /status and /areas cost one handshake - Sensitive-attribute blocklist (camera access_token, entity_picture, etc.) and 30-attribute cap to keep /state output safe and within Telegram's message size - Error-message redaction strips URL userinfo before surfacing to chat Frontend: - HA descriptor with toggle ConfigField type (new) and tag-input filter mode for free-text glob/domain lists (new TagInput component) - 15 command slots + 4 notification slots wired into the existing template-config UI
194 lines
7.1 KiB
Python
194 lines
7.1 KiB
Python
"""Tests for the HA WS session helper and slice-before-normalize path.
|
|
|
|
The reviewer flagged two perf-shaped concerns that we've now addressed:
|
|
|
|
1. ``/status`` and ``/areas`` previously opened 3 and 2 separate WS
|
|
connections respectively. With ``HomeAssistantSession`` they share one
|
|
socket — these tests pin the contract.
|
|
2. ``/entities`` used to normalize every matching entity before slicing to
|
|
``count``. For HA installs with 1000+ entities this materialized 1000+
|
|
normalized dicts to throw most away. The optimization moves the slice
|
|
*before* normalize; this test exercises a 200-entity fixture and
|
|
verifies only the ``count`` survivors get normalized.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import Any
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from notify_bridge_core.providers.home_assistant.client import HomeAssistantSession
|
|
from notify_bridge_server.commands import home_assistant_handler as handler
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session class — surface contract
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_session_class_has_expected_methods() -> None:
|
|
"""Anyone consuming ``HomeAssistantSession`` can rely on this surface."""
|
|
expected = {"send", "get_states", "get_area_registry", "get_entity_registry"}
|
|
actual = {name for name in dir(HomeAssistantSession) if not name.startswith("_")}
|
|
assert expected <= actual, f"missing: {expected - actual}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_get_states_routes_through_send() -> None:
|
|
"""``get_states`` is a thin wrapper around ``send`` with the canonical payload."""
|
|
sent: list[dict[str, Any]] = []
|
|
|
|
class _FakeClient:
|
|
async def _send_command(self, ws: Any, payload: dict[str, Any]) -> int:
|
|
sent.append(payload)
|
|
return 1
|
|
|
|
async def _await_result(self, ws: Any, msg_id: int, timeout: float = 15.0) -> Any:
|
|
return [{"entity_id": "light.kitchen", "state": "on", "attributes": {}}]
|
|
|
|
sess = HomeAssistantSession(_FakeClient(), ws=object()) # type: ignore[arg-type]
|
|
result = await sess.get_states()
|
|
assert sent == [{"type": "get_states"}]
|
|
assert result == [{"entity_id": "light.kitchen", "state": "on", "attributes": {}}]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_methods_use_distinct_payloads() -> None:
|
|
"""Each session-scoped method sends the right HA command name."""
|
|
sent: list[dict[str, Any]] = []
|
|
|
|
class _FakeClient:
|
|
async def _send_command(self, ws: Any, payload: dict[str, Any]) -> int:
|
|
sent.append(payload)
|
|
return len(sent)
|
|
|
|
async def _await_result(self, ws: Any, msg_id: int, timeout: float = 15.0) -> Any:
|
|
return []
|
|
|
|
sess = HomeAssistantSession(_FakeClient(), ws=object()) # type: ignore[arg-type]
|
|
await sess.get_states()
|
|
await sess.get_area_registry()
|
|
await sess.get_entity_registry()
|
|
assert [p["type"] for p in sent] == [
|
|
"get_states",
|
|
"config/area_registry/list",
|
|
"config/entity_registry/list",
|
|
]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# slice-before-normalize — perf contract for /entities
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _FakeAsyncSession:
|
|
"""A fake HA session that returns a canned state list."""
|
|
|
|
def __init__(self, states: list[dict[str, Any]]) -> None:
|
|
self._states = states
|
|
|
|
async def get_states(self) -> list[dict[str, Any]]:
|
|
return self._states
|
|
|
|
|
|
class _FakeClient:
|
|
"""A fake client whose ``session()`` yields a ``_FakeAsyncSession``."""
|
|
|
|
def __init__(self, states: list[dict[str, Any]]) -> None:
|
|
self._states = states
|
|
|
|
def session(self): # noqa: D401 — mimics real client signature
|
|
states = self._states
|
|
class _CM:
|
|
async def __aenter__(self_inner):
|
|
return _FakeAsyncSession(states)
|
|
async def __aexit__(self_inner, *_exc):
|
|
return False
|
|
return _CM()
|
|
|
|
|
|
def _state_row(entity_id: str, n_attrs: int = 2) -> dict[str, Any]:
|
|
return {
|
|
"entity_id": entity_id,
|
|
"state": "on",
|
|
"attributes": {f"attr_{i}": i for i in range(n_attrs)},
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cmd_entities_slices_before_normalizing(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""200 raw entities, count=10. Normalize must run only 10 times.
|
|
|
|
We instrument ``_normalize_state`` with a counter to prove the slice
|
|
happens before the per-row transform. The total field still reports
|
|
all 200 so the user knows the result is truncated.
|
|
"""
|
|
states = [_state_row(f"light.bulb_{i:03d}") for i in range(200)]
|
|
fake_client = _FakeClient(states)
|
|
monkeypatch.setattr(handler, "_make_ws_client", lambda provider, session: fake_client)
|
|
|
|
calls = {"count": 0}
|
|
real_normalize = handler._normalize_state
|
|
|
|
def _counting_normalize(row: dict[str, Any]) -> dict[str, Any]:
|
|
calls["count"] += 1
|
|
return real_normalize(row)
|
|
|
|
monkeypatch.setattr(handler, "_normalize_state", _counting_normalize)
|
|
|
|
# ``get_http_session`` opens a real aiohttp session in the bg; bypass
|
|
# it since our fake client never uses the session arg.
|
|
async def _fake_http_session() -> Any:
|
|
return None
|
|
|
|
monkeypatch.setattr(handler, "get_http_session", _fake_http_session)
|
|
|
|
provider = type("FakeProvider", (), {"config": {}, "name": "HA"})()
|
|
result = await handler._cmd_entities(provider, args="", count=10)
|
|
assert result["total"] == 200
|
|
assert result["shown"] == 10
|
|
assert len(result["entities"]) == 10
|
|
assert calls["count"] == 10, (
|
|
f"normalize should run once per survivor; ran {calls['count']} times"
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cmd_entities_glob_filter_still_normalizes_only_survivors(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
"""200 raw entities mixed across 2 domains; glob narrows to one.
|
|
|
|
Normalize count = min(count, matching_total). Demonstrates the
|
|
optimization composes with the filter step.
|
|
"""
|
|
states = [
|
|
_state_row(f"light.bulb_{i:03d}") for i in range(100)
|
|
] + [
|
|
_state_row(f"sensor.temp_{i:03d}") for i in range(100)
|
|
]
|
|
fake_client = _FakeClient(states)
|
|
monkeypatch.setattr(handler, "_make_ws_client", lambda provider, session: fake_client)
|
|
|
|
calls = {"count": 0}
|
|
real_normalize = handler._normalize_state
|
|
|
|
def _counting_normalize(row: dict[str, Any]) -> dict[str, Any]:
|
|
calls["count"] += 1
|
|
return real_normalize(row)
|
|
|
|
monkeypatch.setattr(handler, "_normalize_state", _counting_normalize)
|
|
|
|
async def _fake_http_session() -> Any:
|
|
return None
|
|
|
|
monkeypatch.setattr(handler, "get_http_session", _fake_http_session)
|
|
|
|
provider = type("FakeProvider", (), {"config": {}, "name": "HA"})()
|
|
result = await handler._cmd_entities(provider, args="light.*", count=5)
|
|
assert result["total"] == 100 # all light.* entities counted
|
|
assert result["shown"] == 5 # but only 5 normalized
|
|
assert calls["count"] == 5
|
|
assert all(e["entity_id"].startswith("light.") for e in result["entities"])
|