feat(value-sources): add sandboxed-Jinja template combinator
A new `template` value source evaluates a hardened, sandboxed Jinja expression over the live values of other value sources — the system's first float combinator. Backend: - Shared engine (utils/template_expr.py): ImmutableSandboxedEnvironment with filters/tests and auto-injected globals stripped; only min/max/abs/round/ clamp exposed; rejects **, string/collection-literal repetition, attribute access and non-global calls; NaN/inf-safe result coercion. - TemplateValueSource model + TemplateValueStream runtime: compile-once, primitives-only eval context, raw[name] exposure, eval_interval throttle, ref-counted input acquire/release, rename-safe hot-update. - Validation: unbound-variable + reserved-name rejection, reference cycle/depth guards (depth-only at create, full cycle at update), runtime acquire() depth backstop, and delete referential-integrity. - API: Create/Update/Response schemas + discriminated unions, _RESPONSE_MAP, and an advisory POST /value-sources/validate-template endpoint. - Demo seed: a static source plus a template combinator example. Frontend: - Editor modal section: repeatable inputs list (EntitySelect rows), a zero-dependency Jinja syntax highlighter, a hints/reference panel, and a debounced live validator that gates Save (stale-response-safe). - Graph editor: read-only template node with one edge per input. - i18n (en/ru/zh), icon, and card rendering. Tests: engine, stream, factory/cycle, validate endpoint, and demo seed.
This commit is contained in:
@@ -0,0 +1,182 @@
|
||||
"""Tests for template value source API: CRUD, validate-template, delete-protection."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ledgrab.api import dependencies as deps
|
||||
from ledgrab.api.routes.value_sources import router
|
||||
from ledgrab.storage.value_source_store import ValueSourceStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _route_db(tmp_path):
|
||||
from ledgrab.storage.database import Database
|
||||
|
||||
db = Database(tmp_path / "test.db")
|
||||
yield db
|
||||
db.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(_route_db):
|
||||
return ValueSourceStore(_route_db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(store):
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
from ledgrab.api.auth import verify_api_key
|
||||
|
||||
app.dependency_overrides[verify_api_key] = lambda: "test-user"
|
||||
app.dependency_overrides[deps.get_value_source_store] = lambda: store
|
||||
app.dependency_overrides[deps.get_processor_manager] = lambda: MagicMock()
|
||||
app.dependency_overrides[deps.get_output_target_store] = lambda: MagicMock(
|
||||
get_all_targets=lambda: []
|
||||
)
|
||||
|
||||
deps._deps["processor_manager"] = MagicMock()
|
||||
return TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
|
||||
def _create(client, **over):
|
||||
body = {
|
||||
"source_type": "template",
|
||||
"name": "Combo",
|
||||
"template": "min(a * 2, 1)",
|
||||
"inputs": [{"name": "a", "value_source_id": ""}],
|
||||
"default_value": 0.2,
|
||||
}
|
||||
body.update(over)
|
||||
return client.post("/api/v1/value-sources", json=body)
|
||||
|
||||
|
||||
class TestCRUD:
|
||||
def test_create_get_list_roundtrip(self, client):
|
||||
r = _create(client)
|
||||
assert r.status_code == 201, r.text
|
||||
body = r.json()
|
||||
assert body["source_type"] == "template"
|
||||
assert body["return_type"] == "float"
|
||||
assert body["template"] == "min(a * 2, 1)"
|
||||
assert body["inputs"] == [{"name": "a", "value_source_id": ""}]
|
||||
assert body["default_value"] == 0.2
|
||||
sid = body["id"]
|
||||
|
||||
got = client.get(f"/api/v1/value-sources/{sid}").json()
|
||||
assert got["template"] == "min(a * 2, 1)"
|
||||
|
||||
lst = client.get("/api/v1/value-sources").json()
|
||||
assert any(s["id"] == sid and s["source_type"] == "template" for s in lst["sources"])
|
||||
|
||||
def test_update(self, client):
|
||||
sid = _create(client).json()["id"]
|
||||
r = client.put(
|
||||
f"/api/v1/value-sources/{sid}",
|
||||
json={"source_type": "template", "template": "clamp(a * 3)"},
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
assert r.json()["template"] == "clamp(a * 3)"
|
||||
|
||||
def test_create_compile_error_returns_400(self, client):
|
||||
r = _create(client, template="a +")
|
||||
assert r.status_code == 400
|
||||
|
||||
def test_create_reserved_name_returns_400(self, client):
|
||||
r = _create(client, inputs=[{"name": "min", "value_source_id": ""}])
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
class TestDeleteProtection:
|
||||
def test_delete_blocked_when_referenced(self, client):
|
||||
base = client.post(
|
||||
"/api/v1/value-sources",
|
||||
json={"source_type": "static", "name": "Base", "value": 0.5},
|
||||
).json()
|
||||
_create(
|
||||
client,
|
||||
name="Uses",
|
||||
template="b",
|
||||
inputs=[{"name": "b", "value_source_id": base["id"]}],
|
||||
)
|
||||
r = client.delete(f"/api/v1/value-sources/{base['id']}")
|
||||
assert r.status_code == 400
|
||||
assert "referenced by" in r.json()["detail"]
|
||||
|
||||
|
||||
class TestValidateEndpoint:
|
||||
def _validate(self, client, **body):
|
||||
return client.post("/api/v1/value-sources/validate-template", json=body)
|
||||
|
||||
def test_valid_expression(self, client):
|
||||
r = self._validate(
|
||||
client,
|
||||
template="min(a, b)",
|
||||
inputs=[{"name": "a", "value_source_id": ""}, {"name": "b", "value_source_id": ""}],
|
||||
)
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["valid"] is True
|
||||
assert set(data["variables"]) == {"a", "b"}
|
||||
|
||||
def test_compile_error(self, client):
|
||||
r = self._validate(client, template="a +", inputs=[])
|
||||
data = r.json()
|
||||
assert data["valid"] is False
|
||||
assert data["error"]
|
||||
|
||||
def test_reserved_name(self, client):
|
||||
r = self._validate(
|
||||
client, template="min(0,1)", inputs=[{"name": "raw", "value_source_id": ""}]
|
||||
)
|
||||
assert r.json()["valid"] is False
|
||||
|
||||
def test_missing_input_is_warning_not_error(self, client):
|
||||
r = self._validate(
|
||||
client, template="a", inputs=[{"name": "a", "value_source_id": "vs_nope"}]
|
||||
)
|
||||
data = r.json()
|
||||
assert data["valid"] is True
|
||||
assert data["warnings"]
|
||||
|
||||
def test_unbound_variable_is_error(self, client):
|
||||
# Typo: expression uses 'ha_enti' but the input is named 'ha_entity'.
|
||||
r = self._validate(
|
||||
client, template="ha_enti", inputs=[{"name": "ha_entity", "value_source_id": ""}]
|
||||
)
|
||||
data = r.json()
|
||||
assert data["valid"] is False
|
||||
assert any("unbound" in e for e in data["errors"])
|
||||
|
||||
def test_cycle_detected_with_id(self, client):
|
||||
t1 = _create(client, name="T1", template="clamp(0.5)", inputs=[]).json()
|
||||
t2 = _create(
|
||||
client,
|
||||
name="T2",
|
||||
template="x",
|
||||
inputs=[{"name": "x", "value_source_id": t1["id"]}],
|
||||
).json()
|
||||
# Editing t1 to point at t2 would close a cycle.
|
||||
r = self._validate(
|
||||
client, template="x", inputs=[{"name": "x", "value_source_id": t2["id"]}], id=t1["id"]
|
||||
)
|
||||
assert r.json()["valid"] is False
|
||||
|
||||
|
||||
class TestResponseMapCoverage:
|
||||
def test_template_in_response_map(self):
|
||||
from ledgrab.api.routes.value_sources import _RESPONSE_MAP
|
||||
from ledgrab.storage.value_source import TemplateValueSource
|
||||
|
||||
assert TemplateValueSource in _RESPONSE_MAP
|
||||
|
||||
def test_template_in_all_unions(self):
|
||||
from ledgrab.api.schemas import value_sources as sch
|
||||
|
||||
for union_name in ("ValueSourceResponse", "ValueSourceCreate", "ValueSourceUpdate"):
|
||||
src = repr(getattr(sch, union_name))
|
||||
assert "template" in src.lower() or "Template" in src
|
||||
@@ -0,0 +1,231 @@
|
||||
"""Tests for TemplateValueStream (the Jinja combinator runtime)."""
|
||||
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from ledgrab.core.processing.value_stream import (
|
||||
TemplateValueStream,
|
||||
ValueStreamManager,
|
||||
)
|
||||
from ledgrab.storage.value_source import TemplateValueSource
|
||||
|
||||
|
||||
# --- Fakes for precise control over input values / raw -----------------------
|
||||
|
||||
|
||||
class _FakeStream:
|
||||
_NO_RAW = object()
|
||||
|
||||
def __init__(self, value, raw=_NO_RAW):
|
||||
self._value = value
|
||||
self._raw = raw
|
||||
|
||||
def get_value(self):
|
||||
return self._value
|
||||
|
||||
# get_raw_value only exists when a raw value was provided
|
||||
def __getattr__(self, name):
|
||||
if name == "get_raw_value" and self._raw is not _FakeStream._NO_RAW:
|
||||
return lambda: self._raw
|
||||
raise AttributeError(name)
|
||||
|
||||
|
||||
class _FakeVSM:
|
||||
def __init__(self, streams):
|
||||
self._streams = streams # id -> _FakeStream
|
||||
self.refcounts = defaultdict(int)
|
||||
|
||||
def acquire(self, vs_id):
|
||||
self.refcounts[vs_id] += 1
|
||||
return self._streams[vs_id]
|
||||
|
||||
def release(self, vs_id):
|
||||
self.refcounts[vs_id] -= 1
|
||||
|
||||
|
||||
def _inputs(*pairs):
|
||||
return [{"name": n, "value_source_id": i} for n, i in pairs]
|
||||
|
||||
|
||||
def _make(template, inputs, streams, default_value=0.0, eval_interval=None):
|
||||
vsm = _FakeVSM(streams)
|
||||
stream = TemplateValueStream(
|
||||
template=template,
|
||||
inputs=inputs,
|
||||
default_value=default_value,
|
||||
eval_interval=eval_interval,
|
||||
value_stream_manager=vsm,
|
||||
)
|
||||
stream.start()
|
||||
return stream, vsm
|
||||
|
||||
|
||||
class TestEvaluation:
|
||||
def test_eval_with_inputs(self):
|
||||
stream, vsm = _make("min(a * 2, 1)", _inputs(("a", "vs_a")), {"vs_a": _FakeStream(0.3)})
|
||||
assert vsm.refcounts["vs_a"] == 1
|
||||
assert stream.get_value() == pytest.approx(0.6)
|
||||
|
||||
def test_clamps_out_of_range(self):
|
||||
stream, _ = _make("a * 10", _inputs(("a", "vs_a")), {"vs_a": _FakeStream(0.5)})
|
||||
assert stream.get_value() == 1.0 # 5.0 clamped
|
||||
|
||||
def test_two_inputs(self):
|
||||
stream, _ = _make(
|
||||
"(a + b) / 2",
|
||||
_inputs(("a", "vs_a"), ("b", "vs_b")),
|
||||
{"vs_a": _FakeStream(0.2), "vs_b": _FakeStream(0.8)},
|
||||
)
|
||||
assert stream.get_value() == pytest.approx(0.5)
|
||||
|
||||
def test_shared_id_single_ref(self):
|
||||
# Two variables bound to the same source share one acquisition.
|
||||
stream, vsm = _make(
|
||||
"min(a + b, 1)",
|
||||
_inputs(("a", "vs_x"), ("b", "vs_x")),
|
||||
{"vs_x": _FakeStream(0.3)},
|
||||
)
|
||||
assert vsm.refcounts["vs_x"] == 1
|
||||
assert stream.get_value() == pytest.approx(0.6)
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
def test_div_by_zero_returns_default(self):
|
||||
stream, _ = _make(
|
||||
"a / 0", _inputs(("a", "vs_a")), {"vs_a": _FakeStream(0.5)}, default_value=0.25
|
||||
)
|
||||
assert stream.get_value() == 0.25
|
||||
|
||||
def test_missing_variable_returns_default(self):
|
||||
# template references 'b' but only 'a' is bound
|
||||
stream, _ = _make(
|
||||
"a + b", _inputs(("a", "vs_a")), {"vs_a": _FakeStream(0.5)}, default_value=0.1
|
||||
)
|
||||
assert stream.get_value() == 0.1
|
||||
|
||||
def test_nan_returns_default(self):
|
||||
stream, _ = _make(
|
||||
"a - a", _inputs(("a", "vs_a")), {"vs_a": _FakeStream(float("inf"))}, default_value=0.3
|
||||
)
|
||||
# inf - inf = nan -> default
|
||||
assert stream.get_value() == 0.3
|
||||
|
||||
def test_invalid_template_uses_default(self):
|
||||
stream, _ = _make(
|
||||
"a +", _inputs(("a", "vs_a")), {"vs_a": _FakeStream(0.5)}, default_value=0.42
|
||||
)
|
||||
assert stream.get_value() == 0.42
|
||||
|
||||
|
||||
class TestRawExposure:
|
||||
def test_raw_present_when_stream_exposes_it(self):
|
||||
stream, _ = _make(
|
||||
"raw['t'] / 100",
|
||||
_inputs(("t", "vs_t")),
|
||||
{"vs_t": _FakeStream(0.5, raw=42.0)},
|
||||
)
|
||||
assert stream.get_value() == pytest.approx(0.42)
|
||||
|
||||
def test_raw_absent_without_getter(self):
|
||||
# input stream has no get_raw_value -> raw['t'] -> None -> error -> default
|
||||
stream, _ = _make(
|
||||
"raw['t'] / 100",
|
||||
_inputs(("t", "vs_t")),
|
||||
{"vs_t": _FakeStream(0.5)},
|
||||
default_value=0.2,
|
||||
)
|
||||
assert stream.get_value() == 0.2
|
||||
|
||||
def test_non_numeric_raw_is_dropped(self):
|
||||
# raw value is a string -> never crosses into sandbox -> raw['t'] absent
|
||||
stream, _ = _make(
|
||||
"raw['t'] / 100",
|
||||
_inputs(("t", "vs_t")),
|
||||
{"vs_t": _FakeStream(0.5, raw="playing")},
|
||||
default_value=0.15,
|
||||
)
|
||||
assert stream.get_value() == 0.15
|
||||
|
||||
|
||||
class TestLifecycle:
|
||||
def test_stop_releases_all(self):
|
||||
stream, vsm = _make(
|
||||
"min(a + b, 1)",
|
||||
_inputs(("a", "vs_a"), ("b", "vs_b")),
|
||||
{"vs_a": _FakeStream(0.1), "vs_b": _FakeStream(0.2)},
|
||||
)
|
||||
stream.stop()
|
||||
assert vsm.refcounts["vs_a"] == 0
|
||||
assert vsm.refcounts["vs_b"] == 0
|
||||
|
||||
def test_eval_interval_caches(self):
|
||||
backing = _FakeStream(0.2)
|
||||
stream, _ = _make("a", _inputs(("a", "vs_a")), {"vs_a": backing}, eval_interval=3600.0)
|
||||
first = stream.get_value()
|
||||
backing._value = 0.9 # change the live input
|
||||
# Cached within the interval -> still the first value.
|
||||
assert stream.get_value() == pytest.approx(first)
|
||||
|
||||
|
||||
class TestHotUpdate:
|
||||
def test_swap_input_releases_old_acquires_new(self):
|
||||
stream, vsm = _make(
|
||||
"a", _inputs(("a", "vs_a")), {"vs_a": _FakeStream(0.1), "vs_b": _FakeStream(0.9)}
|
||||
)
|
||||
assert vsm.refcounts["vs_a"] == 1
|
||||
new_src = TemplateValueSource(
|
||||
id="t1",
|
||||
name="t",
|
||||
source_type="template",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
template="a",
|
||||
inputs=_inputs(("a", "vs_b")),
|
||||
default_value=0.0,
|
||||
)
|
||||
stream.update_source(new_src)
|
||||
assert vsm.refcounts["vs_a"] == 0 # old released
|
||||
assert vsm.refcounts["vs_b"] == 1 # new acquired
|
||||
assert stream.get_value() == pytest.approx(0.9)
|
||||
|
||||
def test_rename_keeps_same_source(self):
|
||||
stream, vsm = _make("a", _inputs(("a", "vs_a")), {"vs_a": _FakeStream(0.7)})
|
||||
renamed = TemplateValueSource(
|
||||
id="t1",
|
||||
name="t",
|
||||
source_type="template",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
template="b", # variable renamed a -> b, same source id
|
||||
inputs=_inputs(("b", "vs_a")),
|
||||
default_value=0.0,
|
||||
)
|
||||
stream.update_source(renamed)
|
||||
assert vsm.refcounts["vs_a"] == 1 # not re-acquired (unchanged id)
|
||||
assert stream.get_value() == pytest.approx(0.7)
|
||||
|
||||
|
||||
class TestAcquireDepthBackstop:
|
||||
def test_self_reference_does_not_overflow(self):
|
||||
"""A cycle that bypassed storage validation must not stack-overflow."""
|
||||
now = datetime.now(timezone.utc)
|
||||
src = TemplateValueSource(
|
||||
id="vs_cycle",
|
||||
name="cycle",
|
||||
source_type="template",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
template="x",
|
||||
inputs=_inputs(("x", "vs_cycle")),
|
||||
default_value=0.0,
|
||||
)
|
||||
|
||||
class _CycleStore:
|
||||
def get_source(self, vs_id):
|
||||
return src
|
||||
|
||||
manager = ValueStreamManager(value_source_store=_CycleStore())
|
||||
stream = manager.acquire("vs_cycle") # must terminate, not recurse forever
|
||||
assert isinstance(stream.get_value(), float)
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Demo-seed regression tests (value sources, incl. the template combinator)."""
|
||||
|
||||
from ledgrab.core.demo_seed import seed_demo_data
|
||||
from ledgrab.storage.database import Database
|
||||
from ledgrab.storage.value_source import StaticValueSource, TemplateValueSource
|
||||
from ledgrab.storage.value_source_store import ValueSourceStore
|
||||
|
||||
|
||||
def _seed(tmp_path):
|
||||
db = Database(tmp_path / "demo.db")
|
||||
seed_demo_data(db)
|
||||
return db
|
||||
|
||||
|
||||
def test_demo_seeds_template_value_source(tmp_path):
|
||||
db = _seed(tmp_path)
|
||||
try:
|
||||
store = ValueSourceStore(db)
|
||||
by_id = {s.id: s for s in store.get_all_sources()}
|
||||
|
||||
base = by_id["vs_demo0001"]
|
||||
boost = by_id["vs_demo0002"]
|
||||
assert isinstance(base, StaticValueSource)
|
||||
assert isinstance(boost, TemplateValueSource)
|
||||
assert boost.template == "clamp(level * 1.5)"
|
||||
assert boost.inputs == [{"name": "level", "value_source_id": "vs_demo0001"}]
|
||||
|
||||
# The reference graph is intact and consistent.
|
||||
assert store.get_transitive_dependencies("vs_demo0002") == {"vs_demo0001"}
|
||||
assert store.find_referencing_sources("vs_demo0001") == [boost.name]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_demo_template_evaluates_through_manager(tmp_path):
|
||||
"""The seeded template must actually evaluate over its seeded input."""
|
||||
from ledgrab.core.processing.value_stream import ValueStreamManager
|
||||
|
||||
db = _seed(tmp_path)
|
||||
try:
|
||||
store = ValueSourceStore(db)
|
||||
vsm = ValueStreamManager(value_source_store=store)
|
||||
stream = vsm.acquire("vs_demo0002")
|
||||
try:
|
||||
# base level 0.5 -> clamp(0.5 * 1.5) = 0.75
|
||||
assert abs(stream.get_value() - 0.75) < 1e-6
|
||||
finally:
|
||||
vsm.release("vs_demo0002")
|
||||
finally:
|
||||
db.close()
|
||||
@@ -0,0 +1,231 @@
|
||||
"""Tests for the template value source: model, factory, cycle/depth, refs."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from ledgrab.storage.value_source import TemplateValueSource
|
||||
|
||||
|
||||
class TestModelRoundTrip:
|
||||
def _make(self, **over):
|
||||
now = datetime.now(timezone.utc)
|
||||
defaults = dict(
|
||||
id="vs_t1",
|
||||
name="Combo",
|
||||
source_type="template",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
template="min(a * 2, 1)",
|
||||
inputs=[{"name": "a", "value_source_id": "vs_a"}],
|
||||
default_value=0.2,
|
||||
eval_interval=1.5,
|
||||
)
|
||||
defaults.update(over)
|
||||
return TemplateValueSource(**defaults)
|
||||
|
||||
def test_to_from_dict_idempotent(self):
|
||||
src = self._make()
|
||||
rebuilt = TemplateValueSource.from_dict(src.to_dict())
|
||||
assert rebuilt.template == src.template
|
||||
assert rebuilt.inputs == src.inputs
|
||||
assert rebuilt.default_value == src.default_value
|
||||
assert rebuilt.eval_interval == src.eval_interval
|
||||
assert rebuilt.to_dict()["return_type"] == "float"
|
||||
|
||||
def test_old_row_deserializes_with_defaults(self):
|
||||
"""A row written before template fields existed must load safely."""
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
src = TemplateValueSource.from_dict(
|
||||
{
|
||||
"id": "vs_old",
|
||||
"name": "Old",
|
||||
"source_type": "template",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
)
|
||||
assert src.template == ""
|
||||
assert src.inputs == []
|
||||
assert src.default_value == 0.0
|
||||
assert src.eval_interval is None
|
||||
|
||||
def test_dirty_scalars_coerce_to_defaults(self):
|
||||
"""Non-numeric stored scalars must not drop the whole row on load."""
|
||||
src = TemplateValueSource.from_dict(
|
||||
{
|
||||
"id": "x",
|
||||
"name": "n",
|
||||
"source_type": "template",
|
||||
"template": "a",
|
||||
"default_value": "not-a-number",
|
||||
"eval_interval": "bad",
|
||||
}
|
||||
)
|
||||
assert src.default_value == 0.0
|
||||
assert src.eval_interval is None
|
||||
|
||||
def test_inputs_normalized_from_dirty_data(self):
|
||||
src = TemplateValueSource.from_dict(
|
||||
{
|
||||
"id": "x",
|
||||
"name": "n",
|
||||
"source_type": "template",
|
||||
"inputs": [{"name": "a", "value_source_id": "vs_a"}, "junk", {"bad": 1}],
|
||||
}
|
||||
)
|
||||
# non-dict entries dropped; dict entries coerced to {name, value_source_id}
|
||||
assert src.inputs == [
|
||||
{"name": "a", "value_source_id": "vs_a"},
|
||||
{"name": "", "value_source_id": ""},
|
||||
]
|
||||
|
||||
|
||||
class TestFactoryCreate:
|
||||
def test_create_valid(self, value_source_store):
|
||||
src = value_source_store.create_source(
|
||||
"Combo",
|
||||
"template",
|
||||
template="min(a * 2, 1)",
|
||||
inputs=[{"name": "a", "value_source_id": ""}],
|
||||
default_value=0.3,
|
||||
)
|
||||
assert isinstance(src, TemplateValueSource)
|
||||
assert src.id.startswith("vs_")
|
||||
assert src.default_value == 0.3
|
||||
|
||||
def test_empty_template_rejected(self, value_source_store):
|
||||
with pytest.raises(ValueError):
|
||||
value_source_store.create_source("X", "template", template=" ", inputs=[])
|
||||
|
||||
def test_compile_error_rejected(self, value_source_store):
|
||||
with pytest.raises(ValueError):
|
||||
value_source_store.create_source("X", "template", template="a +", inputs=[])
|
||||
|
||||
def test_cost_bomb_rejected(self, value_source_store):
|
||||
with pytest.raises(ValueError):
|
||||
value_source_store.create_source("X", "template", template="10 ** 10", inputs=[])
|
||||
|
||||
def test_reserved_input_name_rejected(self, value_source_store):
|
||||
with pytest.raises(ValueError):
|
||||
value_source_store.create_source(
|
||||
"X",
|
||||
"template",
|
||||
template="min(0, 1)",
|
||||
inputs=[{"name": "min", "value_source_id": "vs_a"}],
|
||||
)
|
||||
|
||||
def test_duplicate_input_name_rejected(self, value_source_store):
|
||||
with pytest.raises(ValueError):
|
||||
value_source_store.create_source(
|
||||
"X",
|
||||
"template",
|
||||
template="a",
|
||||
inputs=[
|
||||
{"name": "a", "value_source_id": "vs_a"},
|
||||
{"name": "a", "value_source_id": "vs_b"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_default_value_out_of_range_rejected(self, value_source_store):
|
||||
with pytest.raises(ValueError):
|
||||
value_source_store.create_source(
|
||||
"X",
|
||||
"template",
|
||||
template="a",
|
||||
inputs=[{"name": "a", "value_source_id": ""}],
|
||||
default_value=5.0,
|
||||
)
|
||||
|
||||
def test_unbound_variable_rejected(self, value_source_store):
|
||||
# 'ha_enti' is referenced but only 'ha_entity' is bound (typo) → reject.
|
||||
with pytest.raises(ValueError):
|
||||
value_source_store.create_source(
|
||||
"X",
|
||||
"template",
|
||||
template="ha_enti",
|
||||
inputs=[{"name": "ha_entity", "value_source_id": ""}],
|
||||
)
|
||||
|
||||
|
||||
class TestFactoryUpdate:
|
||||
def test_partial_update_template_only(self, value_source_store):
|
||||
src = value_source_store.create_source(
|
||||
"X",
|
||||
"template",
|
||||
template="a",
|
||||
inputs=[{"name": "a", "value_source_id": ""}],
|
||||
default_value=0.1,
|
||||
)
|
||||
updated = value_source_store.update_source(src.id, template="clamp(a * 3)")
|
||||
assert updated.template == "clamp(a * 3)"
|
||||
assert updated.default_value == 0.1 # unchanged
|
||||
|
||||
def test_update_invalid_template_rejected(self, value_source_store):
|
||||
src = value_source_store.create_source("X", "template", template="clamp(0.5)", inputs=[])
|
||||
with pytest.raises(ValueError):
|
||||
value_source_store.update_source(src.id, template="a |")
|
||||
|
||||
|
||||
class TestCycleAndDepth:
|
||||
def test_self_reference_rejected(self, value_source_store):
|
||||
t = value_source_store.create_source("T", "template", template="clamp(0.5)", inputs=[])
|
||||
with pytest.raises(ValueError):
|
||||
value_source_store.update_source(t.id, inputs=[{"name": "x", "value_source_id": t.id}])
|
||||
|
||||
def test_circular_reference_rejected(self, value_source_store):
|
||||
t1 = value_source_store.create_source("T1", "template", template="clamp(0.5)", inputs=[])
|
||||
t2 = value_source_store.create_source(
|
||||
"T2",
|
||||
"template",
|
||||
template="x",
|
||||
inputs=[{"name": "x", "value_source_id": t1.id}],
|
||||
)
|
||||
# t1 -> t2 -> t1 would be a cycle
|
||||
with pytest.raises(ValueError):
|
||||
value_source_store.update_source(
|
||||
t1.id, inputs=[{"name": "x", "value_source_id": t2.id}]
|
||||
)
|
||||
|
||||
def test_deep_chain_rejected(self, value_source_store):
|
||||
prev = value_source_store.create_source("L0", "template", template="clamp(0.5)", inputs=[])
|
||||
created = 1
|
||||
with pytest.raises(ValueError):
|
||||
for i in range(1, 12):
|
||||
node = value_source_store.create_source(
|
||||
f"L{i}",
|
||||
"template",
|
||||
template="x",
|
||||
inputs=[{"name": "x", "value_source_id": prev.id}],
|
||||
)
|
||||
prev = node
|
||||
created += 1
|
||||
# Should have rejected before building an unbounded chain.
|
||||
assert created <= 8
|
||||
|
||||
def test_get_transitive_dependencies(self, value_source_store):
|
||||
leaf = value_source_store.create_source(
|
||||
"leaf", "template", template="clamp(0.5)", inputs=[]
|
||||
)
|
||||
mid = value_source_store.create_source(
|
||||
"mid", "template", template="x", inputs=[{"name": "x", "value_source_id": leaf.id}]
|
||||
)
|
||||
top = value_source_store.create_source(
|
||||
"top", "template", template="x", inputs=[{"name": "x", "value_source_id": mid.id}]
|
||||
)
|
||||
deps = value_source_store.get_transitive_dependencies(top.id)
|
||||
assert deps == {mid.id, leaf.id}
|
||||
|
||||
|
||||
class TestReferencingSources:
|
||||
def test_find_referencing_sources(self, value_source_store):
|
||||
base = value_source_store.create_source("Base", "static", value=0.5)
|
||||
tmpl = value_source_store.create_source(
|
||||
"Uses",
|
||||
"template",
|
||||
template="b",
|
||||
inputs=[{"name": "b", "value_source_id": base.id}],
|
||||
)
|
||||
refs = value_source_store.find_referencing_sources(base.id)
|
||||
assert tmpl.name in refs
|
||||
assert value_source_store.find_referencing_sources(tmpl.id) == []
|
||||
@@ -0,0 +1,136 @@
|
||||
"""Tests for the hardened sandboxed-Jinja expression engine."""
|
||||
|
||||
import pytest
|
||||
|
||||
from ledgrab.utils.template_expr import (
|
||||
GLOBALS,
|
||||
RESERVED_NAMES,
|
||||
TemplateValidationError,
|
||||
clamp,
|
||||
compile_template,
|
||||
extract_variables,
|
||||
finalize_result,
|
||||
validate_input_name,
|
||||
validate_template_expression,
|
||||
)
|
||||
|
||||
|
||||
class TestCompileAndEval:
|
||||
def test_basic_eval(self):
|
||||
assert compile_template("min(a * 2, 1)")(a=0.3, raw={}) == pytest.approx(0.6)
|
||||
|
||||
def test_clamp_global(self):
|
||||
assert compile_template("clamp((t - 18) / 10)")(t=22.5, raw={}) == pytest.approx(0.45)
|
||||
|
||||
def test_raw_subscript(self):
|
||||
assert compile_template("raw['t'] / 100")(raw={"t": 42.0}) == pytest.approx(0.42)
|
||||
|
||||
def test_ternary_and_comparison(self):
|
||||
expr = compile_template("a if a > 0.5 else b")
|
||||
assert expr(a=0.8, b=0.1, raw={}) == pytest.approx(0.8)
|
||||
assert expr(a=0.2, b=0.1, raw={}) == pytest.approx(0.1)
|
||||
|
||||
def test_all_globals_callable(self):
|
||||
for tpl in ("min(a, b)", "max(a, b)", "abs(a - b)", "round(a, 1)", "clamp(a)"):
|
||||
compile_template(tpl)(a=0.4, b=0.6, raw={})
|
||||
|
||||
|
||||
class TestRejections:
|
||||
@pytest.mark.parametrize(
|
||||
"tpl",
|
||||
[
|
||||
"",
|
||||
" ",
|
||||
"a +", # syntax error
|
||||
"10 ** 3", # power bomb
|
||||
"'a' * 1000", # string repetition
|
||||
"a | pprint", # filter
|
||||
"a is defined", # test
|
||||
"a.__class__", # attribute access
|
||||
"raw['s'].format(1)", # str gadget via attribute
|
||||
"dict(x=1)", # non-global call
|
||||
"namespace(x=1)",
|
||||
"range(3)",
|
||||
"cycler(1, 2)",
|
||||
"[0] * 1000000", # list-literal repetition (memory bomb)
|
||||
"(1,) * 1000000", # tuple-literal repetition (memory bomb)
|
||||
"[1, 2, 3]", # bare list literal
|
||||
"{1: 2}", # dict literal
|
||||
],
|
||||
)
|
||||
def test_rejected(self, tpl):
|
||||
with pytest.raises(TemplateValidationError):
|
||||
validate_template_expression(tpl)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tpl",
|
||||
[
|
||||
"min(a * 2, 1)",
|
||||
"(a + b) / 2",
|
||||
"clamp((t - 18) / 10, 0, 1)",
|
||||
"raw['x'] / 100",
|
||||
"a if a > b else b",
|
||||
"abs(a - b)",
|
||||
],
|
||||
)
|
||||
def test_accepted(self, tpl):
|
||||
validate_template_expression(tpl) # must not raise
|
||||
|
||||
|
||||
class TestFinalizeResult:
|
||||
def test_nan_returns_default(self):
|
||||
assert finalize_result(float("nan"), 0.25) == 0.25
|
||||
|
||||
def test_inf_returns_default(self):
|
||||
assert finalize_result(float("inf"), 0.25) == 0.25
|
||||
assert finalize_result(float("-inf"), 0.25) == 0.25
|
||||
|
||||
def test_non_numeric_returns_default(self):
|
||||
assert finalize_result("nope", 0.25) == 0.25
|
||||
assert finalize_result(None, 0.25) == 0.25
|
||||
|
||||
def test_overflow_returns_default(self):
|
||||
# float() of a multi-hundred-digit int (chained big-int multiply) raises
|
||||
# OverflowError, not ValueError — must still fall back, not propagate.
|
||||
assert finalize_result(10**400, 0.25) == 0.25
|
||||
|
||||
def test_clamps_to_unit(self):
|
||||
assert finalize_result(5.0, 0.0) == 1.0
|
||||
assert finalize_result(-1.0, 0.0) == 0.0
|
||||
assert finalize_result(0.5, 0.0) == pytest.approx(0.5)
|
||||
|
||||
def test_clamp_helper(self):
|
||||
assert clamp(2.0) == 1.0
|
||||
assert clamp(-2.0) == 0.0
|
||||
assert clamp(5.0, 0.0, 10.0) == 5.0
|
||||
|
||||
|
||||
class TestInputNames:
|
||||
@pytest.mark.parametrize("name", ["audio", "cpu_load", "_x", "Temp2"])
|
||||
def test_valid(self, name):
|
||||
validate_input_name(name)
|
||||
|
||||
@pytest.mark.parametrize("name", ["", "1bad", "has space", "a-b", "a.b"])
|
||||
def test_invalid_identifier(self, name):
|
||||
with pytest.raises(TemplateValidationError):
|
||||
validate_input_name(name)
|
||||
|
||||
@pytest.mark.parametrize("name", sorted(RESERVED_NAMES))
|
||||
def test_reserved(self, name):
|
||||
with pytest.raises(TemplateValidationError):
|
||||
validate_input_name(name)
|
||||
|
||||
def test_globals_are_reserved(self):
|
||||
assert set(GLOBALS).issubset(RESERVED_NAMES)
|
||||
assert "raw" in RESERVED_NAMES
|
||||
|
||||
|
||||
class TestExtractVariables:
|
||||
def test_excludes_globals_and_raw(self):
|
||||
assert extract_variables("min(a, raw['x']) + b") == ["a", "b"]
|
||||
|
||||
def test_empty_for_uncompilable(self):
|
||||
assert extract_variables("a +") == []
|
||||
|
||||
def test_constant_expression(self):
|
||||
assert extract_variables("clamp(0.5)") == []
|
||||
Reference in New Issue
Block a user