"""Tests for group device type: cycle detection, LED count resolution, GroupLEDClient.""" import numpy as np import pytest from ledgrab.core.devices.device_config import GroupConfig from ledgrab.core.devices.led_client import ProviderDeps from ledgrab.storage.database import Database from ledgrab.storage.device_store import Device, DeviceStore # ── Fixtures ────────────────────────────────────────────────────────── @pytest.fixture def tmp_db(tmp_path): db = Database(tmp_path / "test.db") yield db db.close() @pytest.fixture def store(tmp_db): return DeviceStore(tmp_db) def _create_device(store: DeviceStore, name: str, led_count: int = 30, **kwargs) -> Device: """Helper to create a simple mock device.""" return store.create_device( name=name, url=f"mock://{name}", led_count=led_count, device_type=kwargs.pop("device_type", "mock"), **kwargs, ) # ── Cycle Detection ────────────────────────────────────────────────── class TestCycleDetection: def test_valid_flat_group(self, store): d1 = _create_device(store, "d1", 30) d2 = _create_device(store, "d2", 60) # No cycle — should not raise store.validate_group_no_cycles(None, [d1.id, d2.id]) def test_self_reference(self, store): d1 = _create_device(store, "d1", 30) with pytest.raises(ValueError, match="Circular group reference"): store.validate_group_no_cycles(d1.id, [d1.id]) def test_simple_cycle(self, store): d1 = _create_device(store, "d1", 30) g1 = _create_device( store, "g1", 30, device_type="group", group_device_ids=[d1.id], group_mode="sequence", ) # g2 wants to contain g1, but g1 is also going to be in g2 → cycle with pytest.raises(ValueError, match="Circular group reference"): store.validate_group_no_cycles(g1.id, [g1.id]) def test_deep_cycle(self, store): d1 = _create_device(store, "d1", 30) g1 = _create_device( store, "g1", 30, device_type="group", group_device_ids=[d1.id], group_mode="sequence", ) g2 = _create_device( store, "g2", 30, device_type="group", group_device_ids=[g1.id], group_mode="sequence", ) # g3 wants g2, and we're editing g1 to contain g3 → cycle: g1→g3→g2→g1 g3 = _create_device( store, "g3", 30, device_type="group", group_device_ids=[g2.id], group_mode="sequence", ) with pytest.raises(ValueError, match="Circular group reference"): store.validate_group_no_cycles(g1.id, [g3.id]) def test_diamond_dag_allowed(self, store): """Diamond shape (A→B, A→C, B→D, C→D) is NOT a cycle.""" d = _create_device(store, "d", 30) g_b = _create_device( store, "g_b", 30, device_type="group", group_device_ids=[d.id], group_mode="sequence", ) g_c = _create_device( store, "g_c", 30, device_type="group", group_device_ids=[d.id], group_mode="sequence", ) # g_a contains both g_b and g_c, which both contain d — diamond, not cycle store.validate_group_no_cycles(None, [g_b.id, g_c.id]) def test_nonexistent_child_raises(self, store): with pytest.raises(ValueError, match="Referenced device not found"): store.validate_group_no_cycles(None, ["nonexistent_device"]) def test_valid_nested_groups(self, store): """Groups can contain other groups without cycles.""" d1 = _create_device(store, "d1", 30) d2 = _create_device(store, "d2", 60) g_inner = _create_device( store, "g_inner", 90, device_type="group", group_device_ids=[d1.id, d2.id], group_mode="sequence", ) # Outer group containing inner group + another device — valid d3 = _create_device(store, "d3", 20) store.validate_group_no_cycles(None, [g_inner.id, d3.id]) # ── LED Count Resolution ───────────────────────────────────────────── class TestLedCountResolution: def test_flat_sequence(self, store): d1 = _create_device(store, "d1", 30) d2 = _create_device(store, "d2", 60) d3 = _create_device(store, "d3", 10) total = store.resolve_group_led_count([d1.id, d2.id, d3.id]) assert total == 100 def test_nested_sequence_groups(self, store): d1 = _create_device(store, "d1", 30) d2 = _create_device(store, "d2", 60) g_inner = _create_device( store, "g_inner", 90, device_type="group", group_device_ids=[d1.id, d2.id], group_mode="sequence", ) d3 = _create_device(store, "d3", 20) total = store.resolve_group_led_count([g_inner.id, d3.id]) assert total == 110 # 30+60+20 def test_independent_child_uses_own_led_count(self, store): """Independent mode child group contributes its own led_count (not recursed).""" d1 = _create_device(store, "d1", 30) d2 = _create_device(store, "d2", 60) g_independent = _create_device( store, "g_ind", 100, device_type="group", group_device_ids=[d1.id, d2.id], group_mode="independent", ) d3 = _create_device(store, "d3", 20) # g_independent is in independent mode, so its led_count=100 is used directly total = store.resolve_group_led_count([g_independent.id, d3.id]) assert total == 120 # 100+20 def test_max_led_count(self, store): d1 = _create_device(store, "d1", 30) d2 = _create_device(store, "d2", 60) d3 = _create_device(store, "d3", 10) max_count = store.resolve_group_max_led_count([d1.id, d2.id, d3.id]) assert max_count == 60 def test_max_led_count_empty(self, store): assert store.resolve_group_max_led_count([]) == 1 def test_missing_device_skipped(self, store): d1 = _create_device(store, "d1", 30) total = store.resolve_group_led_count([d1.id, "nonexistent"]) assert total == 30 # ── Group References ────────────────────────────────────────────────── class TestGroupReferences: def test_get_groups_referencing(self, store): d1 = _create_device(store, "d1", 30) d2 = _create_device(store, "d2", 60) g1 = _create_device( store, "g1", 90, device_type="group", group_device_ids=[d1.id, d2.id], group_mode="sequence", ) g2 = _create_device( store, "g2", 30, device_type="group", group_device_ids=[d1.id], group_mode="sequence", ) refs = store.get_groups_referencing(d1.id) ref_ids = {r.id for r in refs} assert ref_ids == {g1.id, g2.id} def test_no_groups_referencing(self, store): d1 = _create_device(store, "d1", 30) assert store.get_groups_referencing(d1.id) == [] # ── GroupLEDClient ──────────────────────────────────────────────────── class TestGroupLEDClient: @pytest.fixture def mock_store(self, store): """Store with 3 mock devices for client tests.""" d1 = _create_device(store, "d1", 10) d2 = _create_device(store, "d2", 20) d3 = _create_device(store, "d3", 30) return store, [d1, d2, d3] def _make_client(self, store, devices, mode="sequence"): from ledgrab.core.devices.group_client import GroupLEDClient config = GroupConfig( device_id="test_group", device_url="group://test_group", led_count=sum(d.led_count for d in devices), group_mode=mode, group_device_ids=[d.id for d in devices], ) return GroupLEDClient(config=config, deps=ProviderDeps(device_store=store)) @pytest.mark.asyncio async def test_connect_creates_children(self, mock_store): store, devices = mock_store client = self._make_client(store, devices) await client.connect() assert client.is_connected assert client.device_led_count == 60 # 10+20+30 assert len(client._children) == 3 await client.close() @pytest.mark.asyncio async def test_sequence_mode_slices(self, mock_store): store, devices = mock_store client = self._make_client(store, devices) await client.connect() # Capture what each child receives sent_pixels = [] for child_client, _ in client._children: original_send = child_client.send_pixels async def capture_send(pixels, brightness, _orig=original_send, _idx=len(sent_pixels)): sent_pixels.append(np.asarray(pixels)) return await _orig(pixels, brightness) child_client.send_pixels = capture_send # Create a 60-pixel gradient pixels = np.arange(60 * 3, dtype=np.uint8).reshape(60, 3) await client.send_pixels(pixels, 255) assert len(sent_pixels) == 3 np.testing.assert_array_equal(sent_pixels[0], pixels[0:10]) np.testing.assert_array_equal(sent_pixels[1], pixels[10:30]) np.testing.assert_array_equal(sent_pixels[2], pixels[30:60]) await client.close() @pytest.mark.asyncio async def test_independent_mode_resamples(self, mock_store): store, devices = mock_store client = self._make_client(store, devices, mode="independent") await client.connect() sent_pixels = [] for child_client, _ in client._children: original_send = child_client.send_pixels async def capture_send(pixels, brightness, _orig=original_send): sent_pixels.append(np.asarray(pixels)) return await _orig(pixels, brightness) child_client.send_pixels = capture_send # Send 5 pixels — each child should get its own resampled version pixels = np.array( [[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 255, 0], [0, 255, 255]], dtype=np.uint8 ) await client.send_pixels(pixels, 255) assert len(sent_pixels) == 3 assert sent_pixels[0].shape == (10, 3) # resampled to 10 LEDs assert sent_pixels[1].shape == (20, 3) # resampled to 20 LEDs assert sent_pixels[2].shape == (30, 3) # resampled to 30 LEDs await client.close() @pytest.mark.asyncio async def test_close_cleans_up(self, mock_store): store, devices = mock_store client = self._make_client(store, devices) await client.connect() assert client.is_connected await client.close() assert not client.is_connected assert len(client._children) == 0 @pytest.mark.asyncio async def test_sequence_pads_short_pixels(self, mock_store): store, devices = mock_store client = self._make_client(store, devices) await client.connect() sent_pixels = [] for child_client, _ in client._children: original_send = child_client.send_pixels async def capture_send(pixels, brightness, _orig=original_send): sent_pixels.append(np.asarray(pixels)) return await _orig(pixels, brightness) child_client.send_pixels = capture_send # Send only 15 pixels (less than 60 total needed) pixels = np.ones((15, 3), dtype=np.uint8) * 128 await client.send_pixels(pixels, 255) assert len(sent_pixels) == 3 assert sent_pixels[0].shape == (10, 3) assert sent_pixels[1].shape == (20, 3) assert sent_pixels[2].shape == (30, 3) # First child gets full 10 pixels np.testing.assert_array_equal(sent_pixels[0], np.ones((10, 3), dtype=np.uint8) * 128) # Second child gets 5 real + 15 black np.testing.assert_array_equal(sent_pixels[1][:5], np.ones((5, 3), dtype=np.uint8) * 128) np.testing.assert_array_equal(sent_pixels[1][5:], np.zeros((15, 3), dtype=np.uint8)) await client.close()