"""Tests for the HA WS session helper and slice-before-normalize path. The reviewer flagged two perf-shaped concerns that we've now addressed: 1. ``/status`` and ``/areas`` previously opened 3 and 2 separate WS connections respectively. With ``HomeAssistantSession`` they share one socket — these tests pin the contract. 2. ``/entities`` used to normalize every matching entity before slicing to ``count``. For HA installs with 1000+ entities this materialized 1000+ normalized dicts to throw most away. The optimization moves the slice *before* normalize; this test exercises a 200-entity fixture and verifies only the ``count`` survivors get normalized. """ from __future__ import annotations import asyncio from typing import Any from unittest.mock import patch import pytest from notify_bridge_core.providers.home_assistant.client import HomeAssistantSession from notify_bridge_server.commands import home_assistant_handler as handler # --------------------------------------------------------------------------- # Session class — surface contract # --------------------------------------------------------------------------- def test_session_class_has_expected_methods() -> None: """Anyone consuming ``HomeAssistantSession`` can rely on this surface.""" expected = {"send", "get_states", "get_area_registry", "get_entity_registry"} actual = {name for name in dir(HomeAssistantSession) if not name.startswith("_")} assert expected <= actual, f"missing: {expected - actual}" @pytest.mark.asyncio async def test_session_get_states_routes_through_send() -> None: """``get_states`` is a thin wrapper around ``send`` with the canonical payload.""" sent: list[dict[str, Any]] = [] class _FakeClient: async def _send_command(self, ws: Any, payload: dict[str, Any]) -> int: sent.append(payload) return 1 async def _await_result(self, ws: Any, msg_id: int, timeout: float = 15.0) -> Any: return [{"entity_id": "light.kitchen", "state": "on", "attributes": {}}] sess = HomeAssistantSession(_FakeClient(), ws=object()) # type: ignore[arg-type] result = await sess.get_states() assert sent == [{"type": "get_states"}] assert result == [{"entity_id": "light.kitchen", "state": "on", "attributes": {}}] @pytest.mark.asyncio async def test_session_methods_use_distinct_payloads() -> None: """Each session-scoped method sends the right HA command name.""" sent: list[dict[str, Any]] = [] class _FakeClient: async def _send_command(self, ws: Any, payload: dict[str, Any]) -> int: sent.append(payload) return len(sent) async def _await_result(self, ws: Any, msg_id: int, timeout: float = 15.0) -> Any: return [] sess = HomeAssistantSession(_FakeClient(), ws=object()) # type: ignore[arg-type] await sess.get_states() await sess.get_area_registry() await sess.get_entity_registry() assert [p["type"] for p in sent] == [ "get_states", "config/area_registry/list", "config/entity_registry/list", ] # --------------------------------------------------------------------------- # slice-before-normalize — perf contract for /entities # --------------------------------------------------------------------------- class _FakeAsyncSession: """A fake HA session that returns a canned state list.""" def __init__(self, states: list[dict[str, Any]]) -> None: self._states = states async def get_states(self) -> list[dict[str, Any]]: return self._states class _FakeClient: """A fake client whose ``session()`` yields a ``_FakeAsyncSession``.""" def __init__(self, states: list[dict[str, Any]]) -> None: self._states = states def session(self): # noqa: D401 — mimics real client signature states = self._states class _CM: async def __aenter__(self_inner): return _FakeAsyncSession(states) async def __aexit__(self_inner, *_exc): return False return _CM() def _state_row(entity_id: str, n_attrs: int = 2) -> dict[str, Any]: return { "entity_id": entity_id, "state": "on", "attributes": {f"attr_{i}": i for i in range(n_attrs)}, } @pytest.mark.asyncio async def test_cmd_entities_slices_before_normalizing(monkeypatch: pytest.MonkeyPatch) -> None: """200 raw entities, count=10. Normalize must run only 10 times. We instrument ``_normalize_state`` with a counter to prove the slice happens before the per-row transform. The total field still reports all 200 so the user knows the result is truncated. """ states = [_state_row(f"light.bulb_{i:03d}") for i in range(200)] fake_client = _FakeClient(states) monkeypatch.setattr(handler, "_make_ws_client", lambda provider, session: fake_client) calls = {"count": 0} real_normalize = handler._normalize_state def _counting_normalize(row: dict[str, Any]) -> dict[str, Any]: calls["count"] += 1 return real_normalize(row) monkeypatch.setattr(handler, "_normalize_state", _counting_normalize) # ``get_http_session`` opens a real aiohttp session in the bg; bypass # it since our fake client never uses the session arg. async def _fake_http_session() -> Any: return None monkeypatch.setattr(handler, "get_http_session", _fake_http_session) provider = type("FakeProvider", (), {"config": {}, "name": "HA"})() result = await handler._cmd_entities(provider, args="", count=10) assert result["total"] == 200 assert result["shown"] == 10 assert len(result["entities"]) == 10 assert calls["count"] == 10, ( f"normalize should run once per survivor; ran {calls['count']} times" ) @pytest.mark.asyncio async def test_cmd_entities_glob_filter_still_normalizes_only_survivors(monkeypatch: pytest.MonkeyPatch) -> None: """200 raw entities mixed across 2 domains; glob narrows to one. Normalize count = min(count, matching_total). Demonstrates the optimization composes with the filter step. """ states = [ _state_row(f"light.bulb_{i:03d}") for i in range(100) ] + [ _state_row(f"sensor.temp_{i:03d}") for i in range(100) ] fake_client = _FakeClient(states) monkeypatch.setattr(handler, "_make_ws_client", lambda provider, session: fake_client) calls = {"count": 0} real_normalize = handler._normalize_state def _counting_normalize(row: dict[str, Any]) -> dict[str, Any]: calls["count"] += 1 return real_normalize(row) monkeypatch.setattr(handler, "_normalize_state", _counting_normalize) async def _fake_http_session() -> Any: return None monkeypatch.setattr(handler, "get_http_session", _fake_http_session) provider = type("FakeProvider", (), {"config": {}, "name": "HA"})() result = await handler._cmd_entities(provider, args="light.*", count=5) assert result["total"] == 100 # all light.* entities counted assert result["shown"] == 5 # but only 5 normalized assert calls["count"] == 5 assert all(e["entity_id"].startswith("light.") for e in result["entities"])