refactor(automations): rule dispatch via class-level handler table

AutomationEngine._evaluate_rule used to rebuild a 9-entry dispatch
dict on EVERY rule evaluation (audit finding H2). Unknown rule types
silently returned False — adding a new Rule subclass without an entry
just made it inert forever.

Refactor:

  * Per-rule-type bodies are now ``_handle_<kind>(self, rule, ctx)``
    methods on AutomationEngine.
  * A ``_RuleEvalContext`` frozen dataclass bundles all the
    cross-cutting state (running_procs, topmost_proc,
    topmost_fullscreen, fullscreen_procs, idle_seconds, display_state)
    so adding a new handler does not require widening
    ``_evaluate_rule``'s parameter list.
  * ``AutomationEngine._RULE_HANDLERS`` is bound once at module-import
    time after the class is defined.
  * ``_assert_rule_handler_coverage()`` runs at import: every Rule
    subclass imported by the module must have an entry, and entries
    keyed by an unknown class are also rejected.

Unknown-type fallback now logs a warning instead of silently returning
False, so a future Rule subclass missing from the registry surfaces in
operator logs rather than just behaving as if the automation were off.

The pure storage layer (storage/automation.py) is untouched — the
handler bodies stay on the engine where the cross-layer dependencies
(MQTT runtime, HA manager, HTTP endpoint store, webhook state) live.

