fix: isolate tests from production database

Tests that imported wled_controller.main at module level caused the real
production database (data/ledgrab.db) to be opened before test fixtures
could patch the config. This led to silent data loss.

Patch the global config singleton at conftest module level (before any
test imports main.py) to redirect all DB access to a temp directory.
This commit is contained in:
2026-04-01 19:01:56 +03:00
parent 6b0e4e5539
commit 992495e2e4
3 changed files with 79 additions and 24 deletions
@@ -1,9 +1,7 @@
"""API tests for audio processing template endpoints.""" """API tests for audio processing template endpoints."""
import pytest import pytest
from fastapi.testclient import TestClient
from wled_controller.main import app
from wled_controller.config import get_config from wled_controller.config import get_config
# Ensure audio filters registered # Ensure audio filters registered
@@ -17,6 +15,9 @@ AUTH = {"Authorization": f"Bearer {_api_key}"} if _api_key else {}
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def client(): def client():
"""Provide a TestClient with lifespan (startup/shutdown) properly triggered.""" """Provide a TestClient with lifespan (startup/shutdown) properly triggered."""
from fastapi.testclient import TestClient
from wled_controller.main import app
with TestClient(app) as c: with TestClient(app) as c:
yield c yield c
+57 -11
View File
@@ -1,20 +1,52 @@
"""Pytest configuration and shared fixtures.""" """Pytest configuration and shared fixtures.
IMPORTANT: This conftest patches the global config singleton BEFORE any test
module can import ``wled_controller.main``. ``main.py`` reads ``get_config()``
at module level to open the database — if the singleton is not patched first,
the REAL production database (``data/ledgrab.db``) is opened and tests
read/write/delete production data.
"""
import tempfile
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path
import pytest import pytest
from wled_controller.config import Config, StorageConfig, ServerConfig, AuthConfig # ---------------------------------------------------------------------------
from wled_controller.storage.database import Database # ISOLATE ALL TESTS FROM PRODUCTION DATA — must happen before any test module
from wled_controller.storage.device_store import Device, DeviceStore # imports ``wled_controller.main``.
from wled_controller.storage.sync_clock import SyncClock # ---------------------------------------------------------------------------
from wled_controller.storage.sync_clock_store import SyncClockStore
from wled_controller.storage.output_target_store import OutputTargetStore import wled_controller.config as _config_mod # noqa: E402
from wled_controller.storage.automation import (
Automation, _test_tmp = Path(tempfile.mkdtemp(prefix="wled_test_"))
_test_db_path = str(_test_tmp / "test_ledgrab.db")
_test_assets_dir = str(_test_tmp / "test_assets")
_original_config = _config_mod.Config.load()
_test_config = _original_config.model_copy(
update={
"storage": _config_mod.StorageConfig(database_file=_test_db_path),
"assets": _config_mod.AssetsConfig(
assets_dir=_test_assets_dir,
max_file_size_mb=_original_config.assets.max_file_size_mb,
),
},
) )
from wled_controller.storage.automation_store import AutomationStore _config_mod.config = _test_config
from wled_controller.storage.value_source_store import ValueSourceStore
# ---------------------------------------------------------------------------
from wled_controller.config import Config, StorageConfig, ServerConfig, AuthConfig # noqa: E402
from wled_controller.storage.database import Database # noqa: E402
from wled_controller.storage.device_store import Device, DeviceStore # noqa: E402
from wled_controller.storage.sync_clock import SyncClock # noqa: E402
from wled_controller.storage.sync_clock_store import SyncClockStore # noqa: E402
from wled_controller.storage.output_target_store import OutputTargetStore # noqa: E402
from wled_controller.storage.automation import Automation # noqa: E402
from wled_controller.storage.automation_store import AutomationStore # noqa: E402
from wled_controller.storage.value_source_store import ValueSourceStore # noqa: E402
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -242,3 +274,17 @@ def sample_calibration():
{"edge": "left", "led_start": 110, "led_count": 40, "reverse": True}, {"edge": "left", "led_start": 110, "led_count": 40, "reverse": True},
], ],
} }
# ---------------------------------------------------------------------------
# Session cleanup — remove temporary test directory
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session", autouse=True)
def _cleanup_test_tmp():
"""Remove the temporary test directory after all tests complete."""
import shutil
yield
shutil.rmtree(_test_tmp, ignore_errors=True)
+19 -11
View File
@@ -4,31 +4,39 @@ import os
import sys import sys
import pytest import pytest
from fastapi.testclient import TestClient
from wled_controller.main import app
from wled_controller import __version__ from wled_controller import __version__
from wled_controller.config import get_config from wled_controller.config import get_config
_has_display = bool(os.environ.get("DISPLAY") or sys.platform == "win32" or sys.platform == "darwin") _has_display = bool(
os.environ.get("DISPLAY") or sys.platform == "win32" or sys.platform == "darwin"
)
requires_display = pytest.mark.skipif(not _has_display, reason="No display available (headless CI)") requires_display = pytest.mark.skipif(not _has_display, reason="No display available (headless CI)")
client = TestClient(app)
# Build auth header from the first configured API key # Build auth header from the first configured API key
_config = get_config() _config = get_config()
_api_key = next(iter(_config.auth.api_keys.values()), "") _api_key = next(iter(_config.auth.api_keys.values()), "")
AUTH_HEADERS = {"Authorization": f"Bearer {_api_key}"} if _api_key else {} AUTH_HEADERS = {"Authorization": f"Bearer {_api_key}"} if _api_key else {}
def test_root_endpoint(): @pytest.fixture(scope="module")
def client():
"""Provide a TestClient backed by the isolated test database."""
from fastapi.testclient import TestClient
from wled_controller.main import app
with TestClient(app, raise_server_exceptions=False) as c:
yield c
def test_root_endpoint(client):
"""Test root endpoint returns the HTML dashboard.""" """Test root endpoint returns the HTML dashboard."""
response = client.get("/") response = client.get("/")
assert response.status_code == 200 assert response.status_code == 200
assert "text/html" in response.headers["content-type"] assert "text/html" in response.headers["content-type"]
def test_health_check(): def test_health_check(client):
"""Test health check endpoint.""" """Test health check endpoint."""
response = client.get("/health") response = client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
@@ -38,7 +46,7 @@ def test_health_check():
assert "timestamp" in data assert "timestamp" in data
def test_version_endpoint(): def test_version_endpoint(client):
"""Test version endpoint.""" """Test version endpoint."""
response = client.get("/api/v1/version") response = client.get("/api/v1/version")
assert response.status_code == 200 assert response.status_code == 200
@@ -49,7 +57,7 @@ def test_version_endpoint():
@requires_display @requires_display
def test_get_displays(): def test_get_displays(client):
"""Test get displays endpoint (requires auth and a real display).""" """Test get displays endpoint (requires auth and a real display)."""
response = client.get("/api/v1/config/displays", headers=AUTH_HEADERS) response = client.get("/api/v1/config/displays", headers=AUTH_HEADERS)
assert response.status_code == 200 assert response.status_code == 200
@@ -69,7 +77,7 @@ def test_get_displays():
assert "is_primary" in display assert "is_primary" in display
def test_openapi_docs(): def test_openapi_docs(client):
"""Test OpenAPI documentation is available.""" """Test OpenAPI documentation is available."""
response = client.get("/openapi.json") response = client.get("/openapi.json")
assert response.status_code == 200 assert response.status_code == 200
@@ -77,7 +85,7 @@ def test_openapi_docs():
assert data["info"]["version"] == __version__ assert data["info"]["version"] == __version__
def test_swagger_ui(): def test_swagger_ui(client):
"""Test Swagger UI is available.""" """Test Swagger UI is available."""
response = client.get("/docs") response = client.get("/docs")
assert response.status_code == 200 assert response.status_code == 200