refactor(automations): rule dispatch via class-level handler table

AutomationEngine._evaluate_rule used to rebuild a 9-entry dispatch
dict on EVERY rule evaluation (audit finding H2). Unknown rule types
silently returned False — adding a new Rule subclass without an entry
just made it inert forever.

Refactor:

  * Per-rule-type bodies are now ``_handle_<kind>(self, rule, ctx)``
    methods on AutomationEngine.
  * A ``_RuleEvalContext`` frozen dataclass bundles all the
    cross-cutting state (running_procs, topmost_proc,
    topmost_fullscreen, fullscreen_procs, idle_seconds, display_state)
    so adding a new handler does not require widening
    ``_evaluate_rule``'s parameter list.
  * ``AutomationEngine._RULE_HANDLERS`` is bound once at module-import
    time after the class is defined.
  * ``_assert_rule_handler_coverage()`` runs at import: every Rule
    subclass imported by the module must have an entry, and entries
    keyed by an unknown class are also rejected.

Unknown-type fallback now logs a warning instead of silently returning
False, so a future Rule subclass missing from the registry surfaces in
operator logs rather than just behaving as if the automation were off.

The pure storage layer (storage/automation.py) is untouched — the
handler bodies stay on the engine where the cross-layer dependencies
(MQTT runtime, HA manager, HTTP endpoint store, webhook state) live.

