"""Tests for the hardened sandboxed-Jinja expression engine.""" import pytest from ledgrab.utils.template_expr import ( GLOBALS, RESERVED_NAMES, TemplateValidationError, clamp, compile_template, extract_variables, finalize_result, validate_input_name, validate_template_expression, ) class TestCompileAndEval: def test_basic_eval(self): assert compile_template("min(a * 2, 1)")(a=0.3, raw={}) == pytest.approx(0.6) def test_clamp_global(self): assert compile_template("clamp((t - 18) / 10)")(t=22.5, raw={}) == pytest.approx(0.45) def test_raw_subscript(self): assert compile_template("raw['t'] / 100")(raw={"t": 42.0}) == pytest.approx(0.42) def test_ternary_and_comparison(self): expr = compile_template("a if a > 0.5 else b") assert expr(a=0.8, b=0.1, raw={}) == pytest.approx(0.8) assert expr(a=0.2, b=0.1, raw={}) == pytest.approx(0.1) def test_all_globals_callable(self): for tpl in ("min(a, b)", "max(a, b)", "abs(a - b)", "round(a, 1)", "clamp(a)"): compile_template(tpl)(a=0.4, b=0.6, raw={}) class TestRejections: @pytest.mark.parametrize( "tpl", [ "", " ", "a +", # syntax error "10 ** 3", # power bomb "'a' * 1000", # string repetition "a | pprint", # filter "a is defined", # test "a.__class__", # attribute access "raw['s'].format(1)", # str gadget via attribute "dict(x=1)", # non-global call "namespace(x=1)", "range(3)", "cycler(1, 2)", "[0] * 1000000", # list-literal repetition (memory bomb) "(1,) * 1000000", # tuple-literal repetition (memory bomb) "[1, 2, 3]", # bare list literal "{1: 2}", # dict literal ], ) def test_rejected(self, tpl): with pytest.raises(TemplateValidationError): validate_template_expression(tpl) @pytest.mark.parametrize( "tpl", [ "min(a * 2, 1)", "(a + b) / 2", "clamp((t - 18) / 10, 0, 1)", "raw['x'] / 100", "a if a > b else b", "abs(a - b)", ], ) def test_accepted(self, tpl): validate_template_expression(tpl) # must not raise class TestFinalizeResult: def test_nan_returns_default(self): assert finalize_result(float("nan"), 0.25) == 0.25 def test_inf_returns_default(self): assert finalize_result(float("inf"), 0.25) == 0.25 assert finalize_result(float("-inf"), 0.25) == 0.25 def test_non_numeric_returns_default(self): assert finalize_result("nope", 0.25) == 0.25 assert finalize_result(None, 0.25) == 0.25 def test_overflow_returns_default(self): # float() of a multi-hundred-digit int (chained big-int multiply) raises # OverflowError, not ValueError — must still fall back, not propagate. assert finalize_result(10**400, 0.25) == 0.25 def test_clamps_to_unit(self): assert finalize_result(5.0, 0.0) == 1.0 assert finalize_result(-1.0, 0.0) == 0.0 assert finalize_result(0.5, 0.0) == pytest.approx(0.5) def test_clamp_helper(self): assert clamp(2.0) == 1.0 assert clamp(-2.0) == 0.0 assert clamp(5.0, 0.0, 10.0) == 5.0 class TestInputNames: @pytest.mark.parametrize("name", ["audio", "cpu_load", "_x", "Temp2"]) def test_valid(self, name): validate_input_name(name) @pytest.mark.parametrize("name", ["", "1bad", "has space", "a-b", "a.b"]) def test_invalid_identifier(self, name): with pytest.raises(TemplateValidationError): validate_input_name(name) @pytest.mark.parametrize("name", sorted(RESERVED_NAMES)) def test_reserved(self, name): with pytest.raises(TemplateValidationError): validate_input_name(name) def test_globals_are_reserved(self): assert set(GLOBALS).issubset(RESERVED_NAMES) assert "raw" in RESERVED_NAMES class TestExtractVariables: def test_excludes_globals_and_raw(self): assert extract_variables("min(a, raw['x']) + b") == ["a", "b"] def test_empty_for_uncompilable(self): assert extract_variables("a +") == [] def test_constant_expression(self): assert extract_variables("clamp(0.5)") == []