"""Tests for BaseJsonStore — the shared data-layer base class.""" import json from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from pathlib import Path import pytest from wled_controller.storage.base_store import BaseJsonStore, EntityNotFoundError # --------------------------------------------------------------------------- # Minimal concrete store for testing the base class # --------------------------------------------------------------------------- @dataclass class _Item: id: str name: str value: int = 0 def to_dict(self) -> dict: return {"id": self.id, "name": self.name, "value": self.value} @staticmethod def from_dict(data: dict) -> "_Item": return _Item(id=data["id"], name=data["name"], value=data.get("value", 0)) class _TestStore(BaseJsonStore[_Item]): _json_key = "items" _entity_name = "Item" def __init__(self, file_path: str): super().__init__(file_path, _Item.from_dict) def add(self, item: _Item) -> None: with self._lock: self._check_name_unique(item.name) self._items[item.id] = item self._save() class _LegacyStore(BaseJsonStore[_Item]): """Store that supports legacy JSON keys for migration testing.""" _json_key = "items_v2" _entity_name = "Item" _legacy_json_keys = ["items_v1", "old_items"] def __init__(self, file_path: str): super().__init__(file_path, _Item.from_dict) # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture def store_file(tmp_path) -> Path: return tmp_path / "test_store.json" @pytest.fixture def store(store_file) -> _TestStore: return _TestStore(str(store_file)) # --------------------------------------------------------------------------- # Initialization # --------------------------------------------------------------------------- class TestInit: def test_empty_init(self, store): assert store.count() == 0 assert store.get_all() == [] def test_file_not_found_starts_empty(self, tmp_path): s = _TestStore(str(tmp_path / "missing.json")) assert s.count() == 0 def test_load_from_existing_file(self, store_file): data = { "version": "1.0.0", "items": { "a": {"id": "a", "name": "Alpha", "value": 1}, "b": {"id": "b", "name": "Beta", "value": 2}, }, } store_file.write_text(json.dumps(data), encoding="utf-8") s = _TestStore(str(store_file)) assert s.count() == 2 assert s.get("a").name == "Alpha" assert s.get("b").value == 2 def test_load_skips_corrupt_items(self, store_file): """Items that fail deserialization are skipped, not fatal.""" data = { "version": "1.0.0", "items": { "good": {"id": "good", "name": "OK"}, "bad": {"missing_required": True}, }, } store_file.write_text(json.dumps(data), encoding="utf-8") s = _TestStore(str(store_file)) assert s.count() == 1 assert s.get("good").name == "OK" def test_load_corrupt_json_raises(self, store_file): """Completely invalid JSON file raises on load.""" store_file.write_text("{bad json", encoding="utf-8") with pytest.raises(Exception): _TestStore(str(store_file)) # --------------------------------------------------------------------------- # CRUD operations # --------------------------------------------------------------------------- class TestCRUD: def test_get_all_returns_list(self, store): store.add(_Item(id="x", name="X")) items = store.get_all() assert isinstance(items, list) assert len(items) == 1 def test_get_existing(self, store): store.add(_Item(id="x", name="X", value=42)) item = store.get("x") assert item.id == "x" assert item.value == 42 def test_get_not_found_raises(self, store): with pytest.raises(EntityNotFoundError, match="not found"): store.get("nonexistent") def test_delete_existing(self, store): store.add(_Item(id="x", name="X")) store.delete("x") assert store.count() == 0 def test_delete_not_found_raises(self, store): with pytest.raises(EntityNotFoundError, match="not found"): store.delete("nonexistent") def test_count(self, store): assert store.count() == 0 store.add(_Item(id="a", name="A")) assert store.count() == 1 store.add(_Item(id="b", name="B")) assert store.count() == 2 store.delete("a") assert store.count() == 1 # --------------------------------------------------------------------------- # Persistence (save/load round-trip) # --------------------------------------------------------------------------- class TestPersistence: def test_save_and_reload(self, store_file): s1 = _TestStore(str(store_file)) s1.add(_Item(id="p1", name="Persisted", value=99)) # Load fresh from the same file s2 = _TestStore(str(store_file)) assert s2.count() == 1 assert s2.get("p1").value == 99 def test_delete_persists(self, store_file): s1 = _TestStore(str(store_file)) s1.add(_Item(id="del", name="ToDelete")) s1.delete("del") s2 = _TestStore(str(store_file)) assert s2.count() == 0 def test_json_file_structure(self, store, store_file): store.add(_Item(id="s1", name="Struct", value=7)) raw = json.loads(store_file.read_text(encoding="utf-8")) assert "version" in raw assert "items" in raw assert raw["items"]["s1"]["name"] == "Struct" # --------------------------------------------------------------------------- # Name uniqueness # --------------------------------------------------------------------------- class TestNameUniqueness: def test_duplicate_name_raises(self, store): store.add(_Item(id="a", name="Unique")) with pytest.raises(ValueError, match="already exists"): store.add(_Item(id="b", name="Unique")) def test_different_names_ok(self, store): store.add(_Item(id="a", name="Alpha")) store.add(_Item(id="b", name="Beta")) assert store.count() == 2 def test_empty_name_raises(self, store): with pytest.raises(ValueError, match="required"): store._check_name_unique("") def test_whitespace_name_raises(self, store): with pytest.raises(ValueError, match="required"): store._check_name_unique(" ") def test_exclude_id_allows_self(self, store): store.add(_Item(id="a", name="Alpha")) # Checking uniqueness for a rename of item "a" — should not conflict with itself store._check_name_unique("Alpha", exclude_id="a") # should not raise # --------------------------------------------------------------------------- # Thread safety # --------------------------------------------------------------------------- class TestThreadSafety: def test_concurrent_reads(self, store): for i in range(20): store.add(_Item(id=f"t{i}", name=f"Thread {i}")) results = [] def _read(): return store.count() with ThreadPoolExecutor(max_workers=8) as pool: futures = [pool.submit(_read) for _ in range(50)] results = [f.result() for f in as_completed(futures)] assert all(r == 20 for r in results) def test_concurrent_add_and_read(self, tmp_path): """Concurrent adds should not lose items or corrupt state.""" s = _TestStore(str(tmp_path / "concurrent.json")) errors = [] def _add(index): try: s.add(_Item(id=f"c{index}", name=f"Conc {index}")) except Exception as e: errors.append(e) with ThreadPoolExecutor(max_workers=8) as pool: futures = [pool.submit(_add, i) for i in range(30)] for f in as_completed(futures): f.result() assert len(errors) == 0 assert s.count() == 30 # --------------------------------------------------------------------------- # Legacy key migration # --------------------------------------------------------------------------- class TestLegacyKeyMigration: def test_loads_from_legacy_key(self, store_file): data = { "version": "1.0.0", "items_v1": { "old1": {"id": "old1", "name": "Legacy"}, }, } store_file.write_text(json.dumps(data), encoding="utf-8") s = _LegacyStore(str(store_file)) assert s.count() == 1 assert s.get("old1").name == "Legacy" def test_primary_key_takes_precedence(self, store_file): data = { "version": "1.0.0", "items_v2": {"new": {"id": "new", "name": "Primary"}}, "items_v1": {"old": {"id": "old", "name": "Legacy"}}, } store_file.write_text(json.dumps(data), encoding="utf-8") s = _LegacyStore(str(store_file)) assert s.count() == 1 assert s.get("new").name == "Primary" # --------------------------------------------------------------------------- # Async delete # --------------------------------------------------------------------------- class TestAsyncDelete: @pytest.mark.asyncio async def test_async_delete(self, store): store.add(_Item(id="ad", name="AsyncDel")) assert store.count() == 1 await store.async_delete("ad") assert store.count() == 0 @pytest.mark.asyncio async def test_async_delete_not_found(self, store): with pytest.raises(EntityNotFoundError, match="not found"): await store.async_delete("nope")