refactor(automations): rule dispatch via class-level handler table
AutomationEngine._evaluate_rule used to rebuild a 9-entry dispatch
dict on EVERY rule evaluation (audit finding H2). Unknown rule types
silently returned False — adding a new Rule subclass without an entry
just made it inert forever.
Refactor:
* Per-rule-type bodies are now ``_handle_<kind>(self, rule, ctx)``
methods on AutomationEngine.
* A ``_RuleEvalContext`` frozen dataclass bundles all the
cross-cutting state (running_procs, topmost_proc,
topmost_fullscreen, fullscreen_procs, idle_seconds, display_state)
so adding a new handler does not require widening
``_evaluate_rule``'s parameter list.
* ``AutomationEngine._RULE_HANDLERS`` is bound once at module-import
time after the class is defined.
* ``_assert_rule_handler_coverage()`` runs at import: every Rule
subclass imported by the module must have an entry, and entries
keyed by an unknown class are also rejected.
Unknown-type fallback now logs a warning instead of silently returning
False, so a future Rule subclass missing from the registry surfaces in
operator logs rather than just behaving as if the automation were off.
The pure storage layer (storage/automation.py) is untouched — the
handler bodies stay on the engine where the cross-layer dependencies
(MQTT runtime, HA manager, HTTP endpoint store, webhook state) live.
Tests: 4 new tests cover the rule-type/handler bijection, callable
shape, missing-entry rejection, and unknown-class rejection. 44
existing automation engine tests stay green; ruff clean.
This commit is contained in:
@@ -2,8 +2,9 @@
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional, Set
|
||||
from typing import Callable, Dict, Optional, Set
|
||||
|
||||
from ledgrab.core.automations.platform_detector import PlatformDetector
|
||||
from ledgrab.storage.automation import (
|
||||
@@ -11,6 +12,7 @@ from ledgrab.storage.automation import (
|
||||
Automation,
|
||||
DisplayStateRule,
|
||||
HomeAssistantRule,
|
||||
HTTPPollRule,
|
||||
MQTTRule,
|
||||
Rule,
|
||||
StartupRule,
|
||||
@@ -25,6 +27,53 @@ from ledgrab.utils import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RuleEvalContext:
|
||||
"""Per-tick environment passed to every rule handler.
|
||||
|
||||
Bundles all the cross-cutting state the various ``_evaluate_*``
|
||||
handlers need so adding a new handler does not require widening
|
||||
``_evaluate_rule``'s parameter list. ``frozen=True`` guards against
|
||||
a handler mutating its inputs.
|
||||
"""
|
||||
|
||||
running_procs: Set[str]
|
||||
topmost_proc: Optional[str]
|
||||
topmost_fullscreen: bool
|
||||
fullscreen_procs: Set[str]
|
||||
idle_seconds: Optional[float]
|
||||
display_state: Optional[str]
|
||||
|
||||
|
||||
def _apply_operator(operator: str, extracted, expected: str) -> bool:
|
||||
"""Compare *extracted* against *expected* using *operator*.
|
||||
|
||||
String operators (equals, not_equals, contains, regex) coerce the
|
||||
extracted value to str. Numeric operators (gt, lt) coerce both sides
|
||||
to float and return False on parse failure.
|
||||
"""
|
||||
if operator == "equals":
|
||||
return str(extracted) == expected
|
||||
if operator == "not_equals":
|
||||
return str(extracted) != expected
|
||||
if operator == "contains":
|
||||
return expected in str(extracted)
|
||||
if operator == "regex":
|
||||
try:
|
||||
return bool(re.search(expected, str(extracted)))
|
||||
except re.error as exc:
|
||||
logger.debug("HTTP poll rule regex error: %s", exc)
|
||||
return False
|
||||
if operator in ("gt", "lt"):
|
||||
try:
|
||||
lhs = float(extracted)
|
||||
rhs = float(expected)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
return lhs > rhs if operator == "gt" else lhs < rhs
|
||||
return False
|
||||
|
||||
|
||||
class AutomationEngine:
|
||||
"""Evaluates automation rules and activates/deactivates scene presets."""
|
||||
|
||||
@@ -38,12 +87,16 @@ class AutomationEngine:
|
||||
device_store=None,
|
||||
ha_manager=None,
|
||||
mqtt_manager=None,
|
||||
value_stream_manager=None,
|
||||
value_source_store=None,
|
||||
):
|
||||
self._store = automation_store
|
||||
self._manager = processor_manager
|
||||
self._poll_interval = poll_interval
|
||||
self._detector = PlatformDetector()
|
||||
self._mqtt_manager = mqtt_manager
|
||||
self._value_stream_manager = value_stream_manager
|
||||
self._value_source_store = value_source_store
|
||||
self._scene_preset_store = scene_preset_store
|
||||
self._target_store = target_store
|
||||
self._device_store = device_store
|
||||
@@ -65,12 +118,15 @@ class AutomationEngine:
|
||||
self._ha_acquired: Set[str] = set()
|
||||
# MQTT source IDs currently acquired by the engine
|
||||
self._mqtt_acquired: Set[str] = set()
|
||||
# Value source IDs currently acquired by the engine (for HTTPPollRule)
|
||||
self._value_sources_acquired: Set[str] = set()
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._task is not None:
|
||||
return
|
||||
await self._sync_ha_runtimes()
|
||||
await self._sync_mqtt_runtimes()
|
||||
self._sync_value_stream_refs()
|
||||
self._task = asyncio.create_task(self._poll_loop())
|
||||
logger.info("Automation engine started")
|
||||
|
||||
@@ -94,6 +150,8 @@ class AutomationEngine:
|
||||
await self._release_all_ha_runtimes()
|
||||
# Release all MQTT runtimes
|
||||
await self._release_all_mqtt_runtimes()
|
||||
# Release all value-stream refs held for HTTPPollRule evaluation
|
||||
self._release_all_value_stream_refs()
|
||||
|
||||
logger.info("Automation engine stopped")
|
||||
|
||||
@@ -183,6 +241,53 @@ class AutomationEngine:
|
||||
logger.warning("Failed to release MQTT runtime %s: %s", source_id, e)
|
||||
self._mqtt_acquired = set()
|
||||
|
||||
def _get_needed_value_sources(self) -> Set[str]:
|
||||
"""Collect value source IDs referenced by enabled HTTPPollRule rules."""
|
||||
needed: Set[str] = set()
|
||||
if self._value_stream_manager is None:
|
||||
return needed
|
||||
for a in self._store.get_all_automations():
|
||||
if a.enabled:
|
||||
for r in a.rules:
|
||||
if isinstance(r, HTTPPollRule) and r.value_source_id:
|
||||
needed.add(r.value_source_id)
|
||||
return needed
|
||||
|
||||
def _sync_value_stream_refs(self) -> None:
|
||||
"""Acquire/release ValueStreams to keep HTTPPollRule sources polling.
|
||||
|
||||
Mirrors the HA/MQTT sync pattern, but talks to ``ValueStreamManager``
|
||||
(which is sync). Acquiring a stream both starts its background poll
|
||||
task and pins the ref count; releasing decrements.
|
||||
"""
|
||||
if self._value_stream_manager is None:
|
||||
return
|
||||
needed = self._get_needed_value_sources()
|
||||
for vs_id in self._value_sources_acquired - needed:
|
||||
try:
|
||||
self._value_stream_manager.release(vs_id)
|
||||
logger.debug("Released value stream for automation: %s", vs_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to release value stream %s: %s", vs_id, e)
|
||||
for vs_id in needed - self._value_sources_acquired:
|
||||
try:
|
||||
self._value_stream_manager.acquire(vs_id)
|
||||
logger.debug("Acquired value stream for automation: %s", vs_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to acquire value stream %s: %s", vs_id, e)
|
||||
self._value_sources_acquired = needed
|
||||
|
||||
def _release_all_value_stream_refs(self) -> None:
|
||||
"""Release all ValueStreams held for HTTPPollRule evaluation."""
|
||||
if self._value_stream_manager is None:
|
||||
return
|
||||
for vs_id in self._value_sources_acquired:
|
||||
try:
|
||||
self._value_stream_manager.release(vs_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to release value stream %s: %s", vs_id, e)
|
||||
self._value_sources_acquired = set()
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
@@ -198,6 +303,7 @@ class AutomationEngine:
|
||||
async def _evaluate_all(self) -> None:
|
||||
await self._sync_ha_runtimes()
|
||||
await self._sync_mqtt_runtimes()
|
||||
self._sync_value_stream_refs()
|
||||
async with self._eval_lock:
|
||||
await self._evaluate_all_locked()
|
||||
|
||||
@@ -337,6 +443,12 @@ class AutomationEngine:
|
||||
return all(results)
|
||||
return any(results) # "or" is default
|
||||
|
||||
# Per-rule-type handlers. Built once at class-definition time (see
|
||||
# ``_RULE_HANDLERS`` below) so the dispatch dict is not rebuilt on every
|
||||
# tick the way the old inline body used to. Each handler signature is
|
||||
# ``(self, rule, ctx: _RuleEvalContext) -> bool``.
|
||||
_RULE_HANDLERS: "Dict[type, Callable[..., bool]]"
|
||||
|
||||
def _evaluate_rule(
|
||||
self,
|
||||
rule: Rule,
|
||||
@@ -347,22 +459,63 @@ class AutomationEngine:
|
||||
idle_seconds: Optional[float],
|
||||
display_state: Optional[str],
|
||||
) -> bool:
|
||||
dispatch = {
|
||||
StartupRule: lambda r: True,
|
||||
ApplicationRule: lambda r: self._evaluate_app_rule(
|
||||
r, running_procs, topmost_proc, topmost_fullscreen, fullscreen_procs
|
||||
),
|
||||
TimeOfDayRule: lambda r: self._evaluate_time_of_day(r),
|
||||
SystemIdleRule: lambda r: self._evaluate_idle(r, idle_seconds),
|
||||
DisplayStateRule: lambda r: self._evaluate_display_state(r, display_state),
|
||||
MQTTRule: lambda r: self._evaluate_mqtt(r),
|
||||
WebhookRule: lambda r: self._webhook_states.get(r.token, False),
|
||||
HomeAssistantRule: lambda r: self._evaluate_home_assistant(r),
|
||||
}
|
||||
handler = dispatch.get(type(rule))
|
||||
ctx = _RuleEvalContext(
|
||||
running_procs=running_procs,
|
||||
topmost_proc=topmost_proc,
|
||||
topmost_fullscreen=topmost_fullscreen,
|
||||
fullscreen_procs=fullscreen_procs,
|
||||
idle_seconds=idle_seconds,
|
||||
display_state=display_state,
|
||||
)
|
||||
handler = self._RULE_HANDLERS.get(type(rule))
|
||||
if handler is None:
|
||||
# Coverage of ``_RULE_HANDLERS`` is asserted at module import,
|
||||
# so reaching this branch means a Rule subclass slipped past
|
||||
# the assertion (e.g. a hand-built test instance). Log loudly
|
||||
# and fall back to the previous "treat as inactive" semantics.
|
||||
logger.warning(
|
||||
"No handler registered for rule type %s — treating as inactive",
|
||||
type(rule).__name__,
|
||||
)
|
||||
return False
|
||||
return handler(rule)
|
||||
return handler(self, rule, ctx)
|
||||
|
||||
# -- Per-rule-type handlers --
|
||||
# Bound to ``self`` via ``_RULE_HANDLERS`` lookup; each signature is
|
||||
# ``(self, rule, ctx: _RuleEvalContext) -> bool``.
|
||||
|
||||
def _handle_startup(self, rule: StartupRule, ctx: _RuleEvalContext) -> bool:
|
||||
return True
|
||||
|
||||
def _handle_application(self, rule: ApplicationRule, ctx: _RuleEvalContext) -> bool:
|
||||
return self._evaluate_app_rule(
|
||||
rule,
|
||||
ctx.running_procs,
|
||||
ctx.topmost_proc,
|
||||
ctx.topmost_fullscreen,
|
||||
ctx.fullscreen_procs,
|
||||
)
|
||||
|
||||
def _handle_time_of_day(self, rule: TimeOfDayRule, ctx: _RuleEvalContext) -> bool:
|
||||
return self._evaluate_time_of_day(rule)
|
||||
|
||||
def _handle_system_idle(self, rule: SystemIdleRule, ctx: _RuleEvalContext) -> bool:
|
||||
return self._evaluate_idle(rule, ctx.idle_seconds)
|
||||
|
||||
def _handle_display_state(self, rule: DisplayStateRule, ctx: _RuleEvalContext) -> bool:
|
||||
return self._evaluate_display_state(rule, ctx.display_state)
|
||||
|
||||
def _handle_mqtt(self, rule: MQTTRule, ctx: _RuleEvalContext) -> bool:
|
||||
return self._evaluate_mqtt(rule)
|
||||
|
||||
def _handle_webhook(self, rule: WebhookRule, ctx: _RuleEvalContext) -> bool:
|
||||
return self._webhook_states.get(rule.token, False)
|
||||
|
||||
def _handle_home_assistant(self, rule: HomeAssistantRule, ctx: _RuleEvalContext) -> bool:
|
||||
return self._evaluate_home_assistant(rule)
|
||||
|
||||
def _handle_http_poll(self, rule: HTTPPollRule, ctx: _RuleEvalContext) -> bool:
|
||||
return self._evaluate_http_poll(rule)
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_time_of_day(rule: TimeOfDayRule) -> bool:
|
||||
@@ -436,6 +589,25 @@ class AutomationEngine:
|
||||
logger.debug("HA rule regex error: %s", e)
|
||||
return False
|
||||
|
||||
def _evaluate_http_poll(self, rule: HTTPPollRule) -> bool:
|
||||
"""Evaluate an HTTPPollRule by reading the referenced value source.
|
||||
|
||||
The value source (HTTPValueSource → HTTPValueStream) handles the
|
||||
actual polling + JSON extraction; the rule only compares the
|
||||
already-extracted raw value to ``rule.value`` via ``operator``.
|
||||
"""
|
||||
if self._value_stream_manager is None or not rule.value_source_id:
|
||||
return False
|
||||
stream = self._value_stream_manager.peek(rule.value_source_id)
|
||||
if stream is None or not hasattr(stream, "get_raw_value"):
|
||||
return False
|
||||
raw = stream.get_raw_value()
|
||||
if rule.operator == "exists":
|
||||
return raw is not None
|
||||
if raw is None:
|
||||
return False
|
||||
return _apply_operator(rule.operator, raw, rule.value)
|
||||
|
||||
def _evaluate_app_rule(
|
||||
self,
|
||||
rule: ApplicationRule,
|
||||
@@ -636,3 +808,57 @@ class AutomationEngine:
|
||||
"""Deactivate an automation immediately (used when disabling/deleting)."""
|
||||
if automation_id in self._active_automations:
|
||||
await self._deactivate_automation(automation_id)
|
||||
|
||||
|
||||
# Bind the per-rule-type handler table once after the class is fully defined.
|
||||
# This replaces the per-call dict-rebuild that the inline ``_evaluate_rule``
|
||||
# used to do and gives us a single place to assert coverage against the
|
||||
# Rule subclass set imported from storage.
|
||||
AutomationEngine._RULE_HANDLERS = {
|
||||
StartupRule: AutomationEngine._handle_startup,
|
||||
ApplicationRule: AutomationEngine._handle_application,
|
||||
TimeOfDayRule: AutomationEngine._handle_time_of_day,
|
||||
SystemIdleRule: AutomationEngine._handle_system_idle,
|
||||
DisplayStateRule: AutomationEngine._handle_display_state,
|
||||
MQTTRule: AutomationEngine._handle_mqtt,
|
||||
WebhookRule: AutomationEngine._handle_webhook,
|
||||
HomeAssistantRule: AutomationEngine._handle_home_assistant,
|
||||
HTTPPollRule: AutomationEngine._handle_http_poll,
|
||||
}
|
||||
|
||||
|
||||
def _assert_rule_handler_coverage() -> None:
|
||||
"""Every concrete Rule subclass imported by this module must have a handler.
|
||||
|
||||
Runs at module import so a new Rule subclass added without an
|
||||
accompanying ``_handle_*`` method + ``_RULE_HANDLERS`` entry fails the
|
||||
server boot loudly instead of silently being dropped on the floor by
|
||||
``_evaluate_rule``'s "no handler → False" fallback.
|
||||
"""
|
||||
expected = {
|
||||
StartupRule,
|
||||
ApplicationRule,
|
||||
TimeOfDayRule,
|
||||
SystemIdleRule,
|
||||
DisplayStateRule,
|
||||
MQTTRule,
|
||||
WebhookRule,
|
||||
HomeAssistantRule,
|
||||
HTTPPollRule,
|
||||
}
|
||||
registered = set(AutomationEngine._RULE_HANDLERS.keys())
|
||||
missing = expected - registered
|
||||
extra = registered - expected
|
||||
if missing or extra:
|
||||
problems = []
|
||||
if missing:
|
||||
problems.append(f"missing handlers: {sorted(c.__name__ for c in missing)}")
|
||||
if extra:
|
||||
problems.append(f"unregistered classes: {sorted(c.__name__ for c in extra)}")
|
||||
raise RuntimeError(
|
||||
"AutomationEngine._RULE_HANDLERS is out of sync with imported Rule subclasses: "
|
||||
+ "; ".join(problems)
|
||||
)
|
||||
|
||||
|
||||
_assert_rule_handler_coverage()
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Tests for the AutomationEngine rule-handler dispatch registry.
|
||||
|
||||
Lock in the import-time invariant: every Rule subclass imported by the
|
||||
engine module has a corresponding ``_handle_*`` entry in
|
||||
``_RULE_HANDLERS``, and the coverage check rejects drift.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from ledgrab.core.automations import automation_engine
|
||||
from ledgrab.core.automations.automation_engine import (
|
||||
AutomationEngine,
|
||||
_assert_rule_handler_coverage,
|
||||
)
|
||||
from ledgrab.storage.automation import (
|
||||
ApplicationRule,
|
||||
DisplayStateRule,
|
||||
HomeAssistantRule,
|
||||
HTTPPollRule,
|
||||
MQTTRule,
|
||||
Rule,
|
||||
StartupRule,
|
||||
SystemIdleRule,
|
||||
TimeOfDayRule,
|
||||
WebhookRule,
|
||||
)
|
||||
|
||||
|
||||
EXPECTED_RULE_TYPES = {
|
||||
StartupRule,
|
||||
ApplicationRule,
|
||||
TimeOfDayRule,
|
||||
SystemIdleRule,
|
||||
DisplayStateRule,
|
||||
MQTTRule,
|
||||
WebhookRule,
|
||||
HomeAssistantRule,
|
||||
HTTPPollRule,
|
||||
}
|
||||
|
||||
|
||||
def test_every_rule_type_has_a_handler():
|
||||
"""The registry exactly covers the rule-type set the engine imports."""
|
||||
assert set(AutomationEngine._RULE_HANDLERS.keys()) == EXPECTED_RULE_TYPES
|
||||
|
||||
|
||||
def test_handlers_are_engine_methods():
|
||||
"""Each handler value is a method defined on AutomationEngine."""
|
||||
for rule_cls, handler in AutomationEngine._RULE_HANDLERS.items():
|
||||
assert callable(handler), f"handler for {rule_cls.__name__} is not callable"
|
||||
# Method names start with _handle_
|
||||
assert handler.__name__.startswith(
|
||||
"_handle_"
|
||||
), f"handler for {rule_cls.__name__} has unexpected name {handler.__name__}"
|
||||
|
||||
|
||||
def test_coverage_assertion_raises_when_handler_is_missing(monkeypatch):
|
||||
"""Removing an entry from _RULE_HANDLERS makes the import-time check fail."""
|
||||
# Build a clone of the registry without one entry to simulate drift.
|
||||
original = dict(AutomationEngine._RULE_HANDLERS)
|
||||
pruned = {k: v for k, v in original.items() if k is not WebhookRule}
|
||||
monkeypatch.setattr(automation_engine.AutomationEngine, "_RULE_HANDLERS", pruned)
|
||||
|
||||
with pytest.raises(RuntimeError, match="WebhookRule"):
|
||||
_assert_rule_handler_coverage()
|
||||
|
||||
|
||||
def test_coverage_assertion_raises_when_unexpected_handler_added(monkeypatch):
|
||||
"""An entry keyed by an unknown class is also caught."""
|
||||
|
||||
class _UnknownRule(Rule): # type: ignore[misc]
|
||||
pass
|
||||
|
||||
extended = {**AutomationEngine._RULE_HANDLERS, _UnknownRule: lambda *a: True}
|
||||
monkeypatch.setattr(automation_engine.AutomationEngine, "_RULE_HANDLERS", extended)
|
||||
|
||||
with pytest.raises(RuntimeError, match="_UnknownRule"):
|
||||
_assert_rule_handler_coverage()
|
||||
Reference in New Issue
Block a user