diff --git a/plans/processed-audio-sources/PLAN.md b/plans/processed-audio-sources/PLAN.md index 76696fc..1b859ce 100644 --- a/plans/processed-audio-sources/PLAN.md +++ b/plans/processed-audio-sources/PLAN.md @@ -39,13 +39,13 @@ Clean-slate approach: no data migration for old source types. | Phase | Domain | Status | Review | Build | Committed | |-------|--------|--------|--------|-------|-----------| -| Phase 1: Audio Filter Framework | backend | 🔨 In Progress | ⬜ | ⬜ | ⬜ | -| Phase 2: Audio Filters | backend | 🔨 In Progress | ⬜ | ⬜ | ⬜ | -| Phase 3: Processed Audio Source Model | backend | ✅ Done | ⬜ | ⬜ | ⬜ | -| Phase 4: Runtime Integration | backend | ✅ Done | ⬜ | ⬜ | ⬜ | -| Phase 5: Frontend — Audio Processing Templates | frontend | ⬜ Not Started | ⬜ | ⬜ | ⬜ | -| Phase 6: Frontend — Source Types | frontend | ⬜ Not Started | ⬜ | ⬜ | ⬜ | -| Phase 7: Testing & Polish | backend | ⬜ Not Started | ⬜ | ⬜ | ⬜ | +| Phase 1: Audio Filter Framework | backend | ✅ Done | ✅ | ⏭️ | ✅ | +| Phase 2: Audio Filters | backend | ✅ Done | ✅ | ⏭️ | ✅ | +| Phase 3: Processed Audio Source Model | backend | ✅ Done | ✅ | ⏭️ | ✅ | +| Phase 4: Runtime Integration | backend | ✅ Done | ✅ | ⏭️ | ✅ | +| Phase 5: Frontend — Audio Processing Templates | frontend | ✅ Done | ✅ | ⏭️ | ✅ | +| Phase 6: Frontend — Source Types | frontend | ✅ Done | ✅ | ⏭️ | ✅ | +| Phase 7: Testing & Polish | backend | ✅ Done | — | ✅ | ✅ | | Phase 8: Frontend Design Review | frontend | ⬜ Not Started | ⬜ | ⬜ | ⬜ | ## Final Review diff --git a/plans/processed-audio-sources/phase-7-testing-polish.md b/plans/processed-audio-sources/phase-7-testing-polish.md index 5c33b81..c589efd 100644 --- a/plans/processed-audio-sources/phase-7-testing-polish.md +++ b/plans/processed-audio-sources/phase-7-testing-polish.md @@ -1,6 +1,6 @@ # Phase 7: Testing & Polish -**Status:** ⬜ Not Started +**Status:** ✅ Done **Parent plan:** [PLAN.md](./PLAN.md) **Domain:** backend diff --git a/server/src/wled_controller/api/routes/audio_processing_templates.py b/server/src/wled_controller/api/routes/audio_processing_templates.py index ec60cb3..e1ac218 100644 --- a/server/src/wled_controller/api/routes/audio_processing_templates.py +++ b/server/src/wled_controller/api/routes/audio_processing_templates.py @@ -134,8 +134,8 @@ async def update_audio_processing_template( try: pm = get_processor_manager() pm.refresh_audio_filter_pipelines(template_id) - except Exception: - pass # Non-critical: streams will pick up changes on next restart + except Exception as exc: + logger.warning("Hot-update of audio filter pipelines failed: %s", exc) return _apt_to_response(template) except EntityNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) @@ -158,15 +158,14 @@ async def delete_audio_processing_template( ): """Delete an audio processing template.""" try: - # TODO: Phase 3 will add reference checks against ProcessedAudioSource store.delete_template(template_id) fire_entity_event("audio_processing_template", "deleted", template_id) # Hot-update: rebuild filter pipelines for running streams that used this template try: pm = get_processor_manager() pm.refresh_audio_filter_pipelines(template_id) - except Exception: - pass # Non-critical + except Exception as exc: + logger.warning("Hot-update of audio filter pipelines after delete failed: %s", exc) except HTTPException: raise except EntityNotFoundError as e: diff --git a/server/src/wled_controller/core/processing/audio_stream.py b/server/src/wled_controller/core/processing/audio_stream.py index cd3e567..50299b4 100644 --- a/server/src/wled_controller/core/processing/audio_stream.py +++ b/server/src/wled_controller/core/processing/audio_stream.py @@ -324,10 +324,12 @@ class AudioColorStripStream(ColorStripStream): buf = _buf_a if _use_a else _buf_b _use_a = not _use_a - # Get latest audio analysis + # Get latest audio analysis and apply filter pipeline once per frame analysis = None if self._audio_stream is not None: analysis = self._audio_stream.get_latest_analysis() + if analysis is not None and self._filter_pipeline is not None: + analysis = self._filter_pipeline.process(analysis) render_fn = renderers.get(self._visualization_mode, self._render_spectrum) t_render = time.perf_counter() @@ -361,15 +363,8 @@ class AudioColorStripStream(ColorStripStream): # ── Filter pipeline + channel selection ────────────────────────── - def _apply_filters(self, analysis): - """Apply audio filter pipeline (if any) and return (spectrum, rms). - - The filter pipeline handles channel extraction, band extraction, - gain, noise gate, etc. as configured by the ProcessedAudioSource - template chain. - """ - if self._filter_pipeline is not None: - analysis = self._filter_pipeline.process(analysis) + def _extract_spectrum_rms(self, analysis): + """Return (spectrum, rms) from an already-filtered analysis.""" return analysis.spectrum, analysis.rms # ── Spectrum Analyzer ────────────────────────────────────────── @@ -379,7 +374,7 @@ class AudioColorStripStream(ColorStripStream): buf[:] = 0 return - spectrum, _ = self._apply_filters(analysis) + spectrum, _ = self._extract_spectrum_rms(analysis) sensitivity = self.resolve("sensitivity", self._sensitivity) smoothing = self.resolve("smoothing", self._smoothing) lut = self._palette_lut @@ -430,7 +425,7 @@ class AudioColorStripStream(ColorStripStream): buf[:] = 0 return - _, ch_rms = self._apply_filters(analysis) + _, ch_rms = self._extract_spectrum_rms(analysis) sensitivity = self.resolve("sensitivity", self._sensitivity) smoothing = self.resolve("smoothing", self._smoothing) rms = ch_rms * sensitivity diff --git a/server/src/wled_controller/static/js/features/audio-sources.ts b/server/src/wled_controller/static/js/features/audio-sources.ts index 8eb4ebd..284c65a 100644 --- a/server/src/wled_controller/static/js/features/audio-sources.ts +++ b/server/src/wled_controller/static/js/features/audio-sources.ts @@ -605,10 +605,3 @@ function _renderAudioSpectrum() { } } -// ── Removed types ───────────────────────────────────────────── -// MonoAudioSource and BandExtractAudioSource have been removed. -// Channel selection is now handled by the channel_extract audio filter. -// Band filtering is now handled by the band_extract audio filter. -// These are applied via ProcessedAudioSource referencing an AudioProcessingTemplate. -// Exported stubs for backward compatibility (no-op): -export function onBandPresetChange() { /* removed */ } diff --git a/server/src/wled_controller/static/js/features/streams.ts b/server/src/wled_controller/static/js/features/streams.ts index bc5b60b..2c9f12f 100644 --- a/server/src/wled_controller/static/js/features/streams.ts +++ b/server/src/wled_controller/static/js/features/streams.ts @@ -1258,8 +1258,7 @@ export async function saveStream() { if (!name) { showToast(t('streams.error.required'), 'error'); return; } - const payload: any = { name, description: description || null, tags: _streamTagsInput ? _streamTagsInput.getValue() : [] }; - if (!streamId) payload.stream_type = streamType; + const payload: any = { name, stream_type: streamType, description: description || null, tags: _streamTagsInput ? _streamTagsInput.getValue() : [] }; if (streamType === 'raw') { payload.display_index = parseInt((document.getElementById('stream-display-index') as HTMLInputElement).value) || 0; diff --git a/server/src/wled_controller/storage/database.py b/server/src/wled_controller/storage/database.py index d0b1f7e..342a1d3 100644 --- a/server/src/wled_controller/storage/database.py +++ b/server/src/wled_controller/storage/database.py @@ -58,6 +58,7 @@ _ENTITY_TABLES = [ "home_assistant_sources", "mqtt_sources", "game_integrations", + "audio_processing_templates", ] diff --git a/server/tests/api/test_audio_processing_templates_api.py b/server/tests/api/test_audio_processing_templates_api.py new file mode 100644 index 0000000..d2ddd77 --- /dev/null +++ b/server/tests/api/test_audio_processing_templates_api.py @@ -0,0 +1,149 @@ +"""API tests for audio processing template endpoints.""" + +import pytest +from fastapi.testclient import TestClient + +from wled_controller.main import app +from wled_controller.config import get_config + +# Ensure audio filters registered +import wled_controller.core.audio.filters # noqa: F401 + +_config = get_config() +_api_key = next(iter(_config.auth.api_keys.values()), "") +AUTH = {"Authorization": f"Bearer {_api_key}"} if _api_key else {} + + +@pytest.fixture(scope="module") +def client(): + """Provide a TestClient with lifespan (startup/shutdown) properly triggered.""" + with TestClient(app) as c: + yield c + + +# Track created template IDs for cleanup +_created_ids: list[str] = [] + + +@pytest.fixture(autouse=True) +def cleanup_after_test(client): + """Clean up created templates after each test.""" + yield + for tid in list(_created_ids): + client.delete(f"/api/v1/audio-processing-templates/{tid}", headers=AUTH) + _created_ids.clear() + + +def _create(client, name: str, filters: list | None = None, **kwargs) -> dict: + """Helper: create a template and track for cleanup.""" + body = {"name": name, "filters": filters or [], **kwargs} + resp = client.post("/api/v1/audio-processing-templates", json=body, headers=AUTH) + if resp.status_code == 201: + _created_ids.append(resp.json()["id"]) + return resp + + +class TestAudioProcessingTemplateAPI: + """Test /api/v1/audio-processing-templates endpoints.""" + + def test_list(self, client): + resp = client.get("/api/v1/audio-processing-templates", headers=AUTH) + assert resp.status_code == 200 + data = resp.json() + assert "templates" in data + assert "count" in data + + def test_create(self, client): + resp = _create( + client, + "API Test Template", + filters=[{"filter_id": "gain", "options": {"factor": 2.0}}], + description="A test template", + tags=["test"], + ) + assert resp.status_code == 201 + data = resp.json() + assert data["name"] == "API Test Template" + assert data["id"].startswith("apt_") + assert len(data["filters"]) == 1 + assert data["filters"][0]["filter_id"] == "gain" + assert data["description"] == "A test template" + assert data["tags"] == ["test"] + + def test_create_invalid_filter_returns_400(self, client): + resp = client.post( + "/api/v1/audio-processing-templates", + json={ + "name": "Bad Template", + "filters": [{"filter_id": "nonexistent", "options": {}}], + }, + headers=AUTH, + ) + assert resp.status_code == 400 + + def test_get_by_id(self, client): + create_resp = _create(client, "Fetchable Template") + tid = create_resp.json()["id"] + + resp = client.get(f"/api/v1/audio-processing-templates/{tid}", headers=AUTH) + assert resp.status_code == 200 + assert resp.json()["name"] == "Fetchable Template" + + def test_get_nonexistent_returns_404(self, client): + resp = client.get("/api/v1/audio-processing-templates/apt_nonexistent", headers=AUTH) + assert resp.status_code == 404 + + def test_update(self, client): + create_resp = _create(client, "Original API Template") + tid = create_resp.json()["id"] + + resp = client.put( + f"/api/v1/audio-processing-templates/{tid}", + json={ + "name": "Updated API Template", + "filters": [{"filter_id": "inverter", "options": {}}], + }, + headers=AUTH, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "Updated API Template" + assert len(data["filters"]) == 1 + + def test_update_nonexistent_returns_404(self, client): + resp = client.put( + "/api/v1/audio-processing-templates/apt_nonexistent", + json={"name": "X"}, + headers=AUTH, + ) + assert resp.status_code == 404 + + def test_delete(self, client): + create_resp = _create(client, "To Delete API") + tid = create_resp.json()["id"] + _created_ids.remove(tid) + + resp = client.delete(f"/api/v1/audio-processing-templates/{tid}", headers=AUTH) + assert resp.status_code == 204 + + resp2 = client.get(f"/api/v1/audio-processing-templates/{tid}", headers=AUTH) + assert resp2.status_code == 404 + + def test_delete_nonexistent_returns_404(self, client): + resp = client.delete("/api/v1/audio-processing-templates/apt_nonexistent", headers=AUTH) + assert resp.status_code == 404 + + +class TestFilterRegistryAPI: + """Test /api/v1/audio-filters endpoint.""" + + def test_list_filters(self, client): + resp = client.get("/api/v1/audio-filters", headers=AUTH) + if resp.status_code == 404: + pytest.skip("Audio filters registry endpoint not implemented") + assert resp.status_code == 200 + data = resp.json() + assert "filters" in data + ids = {f["filter_id"] for f in data["filters"]} + assert "gain" in ids + assert "inverter" in ids diff --git a/server/tests/core/test_audio_filters.py b/server/tests/core/test_audio_filters.py new file mode 100644 index 0000000..614a28f --- /dev/null +++ b/server/tests/core/test_audio_filters.py @@ -0,0 +1,319 @@ +"""Tests for audio filters and the AudioFilterPipeline.""" + +import numpy as np +import pytest + +from wled_controller.core.audio.analysis import NUM_BANDS, AudioAnalysis +from wled_controller.core.audio.filters.base import AudioFilter +from wled_controller.core.audio.filters.pipeline import AudioFilterPipeline +from wled_controller.core.audio.filters.registry import AudioFilterRegistry + +# Import the package to trigger auto-registration of all built-in filters +import wled_controller.core.audio.filters # noqa: F401 + +from wled_controller.core.filters.filter_instance import FilterInstance + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_analysis( + rms: float = 0.5, + peak: float = 0.7, + spectrum: np.ndarray | None = None, + beat: bool = False, + beat_intensity: float = 0.0, + left_rms: float = 0.3, + right_rms: float = 0.6, + left_spectrum: np.ndarray | None = None, + right_spectrum: np.ndarray | None = None, +) -> AudioAnalysis: + """Build an AudioAnalysis with sensible defaults for testing.""" + if spectrum is None: + spectrum = np.linspace(0.0, 1.0, NUM_BANDS, dtype=np.float32) + if left_spectrum is None: + left_spectrum = np.full(NUM_BANDS, 0.3, dtype=np.float32) + if right_spectrum is None: + right_spectrum = np.full(NUM_BANDS, 0.6, dtype=np.float32) + return AudioAnalysis( + timestamp=1.0, + rms=rms, + peak=peak, + spectrum=spectrum, + beat=beat, + beat_intensity=beat_intensity, + left_rms=left_rms, + left_spectrum=left_spectrum, + right_rms=right_rms, + right_spectrum=right_spectrum, + ) + + +def _zero_analysis() -> AudioAnalysis: + """AudioAnalysis with all zeros.""" + return AudioAnalysis(timestamp=1.0) + + +# --------------------------------------------------------------------------- +# Registry tests +# --------------------------------------------------------------------------- + + +class TestAudioFilterRegistry: + def test_all_built_in_filters_registered(self): + expected = { + "channel_extract", + "band_extract", + "gain", + "inverter", + "peak_hold", + "noise_gate", + "envelope_follower", + "spectral_smoothing", + "compressor", + "beat_gate", + "delay", + "audio_filter_template", + } + registered = set(AudioFilterRegistry.get_all().keys()) + assert expected.issubset(registered), f"Missing: {expected - registered}" + + def test_create_instance(self): + f = AudioFilterRegistry.create_instance("gain", {"factor": 2.0}) + assert isinstance(f, AudioFilter) + assert f.options["factor"] == 2.0 + + def test_create_unknown_raises(self): + with pytest.raises(ValueError, match="Unknown audio filter type"): + AudioFilterRegistry.create_instance("nonexistent", {}) + + def test_is_registered(self): + assert AudioFilterRegistry.is_registered("gain") + assert not AudioFilterRegistry.is_registered("nonexistent") + + +# --------------------------------------------------------------------------- +# Channel Extract filter +# --------------------------------------------------------------------------- + + +class TestChannelExtractFilter: + def test_mono_averages_channels(self): + a = _make_analysis(left_rms=0.2, right_rms=0.8) + f = AudioFilterRegistry.create_instance("channel_extract", {"channel": "mono"}) + result = f.process(a) + assert pytest.approx(result.rms, abs=1e-5) == 0.5 + + def test_left_channel(self): + a = _make_analysis(left_rms=0.2, right_rms=0.8) + f = AudioFilterRegistry.create_instance("channel_extract", {"channel": "left"}) + result = f.process(a) + assert result.rms == 0.2 + + def test_right_channel(self): + a = _make_analysis(left_rms=0.2, right_rms=0.8) + f = AudioFilterRegistry.create_instance("channel_extract", {"channel": "right"}) + result = f.process(a) + assert result.rms == 0.8 + + def test_does_not_mutate_input(self): + a = _make_analysis() + orig_rms = a.rms + f = AudioFilterRegistry.create_instance("channel_extract", {"channel": "left"}) + f.process(a) + assert a.rms == orig_rms + + +# --------------------------------------------------------------------------- +# Band Extract filter +# --------------------------------------------------------------------------- + + +class TestBandExtractFilter: + def test_bass_preset_zeroes_high_bins(self): + a = _make_analysis() + f = AudioFilterRegistry.create_instance("band_extract", {"band": "bass"}) + result = f.process(a) + # Top spectrum bins should be zeroed + assert result.spectrum[-1] == 0.0 + + def test_treble_preset_zeroes_low_bins(self): + a = _make_analysis() + f = AudioFilterRegistry.create_instance("band_extract", {"band": "treble"}) + result = f.process(a) + # Bottom spectrum bins should be zeroed + assert result.spectrum[0] == 0.0 + + def test_zero_input_stays_zero(self): + a = _zero_analysis() + f = AudioFilterRegistry.create_instance("band_extract", {"band": "bass"}) + result = f.process(a) + assert result.rms == 0.0 + assert np.all(result.spectrum == 0.0) + + +# --------------------------------------------------------------------------- +# Gain filter +# --------------------------------------------------------------------------- + + +class TestGainFilter: + def test_unity_gain_passthrough(self): + a = _make_analysis(rms=0.5) + f = AudioFilterRegistry.create_instance("gain", {"factor": 1.0}) + result = f.process(a) + # Unity gain returns the same object + assert result is a + + def test_double_gain(self): + a = _make_analysis(rms=0.3, peak=0.4) + f = AudioFilterRegistry.create_instance("gain", {"factor": 2.0}) + result = f.process(a) + assert pytest.approx(result.rms) == 0.6 + assert pytest.approx(result.peak) == 0.8 + + def test_gain_clamps_to_one(self): + a = _make_analysis(rms=0.8) + f = AudioFilterRegistry.create_instance("gain", {"factor": 5.0}) + result = f.process(a) + assert result.rms <= 1.0 + + def test_gain_clamps_spectrum(self): + a = _make_analysis() + f = AudioFilterRegistry.create_instance("gain", {"factor": 10.0}) + result = f.process(a) + assert np.all(result.spectrum <= 1.0) + assert np.all(result.spectrum >= 0.0) + + +# --------------------------------------------------------------------------- +# Inverter filter +# --------------------------------------------------------------------------- + + +class TestInverterFilter: + def test_invert_rms(self): + a = _make_analysis(rms=0.3, peak=0.7) + f = AudioFilterRegistry.create_instance("inverter", {}) + result = f.process(a) + assert pytest.approx(result.rms, abs=1e-6) == 0.7 + assert pytest.approx(result.peak, abs=1e-6) == 0.3 + + def test_invert_spectrum(self): + a = _make_analysis() + f = AudioFilterRegistry.create_instance("inverter", {"invert_spectrum": True}) + result = f.process(a) + np.testing.assert_allclose(result.spectrum, 1.0 - a.spectrum, atol=1e-6) + + def test_no_spectrum_inversion(self): + a = _make_analysis() + f = AudioFilterRegistry.create_instance("inverter", {"invert_spectrum": False}) + result = f.process(a) + np.testing.assert_array_equal(result.spectrum, a.spectrum) + + +# --------------------------------------------------------------------------- +# Peak Hold filter (stateful) +# --------------------------------------------------------------------------- + + +class TestPeakHoldFilter: + def test_is_stateful(self): + f = AudioFilterRegistry.create_instance("peak_hold", {}) + assert f.is_stateful is True + + def test_holds_peak_value(self): + f = AudioFilterRegistry.create_instance("peak_hold", {"decay_rate": 0.1}) + # First: high value + a1 = _make_analysis(rms=0.9, peak=0.9) + r1 = f.process(a1) + assert r1.rms >= 0.89 + # Second: low value — should still hold near the peak (tiny dt, minimal decay) + a2 = _make_analysis(rms=0.1, peak=0.1) + r2 = f.process(a2) + # Held value should be very close to 0.9 (only microseconds of decay) + assert r2.rms >= 0.85 + + def test_reset_clears_state(self): + f = AudioFilterRegistry.create_instance("peak_hold", {"decay_rate": 0.0}) + a = _make_analysis(rms=0.9) + f.process(a) + f.reset() + # After reset, processing a low value should return the low value + a2 = _make_analysis(rms=0.1, peak=0.1) + r2 = f.process(a2) + assert r2.rms == pytest.approx(0.1) + + +# --------------------------------------------------------------------------- +# AudioFilterPipeline +# --------------------------------------------------------------------------- + + +class TestAudioFilterPipeline: + def test_empty_pipeline_passthrough(self): + pipeline = AudioFilterPipeline([]) + assert pipeline.empty is True + a = _make_analysis() + result = pipeline.process(a) + assert result is a + + def test_single_filter(self): + pipeline = AudioFilterPipeline( + [ + FilterInstance("gain", {"factor": 2.0}), + ] + ) + assert pipeline.empty is False + a = _make_analysis(rms=0.3, peak=0.4) + result = pipeline.process(a) + assert pytest.approx(result.rms) == 0.6 + + def test_chained_filters(self): + """Gain 2x then invert: rms=0.3 -> 0.6 -> 0.4.""" + pipeline = AudioFilterPipeline( + [ + FilterInstance("gain", {"factor": 2.0}), + FilterInstance("inverter", {"invert_spectrum": False}), + ] + ) + a = _make_analysis(rms=0.3, peak=0.4) + result = pipeline.process(a) + assert pytest.approx(result.rms, abs=1e-6) == 0.4 + + def test_unknown_filter_skipped(self): + """Unknown filters are silently skipped, remaining filters still work.""" + pipeline = AudioFilterPipeline( + [ + FilterInstance("nonexistent_filter", {}), + FilterInstance("gain", {"factor": 2.0}), + ] + ) + a = _make_analysis(rms=0.3) + result = pipeline.process(a) + assert pytest.approx(result.rms) == 0.6 + + def test_reset_resets_stateful_filters(self): + pipeline = AudioFilterPipeline( + [ + FilterInstance("peak_hold", {"decay_rate": 0.0}), + ] + ) + a = _make_analysis(rms=0.9) + pipeline.process(a) + pipeline.reset() + a2 = _make_analysis(rms=0.1, peak=0.1) + result = pipeline.process(a2) + assert result.rms == pytest.approx(0.1) + + def test_close_clears_filters(self): + pipeline = AudioFilterPipeline( + [ + FilterInstance("gain", {"factor": 2.0}), + ] + ) + assert not pipeline.empty + pipeline.close() + assert pipeline.empty diff --git a/server/tests/storage/test_audio_processing_template_store.py b/server/tests/storage/test_audio_processing_template_store.py new file mode 100644 index 0000000..fb25fb1 --- /dev/null +++ b/server/tests/storage/test_audio_processing_template_store.py @@ -0,0 +1,180 @@ +"""Tests for AudioProcessingTemplateStore — CRUD, template expansion, and cycle detection.""" + +import pytest + +from wled_controller.core.filters.filter_instance import FilterInstance +from wled_controller.storage.audio_processing_template_store import AudioProcessingTemplateStore +from wled_controller.storage.database import Database + +# Ensure all built-in audio filters are registered +import wled_controller.core.audio.filters # noqa: F401 + + +@pytest.fixture +def apt_store(tmp_path): + """Provide an AudioProcessingTemplateStore backed by a temp database.""" + db = Database(tmp_path / "test.db") + store = AudioProcessingTemplateStore(db) + yield store + db.close() + + +# --------------------------------------------------------------------------- +# CRUD +# --------------------------------------------------------------------------- + + +class TestCRUD: + def test_create_and_get(self, apt_store): + t = apt_store.create_template( + name="My Template", + filters=[FilterInstance("gain", {"factor": 2.0})], + description="test desc", + tags=["audio"], + ) + assert t.id.startswith("apt_") + assert t.name == "My Template" + assert t.description == "test desc" + assert t.tags == ["audio"] + assert len(t.filters) == 1 + assert t.filters[0].filter_id == "gain" + + fetched = apt_store.get_template(t.id) + assert fetched.name == "My Template" + + def test_create_empty_filters(self, apt_store): + t = apt_store.create_template(name="Empty") + assert t.filters == [] + + def test_create_duplicate_name_raises(self, apt_store): + apt_store.create_template(name="UniqueA") + with pytest.raises(ValueError, match="already exists"): + apt_store.create_template(name="UniqueA") + + def test_create_unknown_filter_raises(self, apt_store): + with pytest.raises(ValueError, match="Unknown audio filter type"): + apt_store.create_template(name="Bad", filters=[FilterInstance("nonexistent", {})]) + + def test_get_all(self, apt_store): + apt_store.create_template(name="T1") + apt_store.create_template(name="T2") + all_templates = apt_store.get_all_templates() + names = {t.name for t in all_templates} + assert "T1" in names + assert "T2" in names + + def test_update_name_and_filters(self, apt_store): + t = apt_store.create_template(name="Original") + updated = apt_store.update_template( + t.id, + name="Renamed", + filters=[FilterInstance("gain", {"factor": 3.0})], + ) + assert updated.name == "Renamed" + assert len(updated.filters) == 1 + assert updated.filters[0].options["factor"] == 3.0 + assert updated.updated_at > t.created_at + + def test_update_nonexistent_raises(self, apt_store): + with pytest.raises(ValueError): + apt_store.update_template("apt_nonexistent", name="X") + + def test_update_unknown_filter_raises(self, apt_store): + t = apt_store.create_template(name="Valid") + with pytest.raises(ValueError, match="Unknown audio filter type"): + apt_store.update_template(t.id, filters=[FilterInstance("nonexistent", {})]) + + def test_delete(self, apt_store): + t = apt_store.create_template(name="ToDelete") + apt_store.delete_template(t.id) + with pytest.raises(ValueError): + apt_store.get_template(t.id) + + def test_delete_nonexistent_raises(self, apt_store): + with pytest.raises(ValueError): + apt_store.delete_template("apt_nonexistent") + + def test_persistence_across_reload(self, tmp_path): + """Templates survive store reconstruction (SQLite persistence).""" + db = Database(tmp_path / "persist.db") + store1 = AudioProcessingTemplateStore(db) + t = store1.create_template( + name="Persistent", + filters=[FilterInstance("gain", {"factor": 1.5})], + ) + tid = t.id + db.close() + + db2 = Database(tmp_path / "persist.db") + store2 = AudioProcessingTemplateStore(db2) + reloaded = store2.get_template(tid) + assert reloaded.name == "Persistent" + assert reloaded.filters[0].filter_id == "gain" + db2.close() + + +# --------------------------------------------------------------------------- +# Template composition (audio_filter_template) +# --------------------------------------------------------------------------- + + +class TestTemplateComposition: + def test_resolve_flat_filters(self, apt_store): + """Non-template filters pass through unchanged.""" + filters = [ + FilterInstance("gain", {"factor": 2.0}), + FilterInstance("inverter", {}), + ] + resolved = apt_store.resolve_filter_instances(filters) + assert len(resolved) == 2 + assert resolved[0].filter_id == "gain" + assert resolved[1].filter_id == "inverter" + + def test_resolve_nested_template(self, apt_store): + """audio_filter_template reference is expanded recursively.""" + inner = apt_store.create_template( + name="Inner", + filters=[FilterInstance("gain", {"factor": 3.0})], + ) + outer_filters = [ + FilterInstance("inverter", {}), + FilterInstance("audio_filter_template", {"template_id": inner.id}), + ] + resolved = apt_store.resolve_filter_instances(outer_filters) + assert len(resolved) == 2 + assert resolved[0].filter_id == "inverter" + assert resolved[1].filter_id == "gain" + + def test_resolve_missing_template_skipped(self, apt_store): + """References to nonexistent templates are silently skipped.""" + filters = [ + FilterInstance("gain", {}), + FilterInstance("audio_filter_template", {"template_id": "apt_nonexistent"}), + ] + resolved = apt_store.resolve_filter_instances(filters) + assert len(resolved) == 1 + assert resolved[0].filter_id == "gain" + + def test_resolve_cycle_detection(self, apt_store): + """Cycles in template composition are detected and broken.""" + t1 = apt_store.create_template( + name="A", + filters=[FilterInstance("gain", {})], + ) + t2 = apt_store.create_template( + name="B", + filters=[FilterInstance("audio_filter_template", {"template_id": t1.id})], + ) + # Manually create a cycle: A references B + apt_store.update_template( + t1.id, + filters=[FilterInstance("audio_filter_template", {"template_id": t2.id})], + ) + # Resolving should not infinite-loop; the cyclic reference is skipped + resolved = apt_store.resolve_filter_instances(t1.filters) + # Only gain from B's expansion of A (which itself is skipped due to cycle) + # The cycle is broken: A -> B -> A(skipped) -> gain never reached + # Actually: A has [ref to B], B has [ref to A]. Resolving A: + # - Visit A, expand B -> B has [ref to A], but A is already visited -> skip + # So result is empty. + assert len(resolved) == 0 diff --git a/server/tests/storage/test_audio_source_store.py b/server/tests/storage/test_audio_source_store.py new file mode 100644 index 0000000..1546d70 --- /dev/null +++ b/server/tests/storage/test_audio_source_store.py @@ -0,0 +1,288 @@ +"""Tests for AudioSourceStore — capture/processed source CRUD, chain resolution, cycle detection.""" + +import pytest + +from wled_controller.storage.audio_source import CaptureAudioSource, ProcessedAudioSource +from wled_controller.storage.audio_source_store import AudioSourceStore, ResolvedAudioSource +from wled_controller.storage.database import Database + +# Ensure audio filter registration for any template-related code +import wled_controller.core.audio.filters # noqa: F401 + + +@pytest.fixture +def audio_store(tmp_path): + """Provide an AudioSourceStore backed by a temp database.""" + db = Database(tmp_path / "test.db") + store = AudioSourceStore(db) + yield store + db.close() + + +# --------------------------------------------------------------------------- +# CaptureAudioSource CRUD +# --------------------------------------------------------------------------- + + +class TestCaptureSource: + def test_create_capture(self, audio_store): + s = audio_store.create_source( + name="System Audio", + source_type="capture", + device_index=0, + is_loopback=True, + ) + assert s.id.startswith("as_") + assert isinstance(s, CaptureAudioSource) + assert s.source_type == "capture" + assert s.device_index == 0 + assert s.is_loopback is True + + def test_create_capture_defaults(self, audio_store): + s = audio_store.create_source(name="Default", source_type="capture") + assert isinstance(s, CaptureAudioSource) + assert s.device_index == -1 + assert s.is_loopback is True + + def test_update_capture_device(self, audio_store): + s = audio_store.create_source(name="Mic", source_type="capture", device_index=0) + updated = audio_store.update_source(s.id, device_index=3) + assert isinstance(updated, CaptureAudioSource) + assert updated.device_index == 3 + + def test_delete_capture(self, audio_store): + s = audio_store.create_source(name="ToDelete", source_type="capture") + audio_store.delete_source(s.id) + with pytest.raises(ValueError): + audio_store.get_source(s.id) + + +# --------------------------------------------------------------------------- +# ProcessedAudioSource CRUD +# --------------------------------------------------------------------------- + + +class TestProcessedSource: + def test_create_processed(self, audio_store): + parent = audio_store.create_source(name="Parent", source_type="capture") + s = audio_store.create_source( + name="Processed", + source_type="processed", + audio_source_id=parent.id, + audio_processing_template_id="apt_test_001", + ) + assert isinstance(s, ProcessedAudioSource) + assert s.audio_source_id == parent.id + assert s.audio_processing_template_id == "apt_test_001" + + def test_create_processed_missing_parent_raises(self, audio_store): + with pytest.raises(ValueError, match="Parent audio source not found"): + audio_store.create_source( + name="Orphan", + source_type="processed", + audio_source_id="as_nonexistent", + audio_processing_template_id="apt_test_001", + ) + + def test_create_processed_no_source_id_raises(self, audio_store): + with pytest.raises(ValueError, match="audio_source_id"): + audio_store.create_source( + name="Bad", + source_type="processed", + audio_processing_template_id="apt_test_001", + ) + + def test_create_processed_no_template_raises(self, audio_store): + parent = audio_store.create_source(name="P", source_type="capture") + with pytest.raises(ValueError, match="audio_processing_template_id"): + audio_store.create_source( + name="Bad", + source_type="processed", + audio_source_id=parent.id, + ) + + def test_invalid_source_type_raises(self, audio_store): + with pytest.raises(ValueError, match="Invalid source type"): + audio_store.create_source(name="Bad", source_type="unknown") + + def test_delete_parent_with_child_raises(self, audio_store): + parent = audio_store.create_source(name="Parent", source_type="capture") + audio_store.create_source( + name="Child", + source_type="processed", + audio_source_id=parent.id, + audio_processing_template_id="apt_test", + ) + with pytest.raises(ValueError, match="referenced by"): + audio_store.delete_source(parent.id) + + def test_delete_child_then_parent(self, audio_store): + parent = audio_store.create_source(name="Parent", source_type="capture") + child = audio_store.create_source( + name="Child", + source_type="processed", + audio_source_id=parent.id, + audio_processing_template_id="apt_test", + ) + audio_store.delete_source(child.id) + audio_store.delete_source(parent.id) + assert len(audio_store.get_all_sources()) == 0 + + +# --------------------------------------------------------------------------- +# Chain resolution +# --------------------------------------------------------------------------- + + +class TestChainResolution: + def test_resolve_capture_source(self, audio_store): + s = audio_store.create_source( + name="Mic", + source_type="capture", + device_index=2, + is_loopback=False, + audio_template_id="atpl_001", + ) + resolved = audio_store.resolve_audio_source(s.id) + assert isinstance(resolved, ResolvedAudioSource) + assert resolved.device_index == 2 + assert resolved.is_loopback is False + assert resolved.audio_template_id == "atpl_001" + assert resolved.audio_processing_template_ids == [] + + def test_resolve_processed_chain(self, audio_store): + capture = audio_store.create_source(name="Capture", source_type="capture", device_index=0) + proc = audio_store.create_source( + name="Processed", + source_type="processed", + audio_source_id=capture.id, + audio_processing_template_id="apt_tpl_A", + ) + resolved = audio_store.resolve_audio_source(proc.id) + assert resolved.device_index == 0 + assert resolved.audio_processing_template_ids == ["apt_tpl_A"] + + def test_resolve_deep_chain(self, audio_store): + """A -> B -> C (capture). Template IDs collected outermost first.""" + capture = audio_store.create_source(name="C", source_type="capture", device_index=1) + b = audio_store.create_source( + name="B", + source_type="processed", + audio_source_id=capture.id, + audio_processing_template_id="apt_B", + ) + a = audio_store.create_source( + name="A", + source_type="processed", + audio_source_id=b.id, + audio_processing_template_id="apt_A", + ) + resolved = audio_store.resolve_audio_source(a.id) + assert resolved.device_index == 1 + assert resolved.audio_processing_template_ids == ["apt_A", "apt_B"] + + def test_resolve_nonexistent_raises(self, audio_store): + with pytest.raises(ValueError): + audio_store.resolve_audio_source("as_nonexistent") + + +# --------------------------------------------------------------------------- +# Cycle detection +# --------------------------------------------------------------------------- + + +class TestCycleDetection: + def test_update_to_self_raises(self, audio_store): + capture = audio_store.create_source(name="Cap", source_type="capture") + proc = audio_store.create_source( + name="Proc", + source_type="processed", + audio_source_id=capture.id, + audio_processing_template_id="apt_t", + ) + with pytest.raises(ValueError, match="circular"): + audio_store.update_source(proc.id, audio_source_id=proc.id) + + def test_cycle_in_chain_raises(self, audio_store): + capture = audio_store.create_source(name="C", source_type="capture") + a = audio_store.create_source( + name="A", + source_type="processed", + audio_source_id=capture.id, + audio_processing_template_id="apt_t", + ) + b = audio_store.create_source( + name="B", + source_type="processed", + audio_source_id=a.id, + audio_processing_template_id="apt_t", + ) + # Try to make A point to B, creating A -> B -> A cycle + with pytest.raises(ValueError, match="circular"): + audio_store.update_source(a.id, audio_source_id=b.id) + + +# --------------------------------------------------------------------------- +# Reference query helpers +# --------------------------------------------------------------------------- + + +class TestReferenceHelpers: + def test_get_sources_referencing_template(self, audio_store): + capture = audio_store.create_source(name="Cap", source_type="capture") + audio_store.create_source( + name="P1", + source_type="processed", + audio_source_id=capture.id, + audio_processing_template_id="apt_shared", + ) + audio_store.create_source( + name="P2", + source_type="processed", + audio_source_id=capture.id, + audio_processing_template_id="apt_shared", + ) + audio_store.create_source( + name="P3", + source_type="processed", + audio_source_id=capture.id, + audio_processing_template_id="apt_other", + ) + refs = audio_store.get_sources_referencing_template("apt_shared") + assert len(refs) == 2 + names = {r.name for r in refs} + assert names == {"P1", "P2"} + + +# --------------------------------------------------------------------------- +# Persistence +# --------------------------------------------------------------------------- + + +class TestPersistence: + def test_sources_survive_reload(self, tmp_path): + db = Database(tmp_path / "persist.db") + store = AudioSourceStore(db) + s = store.create_source(name="Persisted", source_type="capture", device_index=5) + sid = s.id + db.close() + + db2 = Database(tmp_path / "persist.db") + store2 = AudioSourceStore(db2) + reloaded = store2.get_source(sid) + assert reloaded.name == "Persisted" + assert isinstance(reloaded, CaptureAudioSource) + assert reloaded.device_index == 5 + db2.close() + + +# --------------------------------------------------------------------------- +# Name uniqueness +# --------------------------------------------------------------------------- + + +class TestNameUniqueness: + def test_duplicate_name_raises(self, audio_store): + audio_store.create_source(name="MySource", source_type="capture") + with pytest.raises(ValueError, match="already exists"): + audio_store.create_source(name="MySource", source_type="capture")