Tests: 4 new tests cover the rule-type/handler bijection, callable
shape, missing-entry rejection, and unknown-class rejection. 44
existing automation engine tests stay green; ruff clean.
This commit is contained in:
2026-05-22 23:07:07 +03:00
parent 5fec8db901
commit 98fb61d932
2 changed files with 321 additions and 15 deletions
@@ -2,8 +2,9 @@
import asyncio
import re
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Dict, Optional, Set
from typing import Callable, Dict, Optional, Set
from ledgrab.core.automations.platform_detector import PlatformDetector
from ledgrab.storage.automation import (
@@ -11,6 +12,7 @@ from ledgrab.storage.automation import (
Automation,
DisplayStateRule,
HomeAssistantRule,
HTTPPollRule,
MQTTRule,
Rule,
StartupRule,
@@ -25,6 +27,53 @@ from ledgrab.utils import get_logger
logger = get_logger(__name__)
@dataclass(frozen=True)
class _RuleEvalContext:
"""Per-tick environment passed to every rule handler.
Bundles all the cross-cutting state the various ``_evaluate_*``
handlers need so adding a new handler does not require widening
``_evaluate_rule``'s parameter list. ``frozen=True`` guards against
a handler mutating its inputs.
"""
running_procs: Set[str]
topmost_proc: Optional[str]
topmost_fullscreen: bool
fullscreen_procs: Set[str]
idle_seconds: Optional[float]
display_state: Optional[str]
def _apply_operator(operator: str, extracted, expected: str) -> bool:
"""Compare *extracted* against *expected* using *operator*.
String operators (equals, not_equals, contains, regex) coerce the
extracted value to str. Numeric operators (gt, lt) coerce both sides
to float and return False on parse failure.
"""
if operator == "equals":
return str(extracted) == expected
if operator == "not_equals":
return str(extracted) != expected
if operator == "contains":
return expected in str(extracted)
if operator == "regex":
try:
return bool(re.search(expected, str(extracted)))
except re.error as exc:
logger.debug("HTTP poll rule regex error: %s", exc)
return False
if operator in ("gt", "lt"):
try:
lhs = float(extracted)
rhs = float(expected)
except (TypeError, ValueError):
return False
return lhs > rhs if operator == "gt" else lhs < rhs
return False
class AutomationEngine:
"""Evaluates automation rules and activates/deactivates scene presets."""
@@ -38,12 +87,16 @@ class AutomationEngine:
device_store=None,
ha_manager=None,
mqtt_manager=None,
value_stream_manager=None,
value_source_store=None,
):
self._store = automation_store
self._manager = processor_manager
self._poll_interval = poll_interval
self._detector = PlatformDetector()
self._mqtt_manager = mqtt_manager
self._value_stream_manager = value_stream_manager
self._value_source_store = value_source_store
self._scene_preset_store = scene_preset_store
self._target_store = target_store
self._device_store = device_store
@@ -65,12 +118,15 @@ class AutomationEngine:
self._ha_acquired: Set[str] = set()
# MQTT source IDs currently acquired by the engine
self._mqtt_acquired: Set[str] = set()
# Value source IDs currently acquired by the engine (for HTTPPollRule)
self._value_sources_acquired: Set[str] = set()
async def start(self) -> None:
if self._task is not None:
return
await self._sync_ha_runtimes()
await self._sync_mqtt_runtimes()
self._sync_value_stream_refs()
self._task = asyncio.create_task(self._poll_loop())
logger.info("Automation engine started")
@@ -94,6 +150,8 @@ class AutomationEngine:
await self._release_all_ha_runtimes()
# Release all MQTT runtimes
await self._release_all_mqtt_runtimes()
# Release all value-stream refs held for HTTPPollRule evaluation
self._release_all_value_stream_refs()
logger.info("Automation engine stopped")
@@ -183,6 +241,53 @@ class AutomationEngine:
logger.warning("Failed to release MQTT runtime %s: %s", source_id, e)
self._mqtt_acquired = set()
def _get_needed_value_sources(self) -> Set[str]:
"""Collect value source IDs referenced by enabled HTTPPollRule rules."""
needed: Set[str] = set()
if self._value_stream_manager is None:
return needed
for a in self._store.get_all_automations():
if a.enabled:
for r in a.rules:
if isinstance(r, HTTPPollRule) and r.value_source_id:
needed.add(r.value_source_id)
return needed
def _sync_value_stream_refs(self) -> None:
"""Acquire/release ValueStreams to keep HTTPPollRule sources polling.
Mirrors the HA/MQTT sync pattern, but talks to ``ValueStreamManager``
(which is sync). Acquiring a stream both starts its background poll
task and pins the ref count; releasing decrements.
"""
if self._value_stream_manager is None:
return
needed = self._get_needed_value_sources()
for vs_id in self._value_sources_acquired - needed:
try:
self._value_stream_manager.release(vs_id)
logger.debug("Released value stream for automation: %s", vs_id)
except Exception as e:
logger.warning("Failed to release value stream %s: %s", vs_id, e)
for vs_id in needed - self._value_sources_acquired:
try:
self._value_stream_manager.acquire(vs_id)
logger.debug("Acquired value stream for automation: %s", vs_id)
except Exception as e:
logger.warning("Failed to acquire value stream %s: %s", vs_id, e)
self._value_sources_acquired = needed
def _release_all_value_stream_refs(self) -> None:
"""Release all ValueStreams held for HTTPPollRule evaluation."""
if self._value_stream_manager is None:
return
for vs_id in self._value_sources_acquired:
try:
self._value_stream_manager.release(vs_id)
except Exception as e:
logger.warning("Failed to release value stream %s: %s", vs_id, e)
self._value_sources_acquired = set()
async def _poll_loop(self) -> None:
try:
while True:
@@ -198,6 +303,7 @@ class AutomationEngine:
async def _evaluate_all(self) -> None:
await self._sync_ha_runtimes()
await self._sync_mqtt_runtimes()
self._sync_value_stream_refs()
async with self._eval_lock:
await self._evaluate_all_locked()
@@ -337,6 +443,12 @@ class AutomationEngine:
return all(results)
return any(results) # "or" is default
# Per-rule-type handlers. Built once at class-definition time (see
# ``_RULE_HANDLERS`` below) so the dispatch dict is not rebuilt on every
# tick the way the old inline body used to. Each handler signature is
# ``(self, rule, ctx: _RuleEvalContext) -> bool``.
_RULE_HANDLERS: "Dict[type, Callable[..., bool]]"
def _evaluate_rule(
self,
rule: Rule,
@@ -347,22 +459,63 @@ class AutomationEngine:
idle_seconds: Optional[float],
display_state: Optional[str],
) -> bool:
dispatch = {
StartupRule: lambda r: True,
ApplicationRule: lambda r: self._evaluate_app_rule(
r, running_procs, topmost_proc, topmost_fullscreen, fullscreen_procs
),
TimeOfDayRule: lambda r: self._evaluate_time_of_day(r),
SystemIdleRule: lambda r: self._evaluate_idle(r, idle_seconds),
DisplayStateRule: lambda r: self._evaluate_display_state(r, display_state),
MQTTRule: lambda r: self._evaluate_mqtt(r),
WebhookRule: lambda r: self._webhook_states.get(r.token, False),
HomeAssistantRule: lambda r: self._evaluate_home_assistant(r),
}
handler = dispatch.get(type(rule))
ctx = _RuleEvalContext(
running_procs=running_procs,
topmost_proc=topmost_proc,
topmost_fullscreen=topmost_fullscreen,
fullscreen_procs=fullscreen_procs,
idle_seconds=idle_seconds,
display_state=display_state,
)
handler = self._RULE_HANDLERS.get(type(rule))
if handler is None:
# Coverage of ``_RULE_HANDLERS`` is asserted at module import,
# so reaching this branch means a Rule subclass slipped past
# the assertion (e.g. a hand-built test instance). Log loudly
# and fall back to the previous "treat as inactive" semantics.
logger.warning(
"No handler registered for rule type %s — treating as inactive",
type(rule).__name__,
)
return False
return handler(rule)
return handler(self, rule, ctx)
# -- Per-rule-type handlers --
# Bound to ``self`` via ``_RULE_HANDLERS`` lookup; each signature is
# ``(self, rule, ctx: _RuleEvalContext) -> bool``.
def _handle_startup(self, rule: StartupRule, ctx: _RuleEvalContext) -> bool:
return True
def _handle_application(self, rule: ApplicationRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_app_rule(
rule,
ctx.running_procs,
ctx.topmost_proc,
ctx.topmost_fullscreen,
ctx.fullscreen_procs,
)
def _handle_time_of_day(self, rule: TimeOfDayRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_time_of_day(rule)
def _handle_system_idle(self, rule: SystemIdleRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_idle(rule, ctx.idle_seconds)
def _handle_display_state(self, rule: DisplayStateRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_display_state(rule, ctx.display_state)
def _handle_mqtt(self, rule: MQTTRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_mqtt(rule)
def _handle_webhook(self, rule: WebhookRule, ctx: _RuleEvalContext) -> bool:
return self._webhook_states.get(rule.token, False)
def _handle_home_assistant(self, rule: HomeAssistantRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_home_assistant(rule)
def _handle_http_poll(self, rule: HTTPPollRule, ctx: _RuleEvalContext) -> bool:
return self._evaluate_http_poll(rule)
@staticmethod
def _evaluate_time_of_day(rule: TimeOfDayRule) -> bool:
@@ -436,6 +589,25 @@ class AutomationEngine:
logger.debug("HA rule regex error: %s", e)
return False
def _evaluate_http_poll(self, rule: HTTPPollRule) -> bool:
"""Evaluate an HTTPPollRule by reading the referenced value source.
The value source (HTTPValueSource → HTTPValueStream) handles the
actual polling + JSON extraction; the rule only compares the
already-extracted raw value to ``rule.value`` via ``operator``.
"""
if self._value_stream_manager is None or not rule.value_source_id:
return False
stream = self._value_stream_manager.peek(rule.value_source_id)
if stream is None or not hasattr(stream, "get_raw_value"):
return False
raw = stream.get_raw_value()
if rule.operator == "exists":
return raw is not None
if raw is None:
return False
return _apply_operator(rule.operator, raw, rule.value)
def _evaluate_app_rule(
self,
rule: ApplicationRule,
@@ -636,3 +808,57 @@ class AutomationEngine:
"""Deactivate an automation immediately (used when disabling/deleting)."""
if automation_id in self._active_automations:
await self._deactivate_automation(automation_id)
# Bind the per-rule-type handler table once after the class is fully defined.
# This replaces the per-call dict-rebuild that the inline ``_evaluate_rule``
# used to do and gives us a single place to assert coverage against the
# Rule subclass set imported from storage.
AutomationEngine._RULE_HANDLERS = {
StartupRule: AutomationEngine._handle_startup,
ApplicationRule: AutomationEngine._handle_application,
TimeOfDayRule: AutomationEngine._handle_time_of_day,
SystemIdleRule: AutomationEngine._handle_system_idle,
DisplayStateRule: AutomationEngine._handle_display_state,
MQTTRule: AutomationEngine._handle_mqtt,
WebhookRule: AutomationEngine._handle_webhook,
HomeAssistantRule: AutomationEngine._handle_home_assistant,
HTTPPollRule: AutomationEngine._handle_http_poll,
}
def _assert_rule_handler_coverage() -> None:
"""Every concrete Rule subclass imported by this module must have a handler.
Runs at module import so a new Rule subclass added without an
accompanying ``_handle_*`` method + ``_RULE_HANDLERS`` entry fails the
server boot loudly instead of silently being dropped on the floor by
``_evaluate_rule``'s "no handler → False" fallback.
"""
expected = {
StartupRule,
ApplicationRule,
TimeOfDayRule,
SystemIdleRule,
DisplayStateRule,
MQTTRule,
WebhookRule,
HomeAssistantRule,
HTTPPollRule,
}
registered = set(AutomationEngine._RULE_HANDLERS.keys())
missing = expected - registered
extra = registered - expected
if missing or extra:
problems = []
if missing:
problems.append(f"missing handlers: {sorted(c.__name__ for c in missing)}")
if extra:
problems.append(f"unregistered classes: {sorted(c.__name__ for c in extra)}")
raise RuntimeError(
"AutomationEngine._RULE_HANDLERS is out of sync with imported Rule subclasses: "
+ "; ".join(problems)
)
_assert_rule_handler_coverage()
@@ -0,0 +1,80 @@
"""Tests for the AutomationEngine rule-handler dispatch registry.
Lock in the import-time invariant: every Rule subclass imported by the
engine module has a corresponding ``_handle_*`` entry in
``_RULE_HANDLERS``, and the coverage check rejects drift.
"""
from __future__ import annotations
import pytest
from ledgrab.core.automations import automation_engine
from ledgrab.core.automations.automation_engine import (
AutomationEngine,
_assert_rule_handler_coverage,
)
from ledgrab.storage.automation import (
ApplicationRule,
DisplayStateRule,
HomeAssistantRule,
HTTPPollRule,
MQTTRule,
Rule,
StartupRule,
SystemIdleRule,
TimeOfDayRule,
WebhookRule,
)
EXPECTED_RULE_TYPES = {
StartupRule,
ApplicationRule,
TimeOfDayRule,
SystemIdleRule,
DisplayStateRule,
MQTTRule,
WebhookRule,
HomeAssistantRule,
HTTPPollRule,
}
def test_every_rule_type_has_a_handler():
"""The registry exactly covers the rule-type set the engine imports."""
assert set(AutomationEngine._RULE_HANDLERS.keys()) == EXPECTED_RULE_TYPES
def test_handlers_are_engine_methods():
"""Each handler value is a method defined on AutomationEngine."""
for rule_cls, handler in AutomationEngine._RULE_HANDLERS.items():
assert callable(handler), f"handler for {rule_cls.__name__} is not callable"
# Method names start with _handle_
assert handler.__name__.startswith(
"_handle_"
), f"handler for {rule_cls.__name__} has unexpected name {handler.__name__}"
def test_coverage_assertion_raises_when_handler_is_missing(monkeypatch):
"""Removing an entry from _RULE_HANDLERS makes the import-time check fail."""
# Build a clone of the registry without one entry to simulate drift.
original = dict(AutomationEngine._RULE_HANDLERS)
pruned = {k: v for k, v in original.items() if k is not WebhookRule}
monkeypatch.setattr(automation_engine.AutomationEngine, "_RULE_HANDLERS", pruned)
with pytest.raises(RuntimeError, match="WebhookRule"):
_assert_rule_handler_coverage()
def test_coverage_assertion_raises_when_unexpected_handler_added(monkeypatch):
"""An entry keyed by an unknown class is also caught."""
class _UnknownRule(Rule): # type: ignore[misc]
pass
extended = {**AutomationEngine._RULE_HANDLERS, _UnknownRule: lambda *a: True}
monkeypatch.setattr(automation_engine.AutomationEngine, "_RULE_HANDLERS", extended)
with pytest.raises(RuntimeError, match="_UnknownRule"):
_assert_rule_handler_coverage()