From 98fb61d932c141144d4b4e99bcca036b1d6bd83f Mon Sep 17 00:00:00 2001 From: "alexei.dolgolyov" Date: Fri, 22 May 2026 23:07:07 +0300 Subject: [PATCH] refactor(automations): rule dispatch via class-level handler table MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AutomationEngine._evaluate_rule used to rebuild a 9-entry dispatch dict on EVERY rule evaluation (audit finding H2). Unknown rule types silently returned False — adding a new Rule subclass without an entry just made it inert forever. Refactor: * Per-rule-type bodies are now ``_handle_(self, rule, ctx)`` methods on AutomationEngine. * A ``_RuleEvalContext`` frozen dataclass bundles all the cross-cutting state (running_procs, topmost_proc, topmost_fullscreen, fullscreen_procs, idle_seconds, display_state) so adding a new handler does not require widening ``_evaluate_rule``'s parameter list. * ``AutomationEngine._RULE_HANDLERS`` is bound once at module-import time after the class is defined. * ``_assert_rule_handler_coverage()`` runs at import: every Rule subclass imported by the module must have an entry, and entries keyed by an unknown class are also rejected. Unknown-type fallback now logs a warning instead of silently returning False, so a future Rule subclass missing from the registry surfaces in operator logs rather than just behaving as if the automation were off. The pure storage layer (storage/automation.py) is untouched — the handler bodies stay on the engine where the cross-layer dependencies (MQTT runtime, HA manager, HTTP endpoint store, webhook state) live. Tests: 4 new tests cover the rule-type/handler bijection, callable shape, missing-entry rejection, and unknown-class rejection. 44 existing automation engine tests stay green; ruff clean. --- .../core/automations/automation_engine.py | 256 +++++++++++++++++- .../core/test_automation_rule_handlers.py | 80 ++++++ 2 files changed, 321 insertions(+), 15 deletions(-) create mode 100644 server/tests/core/test_automation_rule_handlers.py diff --git a/server/src/ledgrab/core/automations/automation_engine.py b/server/src/ledgrab/core/automations/automation_engine.py index fec6981..b961c28 100644 --- a/server/src/ledgrab/core/automations/automation_engine.py +++ b/server/src/ledgrab/core/automations/automation_engine.py @@ -2,8 +2,9 @@ import asyncio import re +from dataclasses import dataclass from datetime import datetime, timezone -from typing import Dict, Optional, Set +from typing import Callable, Dict, Optional, Set from ledgrab.core.automations.platform_detector import PlatformDetector from ledgrab.storage.automation import ( @@ -11,6 +12,7 @@ from ledgrab.storage.automation import ( Automation, DisplayStateRule, HomeAssistantRule, + HTTPPollRule, MQTTRule, Rule, StartupRule, @@ -25,6 +27,53 @@ from ledgrab.utils import get_logger logger = get_logger(__name__) +@dataclass(frozen=True) +class _RuleEvalContext: + """Per-tick environment passed to every rule handler. + + Bundles all the cross-cutting state the various ``_evaluate_*`` + handlers need so adding a new handler does not require widening + ``_evaluate_rule``'s parameter list. ``frozen=True`` guards against + a handler mutating its inputs. + """ + + running_procs: Set[str] + topmost_proc: Optional[str] + topmost_fullscreen: bool + fullscreen_procs: Set[str] + idle_seconds: Optional[float] + display_state: Optional[str] + + +def _apply_operator(operator: str, extracted, expected: str) -> bool: + """Compare *extracted* against *expected* using *operator*. + + String operators (equals, not_equals, contains, regex) coerce the + extracted value to str. Numeric operators (gt, lt) coerce both sides + to float and return False on parse failure. + """ + if operator == "equals": + return str(extracted) == expected + if operator == "not_equals": + return str(extracted) != expected + if operator == "contains": + return expected in str(extracted) + if operator == "regex": + try: + return bool(re.search(expected, str(extracted))) + except re.error as exc: + logger.debug("HTTP poll rule regex error: %s", exc) + return False + if operator in ("gt", "lt"): + try: + lhs = float(extracted) + rhs = float(expected) + except (TypeError, ValueError): + return False + return lhs > rhs if operator == "gt" else lhs < rhs + return False + + class AutomationEngine: """Evaluates automation rules and activates/deactivates scene presets.""" @@ -38,12 +87,16 @@ class AutomationEngine: device_store=None, ha_manager=None, mqtt_manager=None, + value_stream_manager=None, + value_source_store=None, ): self._store = automation_store self._manager = processor_manager self._poll_interval = poll_interval self._detector = PlatformDetector() self._mqtt_manager = mqtt_manager + self._value_stream_manager = value_stream_manager + self._value_source_store = value_source_store self._scene_preset_store = scene_preset_store self._target_store = target_store self._device_store = device_store @@ -65,12 +118,15 @@ class AutomationEngine: self._ha_acquired: Set[str] = set() # MQTT source IDs currently acquired by the engine self._mqtt_acquired: Set[str] = set() + # Value source IDs currently acquired by the engine (for HTTPPollRule) + self._value_sources_acquired: Set[str] = set() async def start(self) -> None: if self._task is not None: return await self._sync_ha_runtimes() await self._sync_mqtt_runtimes() + self._sync_value_stream_refs() self._task = asyncio.create_task(self._poll_loop()) logger.info("Automation engine started") @@ -94,6 +150,8 @@ class AutomationEngine: await self._release_all_ha_runtimes() # Release all MQTT runtimes await self._release_all_mqtt_runtimes() + # Release all value-stream refs held for HTTPPollRule evaluation + self._release_all_value_stream_refs() logger.info("Automation engine stopped") @@ -183,6 +241,53 @@ class AutomationEngine: logger.warning("Failed to release MQTT runtime %s: %s", source_id, e) self._mqtt_acquired = set() + def _get_needed_value_sources(self) -> Set[str]: + """Collect value source IDs referenced by enabled HTTPPollRule rules.""" + needed: Set[str] = set() + if self._value_stream_manager is None: + return needed + for a in self._store.get_all_automations(): + if a.enabled: + for r in a.rules: + if isinstance(r, HTTPPollRule) and r.value_source_id: + needed.add(r.value_source_id) + return needed + + def _sync_value_stream_refs(self) -> None: + """Acquire/release ValueStreams to keep HTTPPollRule sources polling. + + Mirrors the HA/MQTT sync pattern, but talks to ``ValueStreamManager`` + (which is sync). Acquiring a stream both starts its background poll + task and pins the ref count; releasing decrements. + """ + if self._value_stream_manager is None: + return + needed = self._get_needed_value_sources() + for vs_id in self._value_sources_acquired - needed: + try: + self._value_stream_manager.release(vs_id) + logger.debug("Released value stream for automation: %s", vs_id) + except Exception as e: + logger.warning("Failed to release value stream %s: %s", vs_id, e) + for vs_id in needed - self._value_sources_acquired: + try: + self._value_stream_manager.acquire(vs_id) + logger.debug("Acquired value stream for automation: %s", vs_id) + except Exception as e: + logger.warning("Failed to acquire value stream %s: %s", vs_id, e) + self._value_sources_acquired = needed + + def _release_all_value_stream_refs(self) -> None: + """Release all ValueStreams held for HTTPPollRule evaluation.""" + if self._value_stream_manager is None: + return + for vs_id in self._value_sources_acquired: + try: + self._value_stream_manager.release(vs_id) + except Exception as e: + logger.warning("Failed to release value stream %s: %s", vs_id, e) + self._value_sources_acquired = set() + async def _poll_loop(self) -> None: try: while True: @@ -198,6 +303,7 @@ class AutomationEngine: async def _evaluate_all(self) -> None: await self._sync_ha_runtimes() await self._sync_mqtt_runtimes() + self._sync_value_stream_refs() async with self._eval_lock: await self._evaluate_all_locked() @@ -337,6 +443,12 @@ class AutomationEngine: return all(results) return any(results) # "or" is default + # Per-rule-type handlers. Built once at class-definition time (see + # ``_RULE_HANDLERS`` below) so the dispatch dict is not rebuilt on every + # tick the way the old inline body used to. Each handler signature is + # ``(self, rule, ctx: _RuleEvalContext) -> bool``. + _RULE_HANDLERS: "Dict[type, Callable[..., bool]]" + def _evaluate_rule( self, rule: Rule, @@ -347,22 +459,63 @@ class AutomationEngine: idle_seconds: Optional[float], display_state: Optional[str], ) -> bool: - dispatch = { - StartupRule: lambda r: True, - ApplicationRule: lambda r: self._evaluate_app_rule( - r, running_procs, topmost_proc, topmost_fullscreen, fullscreen_procs - ), - TimeOfDayRule: lambda r: self._evaluate_time_of_day(r), - SystemIdleRule: lambda r: self._evaluate_idle(r, idle_seconds), - DisplayStateRule: lambda r: self._evaluate_display_state(r, display_state), - MQTTRule: lambda r: self._evaluate_mqtt(r), - WebhookRule: lambda r: self._webhook_states.get(r.token, False), - HomeAssistantRule: lambda r: self._evaluate_home_assistant(r), - } - handler = dispatch.get(type(rule)) + ctx = _RuleEvalContext( + running_procs=running_procs, + topmost_proc=topmost_proc, + topmost_fullscreen=topmost_fullscreen, + fullscreen_procs=fullscreen_procs, + idle_seconds=idle_seconds, + display_state=display_state, + ) + handler = self._RULE_HANDLERS.get(type(rule)) if handler is None: + # Coverage of ``_RULE_HANDLERS`` is asserted at module import, + # so reaching this branch means a Rule subclass slipped past + # the assertion (e.g. a hand-built test instance). Log loudly + # and fall back to the previous "treat as inactive" semantics. + logger.warning( + "No handler registered for rule type %s — treating as inactive", + type(rule).__name__, + ) return False - return handler(rule) + return handler(self, rule, ctx) + + # -- Per-rule-type handlers -- + # Bound to ``self`` via ``_RULE_HANDLERS`` lookup; each signature is + # ``(self, rule, ctx: _RuleEvalContext) -> bool``. + + def _handle_startup(self, rule: StartupRule, ctx: _RuleEvalContext) -> bool: + return True + + def _handle_application(self, rule: ApplicationRule, ctx: _RuleEvalContext) -> bool: + return self._evaluate_app_rule( + rule, + ctx.running_procs, + ctx.topmost_proc, + ctx.topmost_fullscreen, + ctx.fullscreen_procs, + ) + + def _handle_time_of_day(self, rule: TimeOfDayRule, ctx: _RuleEvalContext) -> bool: + return self._evaluate_time_of_day(rule) + + def _handle_system_idle(self, rule: SystemIdleRule, ctx: _RuleEvalContext) -> bool: + return self._evaluate_idle(rule, ctx.idle_seconds) + + def _handle_display_state(self, rule: DisplayStateRule, ctx: _RuleEvalContext) -> bool: + return self._evaluate_display_state(rule, ctx.display_state) + + def _handle_mqtt(self, rule: MQTTRule, ctx: _RuleEvalContext) -> bool: + return self._evaluate_mqtt(rule) + + def _handle_webhook(self, rule: WebhookRule, ctx: _RuleEvalContext) -> bool: + return self._webhook_states.get(rule.token, False) + + def _handle_home_assistant(self, rule: HomeAssistantRule, ctx: _RuleEvalContext) -> bool: + return self._evaluate_home_assistant(rule) + + def _handle_http_poll(self, rule: HTTPPollRule, ctx: _RuleEvalContext) -> bool: + return self._evaluate_http_poll(rule) @staticmethod def _evaluate_time_of_day(rule: TimeOfDayRule) -> bool: @@ -436,6 +589,25 @@ class AutomationEngine: logger.debug("HA rule regex error: %s", e) return False + def _evaluate_http_poll(self, rule: HTTPPollRule) -> bool: + """Evaluate an HTTPPollRule by reading the referenced value source. + + The value source (HTTPValueSource → HTTPValueStream) handles the + actual polling + JSON extraction; the rule only compares the + already-extracted raw value to ``rule.value`` via ``operator``. + """ + if self._value_stream_manager is None or not rule.value_source_id: + return False + stream = self._value_stream_manager.peek(rule.value_source_id) + if stream is None or not hasattr(stream, "get_raw_value"): + return False + raw = stream.get_raw_value() + if rule.operator == "exists": + return raw is not None + if raw is None: + return False + return _apply_operator(rule.operator, raw, rule.value) + def _evaluate_app_rule( self, rule: ApplicationRule, @@ -636,3 +808,57 @@ class AutomationEngine: """Deactivate an automation immediately (used when disabling/deleting).""" if automation_id in self._active_automations: await self._deactivate_automation(automation_id) + + +# Bind the per-rule-type handler table once after the class is fully defined. +# This replaces the per-call dict-rebuild that the inline ``_evaluate_rule`` +# used to do and gives us a single place to assert coverage against the +# Rule subclass set imported from storage. +AutomationEngine._RULE_HANDLERS = { + StartupRule: AutomationEngine._handle_startup, + ApplicationRule: AutomationEngine._handle_application, + TimeOfDayRule: AutomationEngine._handle_time_of_day, + SystemIdleRule: AutomationEngine._handle_system_idle, + DisplayStateRule: AutomationEngine._handle_display_state, + MQTTRule: AutomationEngine._handle_mqtt, + WebhookRule: AutomationEngine._handle_webhook, + HomeAssistantRule: AutomationEngine._handle_home_assistant, + HTTPPollRule: AutomationEngine._handle_http_poll, +} + + +def _assert_rule_handler_coverage() -> None: + """Every concrete Rule subclass imported by this module must have a handler. + + Runs at module import so a new Rule subclass added without an + accompanying ``_handle_*`` method + ``_RULE_HANDLERS`` entry fails the + server boot loudly instead of silently being dropped on the floor by + ``_evaluate_rule``'s "no handler → False" fallback. + """ + expected = { + StartupRule, + ApplicationRule, + TimeOfDayRule, + SystemIdleRule, + DisplayStateRule, + MQTTRule, + WebhookRule, + HomeAssistantRule, + HTTPPollRule, + } + registered = set(AutomationEngine._RULE_HANDLERS.keys()) + missing = expected - registered + extra = registered - expected + if missing or extra: + problems = [] + if missing: + problems.append(f"missing handlers: {sorted(c.__name__ for c in missing)}") + if extra: + problems.append(f"unregistered classes: {sorted(c.__name__ for c in extra)}") + raise RuntimeError( + "AutomationEngine._RULE_HANDLERS is out of sync with imported Rule subclasses: " + + "; ".join(problems) + ) + + +_assert_rule_handler_coverage() diff --git a/server/tests/core/test_automation_rule_handlers.py b/server/tests/core/test_automation_rule_handlers.py new file mode 100644 index 0000000..9b6dc54 --- /dev/null +++ b/server/tests/core/test_automation_rule_handlers.py @@ -0,0 +1,80 @@ +"""Tests for the AutomationEngine rule-handler dispatch registry. + +Lock in the import-time invariant: every Rule subclass imported by the +engine module has a corresponding ``_handle_*`` entry in +``_RULE_HANDLERS``, and the coverage check rejects drift. +""" + +from __future__ import annotations + +import pytest + +from ledgrab.core.automations import automation_engine +from ledgrab.core.automations.automation_engine import ( + AutomationEngine, + _assert_rule_handler_coverage, +) +from ledgrab.storage.automation import ( + ApplicationRule, + DisplayStateRule, + HomeAssistantRule, + HTTPPollRule, + MQTTRule, + Rule, + StartupRule, + SystemIdleRule, + TimeOfDayRule, + WebhookRule, +) + + +EXPECTED_RULE_TYPES = { + StartupRule, + ApplicationRule, + TimeOfDayRule, + SystemIdleRule, + DisplayStateRule, + MQTTRule, + WebhookRule, + HomeAssistantRule, + HTTPPollRule, +} + + +def test_every_rule_type_has_a_handler(): + """The registry exactly covers the rule-type set the engine imports.""" + assert set(AutomationEngine._RULE_HANDLERS.keys()) == EXPECTED_RULE_TYPES + + +def test_handlers_are_engine_methods(): + """Each handler value is a method defined on AutomationEngine.""" + for rule_cls, handler in AutomationEngine._RULE_HANDLERS.items(): + assert callable(handler), f"handler for {rule_cls.__name__} is not callable" + # Method names start with _handle_ + assert handler.__name__.startswith( + "_handle_" + ), f"handler for {rule_cls.__name__} has unexpected name {handler.__name__}" + + +def test_coverage_assertion_raises_when_handler_is_missing(monkeypatch): + """Removing an entry from _RULE_HANDLERS makes the import-time check fail.""" + # Build a clone of the registry without one entry to simulate drift. + original = dict(AutomationEngine._RULE_HANDLERS) + pruned = {k: v for k, v in original.items() if k is not WebhookRule} + monkeypatch.setattr(automation_engine.AutomationEngine, "_RULE_HANDLERS", pruned) + + with pytest.raises(RuntimeError, match="WebhookRule"): + _assert_rule_handler_coverage() + + +def test_coverage_assertion_raises_when_unexpected_handler_added(monkeypatch): + """An entry keyed by an unknown class is also caught.""" + + class _UnknownRule(Rule): # type: ignore[misc] + pass + + extended = {**AutomationEngine._RULE_HANDLERS, _UnknownRule: lambda *a: True} + monkeypatch.setattr(automation_engine.AutomationEngine, "_RULE_HANDLERS", extended) + + with pytest.raises(RuntimeError, match="_UnknownRule"): + _assert_rule_handler_coverage()