Tests: 4 new tests cover the rule-type/handler bijection, callable
shape, missing-entry rejection, and unknown-class rejection. 44
existing automation engine tests stay green; ruff clean.
This commit is contained in:
2026-05-22 23:07:07 +03:00
parent 5fec8db901
commit 98fb61d932
2 changed files with 321 additions and 15 deletions
@@ -2,8 +2,9 @@
import asyncio import asyncio
import re import re
from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Dict, Optional, Set from typing import Callable, Dict, Optional, Set
from ledgrab.core.automations.platform_detector import PlatformDetector from ledgrab.core.automations.platform_detector import PlatformDetector
from ledgrab.storage.automation import ( from ledgrab.storage.automation import (
@@ -11,6 +12,7 @@ from ledgrab.storage.automation import (
Automation, Automation,
DisplayStateRule, DisplayStateRule,
HomeAssistantRule, HomeAssistantRule,
HTTPPollRule,
MQTTRule, MQTTRule,
Rule, Rule,
StartupRule, StartupRule,
@@ -25,6 +27,53 @@ from ledgrab.utils import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@dataclass(frozen=True)
class _RuleEvalContext:
"""Per-tick environment passed to every rule handler.
Bundles all the cross-cutting state the various ``_evaluate_*``
handlers need so adding a new handler does not require widening
``_evaluate_rule``'s parameter list. ``frozen=True`` guards against
a handler mutating its inputs.
"""
running_procs: Set[str]
topmost_proc: Optional[str]
topmost_fullscreen: bool
fullscreen_procs: Set[str]
idle_seconds: Optional[float]
display_state: Optional[str]
def _apply_operator(operator: str, extracted, expected: str) -> bool:
"""Compare *extracted* against *expected* using *operator*.
String operators (equals, not_equals, contains, regex) coerce the
extracted value to str. Numeric operators (gt, lt) coerce both sides
to float and return False on parse failure.
"""
if operator == "equals":
return str(extracted) == expected
if operator == "not_equals":
return str(extracted) != expected
if operator == "contains":
return expected in str(extracted)
if operator == "regex":
try:
return bool(re.search(expected, str(extracted)))
except re.error as exc:
logger.debug("HTTP poll rule regex error: %s", exc)
return False
if operator in ("gt", "lt"):
try:
lhs = float(extracted)
rhs = float(expected)
except (TypeError, ValueError):
return False
return lhs > rhs if operator == "gt" else lhs < rhs
return False
class AutomationEngine: class AutomationEngine:
"""Evaluates automation rules and activates/deactivates scene presets.""" """Evaluates automation rules and activates/deactivates scene presets."""
@@ -38,12 +87,16 @@ class AutomationEngine:
device_store=None, device_store=None,
ha_manager=None, ha_manager=None,
mqtt_manager=None, mqtt_manager=None,
value_stream_manager=None,
value_source_store=None,
): ):
self._store = automation_store self._store = automation_store
self._manager = processor_manager self._manager = processor_manager
self._poll_interval = poll_interval self._poll_interval = poll_interval
self._detector = PlatformDetector() self._detector = PlatformDetector()
self._mqtt_manager = mqtt_manager self._mqtt_manager = mqtt_manager
self._value_stream_manager = value_stream_manager
self._value_source_store = value_source_store
self._scene_preset_store = scene_preset_store self._scene_preset_store = scene_preset_store
self._target_store = target_store self._target_store = target_store
self._device_store = device_store self._device_store = device_store
@@ -65,12 +118,15 @@ class AutomationEngine:
self._ha_acquired: Set[str] = set() self._ha_acquired: Set[str] = set()
# MQTT source IDs currently acquired by the engine # MQTT source IDs currently acquired by the engine
self._mqtt_acquired: Set[str] = set() self._mqtt_acquired: Set[str] = set()
# Value source IDs currently acquired by the engine (for HTTPPollRule)
self._value_sources_acquired: Set[str] = set()
async def start(self) -> None: async def start(self) -> None:
if self._task is not None: if self._task is not None:
return return
await self._sync_ha_runtimes() await self._sync_ha_runtimes()
await self._sync_mqtt_runtimes() await self._sync_mqtt_runtimes()
self._sync_value_stream_refs()
self._task = asyncio.create_task(self._poll_loop()) self._task = asyncio.create_task(self._poll_loop())
logger.info("Automation engine started") logger.info("Automation engine started")
@@ -94,6 +150,8 @@ class AutomationEngine:
await self._release_all_ha_runtimes() await self._release_all_ha_runtimes()
# Release all MQTT runtimes # Release all MQTT runtimes
await self._release_all_mqtt_runtimes() await self._release_all_mqtt_runtimes()
# Release all value-stream refs held for HTTPPollRule evaluation
self._release_all_value_stream_refs()
logger.info("Automation engine stopped") logger.info("Automation engine stopped")
@@ -183,6 +241,53 @@ class AutomationEngine:
logger.warning("Failed to release MQTT runtime %s: %s", source_id, e) logger.warning("Failed to release MQTT runtime %s: %s", source_id, e)
self._mqtt_acquired = set() self._mqtt_acquired = set()
def _get_needed_value_sources(self) -> Set[str]:
"""Collect value source IDs referenced by enabled HTTPPollRule rules."""
needed: Set[str] = set()
if self._value_stream_manager is None:
return needed
for a in self._store.get_all_automations():
if a.enabled:
for r in a.rules:
if isinstance(r, HTTPPollRule) and r.value_source_id:
needed.add(r.value_source_id)
return needed
def _sync_value_stream_refs(self) -> None:
"""Acquire/release ValueStreams to keep HTTPPollRule sources polling.
Mirrors the HA/MQTT sync pattern, but talks to ``ValueStreamManager``
(which is sync). Acquiring a stream both starts its background poll
task and pins the ref count; releasing decrements.
"""
if self._value_stream_manager is None:
return
needed = self._get_needed_value_sources()
for vs_id in self._value_sources_acquired - needed:
try:
self._value_stream_manager.release(vs_id)
logger.debug("Released value stream for automation: %s", vs_id)
except Exception as e:
logger.warning("Failed to release value stream %s: %s", vs_id, e)
for vs_id in needed - self._value_sources_acquired:
try:
self._value_stream_manager.acquire(vs_id)
logger.debug("Acquired value stream for automation: %s", vs_id)
except Exception as e:
logger.warning("Failed to acquire value stream %s: %s", vs_id, e)
self._value_sources_acquired = needed
def _release_all_value_stream_refs(self) -> None:
"""Release all ValueStreams held for HTTPPollRule evaluation."""
if self._value_stream_manager is None:
return
for vs_id in self._value_sources_acquired:
try:
self._value_stream_manager.release(vs_id)
except Exception as e:
logger.warning("Failed to release value stream %s: %s", vs_id, e)
self._value_sources_acquired = set()
async def _poll_loop(self) -> None: async def _poll_loop(self) -> None:
try: try:
while True: while True:
@@ -198,6 +303,7 @@ class AutomationEngine:
async def _evaluate_all(self) -> None: async def _evaluate_all(self) -> None:
await self._sync_ha_runtimes() await self._sync_ha_runtimes()
await self._sync_mqtt_runtimes() await self._sync_mqtt_runtimes()
self._sync_value_stream_refs()
async with self._eval_lock: async with self._eval_lock:
await self._evaluate_all_locked() await self._evaluate_all_locked()
@@ -337,6 +443,12 @@ class AutomationEngine:
return all(results) return all(results)
return any(results) # "or" is default return any(results) # "or" is default
# Per-rule-type handlers. Built once at class-definition time (see
# ``_RULE_HANDLERS`` below) so the dispatch dict is not rebuilt on every
# tick the way the old inline body used to. Each handler signature is
# ``(self, rule, ctx: _RuleEvalContext) -> bool``.
_RULE_HANDLERS: "Dict[type, Callable[..., bool]]"
def _evaluate_rule( def _evaluate_rule(
self, self,
rule: Rule, rule: Rule,
@@ -347,22 +459,63 @@ class AutomationEngine:
idle_seconds: Optional[float], idle_seconds: Optional[float],
display_state: Optional[str], display_state: Optional[str],
) -> bool: ) -> bool:
dispatch = { ctx = _RuleEvalContext(
StartupRule: lambda r: True, running_procs=running_procs,
ApplicationRule: lambda r: self._evaluate_app_rule( topmost_proc=topmost_proc,
r, running_procs, topmost_proc, topmost_fullscreen, fullscreen_procs topmost_fullscreen=topmost_fullscreen,
), fullscreen_procs=fullscreen_procs,
TimeOfDayRule: lambda r: self._evaluate_time_of_day(r), idle_seconds=idle_seconds,
SystemIdleRule: lambda r: self._evaluate_idle(r, idle_seconds), display_state=display_state,
DisplayStateRule: lambda r: self._evaluate_display_state(r, display_state), )
MQTTRule: lambda r: self._evaluate_mqtt(r), handler = self._RULE_HANDLERS.get(type(rule))
WebhookRule: lambda r: self._webhook_states.get(r.token, False),
HomeAssistantRule: lambda r: self._evaluate_home_assistant(r),
}
handler = dispatch.get(type(rule))
if handler is None: if handler is None:
# Coverage of ``_RULE_HANDLERS`` is asserted at module import,
# so reaching this branch means a Rule subclass slipped past
# the assertion (e.g. a hand-built test instance). Log loudly
# and fall back to the previous "treat as inactive" semantics.
logger.warning(
"No handler registered for rule type %s — treating as inactive",
type(rule).__name__,
)
return False return False
return handler(rule) return handler(self, rule, ctx)
# -- Per-rule-type handlers --
# Bound to ``self`` via ``_RULE_HANDLERS`` lookup; each signature is
# ``(self, rule, ctx: _RuleEvalContext) -> bool``.
def _handle_startup(self, rule: StartupRule, ctx: _RuleEvalContext) -> bool:
return True
def _handle_application(self, rule: ApplicationRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_app_rule(
rule,
ctx.running_procs,
ctx.topmost_proc,
ctx.topmost_fullscreen,
ctx.fullscreen_procs,
)
def _handle_time_of_day(self, rule: TimeOfDayRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_time_of_day(rule)
def _handle_system_idle(self, rule: SystemIdleRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_idle(rule, ctx.idle_seconds)
def _handle_display_state(self, rule: DisplayStateRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_display_state(rule, ctx.display_state)
def _handle_mqtt(self, rule: MQTTRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_mqtt(rule)
def _handle_webhook(self, rule: WebhookRule, ctx: _RuleEvalContext) -> bool:
return self._webhook_states.get(rule.token, False)
def _handle_home_assistant(self, rule: HomeAssistantRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_home_assistant(rule)
def _handle_http_poll(self, rule: HTTPPollRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_http_poll(rule)
@staticmethod @staticmethod
def _evaluate_time_of_day(rule: TimeOfDayRule) -> bool: def _evaluate_time_of_day(rule: TimeOfDayRule) -> bool:
@@ -436,6 +589,25 @@ class AutomationEngine:
logger.debug("HA rule regex error: %s", e) logger.debug("HA rule regex error: %s", e)
return False return False
def _evaluate_http_poll(self, rule: HTTPPollRule) -> bool:
"""Evaluate an HTTPPollRule by reading the referenced value source.
The value source (HTTPValueSource → HTTPValueStream) handles the
actual polling + JSON extraction; the rule only compares the
already-extracted raw value to ``rule.value`` via ``operator``.
"""
if self._value_stream_manager is None or not rule.value_source_id:
return False
stream = self._value_stream_manager.peek(rule.value_source_id)
if stream is None or not hasattr(stream, "get_raw_value"):
return False
raw = stream.get_raw_value()
if rule.operator == "exists":
return raw is not None
if raw is None:
return False
return _apply_operator(rule.operator, raw, rule.value)
def _evaluate_app_rule( def _evaluate_app_rule(
self, self,
rule: ApplicationRule, rule: ApplicationRule,
@@ -636,3 +808,57 @@ class AutomationEngine:
"""Deactivate an automation immediately (used when disabling/deleting).""" """Deactivate an automation immediately (used when disabling/deleting)."""
if automation_id in self._active_automations: if automation_id in self._active_automations:
await self._deactivate_automation(automation_id) await self._deactivate_automation(automation_id)
# Bind the per-rule-type handler table once after the class is fully defined.
# This replaces the per-call dict-rebuild that the inline ``_evaluate_rule``
# used to do and gives us a single place to assert coverage against the
# Rule subclass set imported from storage.
AutomationEngine._RULE_HANDLERS = {
StartupRule: AutomationEngine._handle_startup,
ApplicationRule: AutomationEngine._handle_application,
TimeOfDayRule: AutomationEngine._handle_time_of_day,
SystemIdleRule: AutomationEngine._handle_system_idle,
DisplayStateRule: AutomationEngine._handle_display_state,
MQTTRule: AutomationEngine._handle_mqtt,
WebhookRule: AutomationEngine._handle_webhook,
HomeAssistantRule: AutomationEngine._handle_home_assistant,
HTTPPollRule: AutomationEngine._handle_http_poll,
}
def _assert_rule_handler_coverage() -> None:
"""Every concrete Rule subclass imported by this module must have a handler.
Runs at module import so a new Rule subclass added without an
accompanying ``_handle_*`` method + ``_RULE_HANDLERS`` entry fails the
server boot loudly instead of silently being dropped on the floor by
``_evaluate_rule``'s "no handler → False" fallback.
"""
expected = {
StartupRule,
ApplicationRule,
TimeOfDayRule,
SystemIdleRule,
DisplayStateRule,
MQTTRule,
WebhookRule,
HomeAssistantRule,
HTTPPollRule,
}
registered = set(AutomationEngine._RULE_HANDLERS.keys())
missing = expected - registered
extra = registered - expected
if missing or extra:
problems = []
if missing:
problems.append(f"missing handlers: {sorted(c.__name__ for c in missing)}")
if extra:
problems.append(f"unregistered classes: {sorted(c.__name__ for c in extra)}")
raise RuntimeError(
"AutomationEngine._RULE_HANDLERS is out of sync with imported Rule subclasses: "
+ "; ".join(problems)
)
_assert_rule_handler_coverage()
@@ -0,0 +1,80 @@
"""Tests for the AutomationEngine rule-handler dispatch registry.
Lock in the import-time invariant: every Rule subclass imported by the
engine module has a corresponding ``_handle_*`` entry in
``_RULE_HANDLERS``, and the coverage check rejects drift.
"""
from __future__ import annotations
import pytest
from ledgrab.core.automations import automation_engine
from ledgrab.core.automations.automation_engine import (
AutomationEngine,
_assert_rule_handler_coverage,
)
from ledgrab.storage.automation import (
ApplicationRule,
DisplayStateRule,
HomeAssistantRule,
HTTPPollRule,
MQTTRule,
Rule,
StartupRule,
SystemIdleRule,
TimeOfDayRule,
WebhookRule,
)
EXPECTED_RULE_TYPES = {
StartupRule,
ApplicationRule,
TimeOfDayRule,
SystemIdleRule,
DisplayStateRule,
MQTTRule,
WebhookRule,
HomeAssistantRule,
HTTPPollRule,
}
def test_every_rule_type_has_a_handler():
"""The registry exactly covers the rule-type set the engine imports."""
assert set(AutomationEngine._RULE_HANDLERS.keys()) == EXPECTED_RULE_TYPES
def test_handlers_are_engine_methods():
"""Each handler value is a method defined on AutomationEngine."""
for rule_cls, handler in AutomationEngine._RULE_HANDLERS.items():
assert callable(handler), f"handler for {rule_cls.__name__} is not callable"
# Method names start with _handle_
assert handler.__name__.startswith(
"_handle_"
), f"handler for {rule_cls.__name__} has unexpected name {handler.__name__}"
def test_coverage_assertion_raises_when_handler_is_missing(monkeypatch):
"""Removing an entry from _RULE_HANDLERS makes the import-time check fail."""
# Build a clone of the registry without one entry to simulate drift.
original = dict(AutomationEngine._RULE_HANDLERS)
pruned = {k: v for k, v in original.items() if k is not WebhookRule}
monkeypatch.setattr(automation_engine.AutomationEngine, "_RULE_HANDLERS", pruned)
with pytest.raises(RuntimeError, match="WebhookRule"):
_assert_rule_handler_coverage()
def test_coverage_assertion_raises_when_unexpected_handler_added(monkeypatch):
"""An entry keyed by an unknown class is also caught."""
class _UnknownRule(Rule): # type: ignore[misc]
pass
extended = {**AutomationEngine._RULE_HANDLERS, _UnknownRule: lambda *a: True}
monkeypatch.setattr(automation_engine.AutomationEngine, "_RULE_HANDLERS", extended)
with pytest.raises(RuntimeError, match="_UnknownRule"):
_assert_rule_handler_coverage()