ab43578049
Add AudioFilterPipeline for chained filter execution on AudioAnalysis. Wire filter pipelines into AudioColorStripStream, AudioValueStream, and WebSocket test endpoint. Add hot-update support via ProcessorManager.refresh_audio_filter_pipelines(). Thread AudioProcessingTemplateStore through dependency injection hierarchy.
305 lines
10 KiB
Python
305 lines
10 KiB
Python
"""Audio source routes: CRUD for audio sources + real-time test WebSocket."""
|
|
|
|
import asyncio
|
|
from typing import Annotated, Optional
|
|
|
|
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
|
from starlette.websockets import WebSocket, WebSocketDisconnect
|
|
|
|
from wled_controller.api.auth import AuthRequired
|
|
from wled_controller.api.dependencies import (
|
|
fire_entity_event,
|
|
get_audio_processing_template_store,
|
|
get_audio_source_store,
|
|
get_audio_template_store,
|
|
get_color_strip_store,
|
|
get_processor_manager,
|
|
)
|
|
from wled_controller.api.schemas.audio_sources import (
|
|
AudioSourceCreate,
|
|
AudioSourceListResponse,
|
|
AudioSourceResponse,
|
|
AudioSourceUpdate,
|
|
CaptureAudioSourceResponse,
|
|
ProcessedAudioSourceResponse,
|
|
)
|
|
from wled_controller.storage.audio_source import (
|
|
AudioSource,
|
|
CaptureAudioSource,
|
|
ProcessedAudioSource,
|
|
)
|
|
from wled_controller.storage.audio_source_store import AudioSourceStore
|
|
from wled_controller.storage.color_strip_store import ColorStripStore
|
|
from wled_controller.utils import get_logger
|
|
from wled_controller.storage.base_store import EntityNotFoundError
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
_RESPONSE_MAP = {
|
|
CaptureAudioSource: lambda s: CaptureAudioSourceResponse(
|
|
id=s.id,
|
|
name=s.name,
|
|
description=s.description,
|
|
tags=s.tags,
|
|
created_at=s.created_at,
|
|
updated_at=s.updated_at,
|
|
device_index=s.device_index,
|
|
is_loopback=s.is_loopback,
|
|
audio_template_id=s.audio_template_id,
|
|
),
|
|
ProcessedAudioSource: lambda s: ProcessedAudioSourceResponse(
|
|
id=s.id,
|
|
name=s.name,
|
|
description=s.description,
|
|
tags=s.tags,
|
|
created_at=s.created_at,
|
|
updated_at=s.updated_at,
|
|
audio_source_id=s.audio_source_id,
|
|
audio_processing_template_id=s.audio_processing_template_id,
|
|
),
|
|
}
|
|
|
|
|
|
def _to_response(source: AudioSource) -> AudioSourceResponse:
|
|
"""Convert an AudioSource dataclass to the matching response schema."""
|
|
builder = _RESPONSE_MAP.get(type(source))
|
|
if builder is None:
|
|
# Fallback for unknown types — return as capture
|
|
return CaptureAudioSourceResponse(
|
|
id=source.id,
|
|
name=source.name,
|
|
description=source.description,
|
|
tags=source.tags,
|
|
created_at=source.created_at,
|
|
updated_at=source.updated_at,
|
|
device_index=getattr(source, "device_index", -1),
|
|
is_loopback=getattr(source, "is_loopback", True),
|
|
audio_template_id=getattr(source, "audio_template_id", None),
|
|
)
|
|
return builder(source)
|
|
|
|
|
|
@router.get("/api/v1/audio-sources", response_model=AudioSourceListResponse, tags=["Audio Sources"])
|
|
async def list_audio_sources(
|
|
_auth: AuthRequired,
|
|
source_type: Optional[str] = Query(
|
|
None, description="Filter by source_type: capture or processed"
|
|
),
|
|
store: AudioSourceStore = Depends(get_audio_source_store),
|
|
):
|
|
"""List all audio sources, optionally filtered by type."""
|
|
sources = store.get_all_sources()
|
|
if source_type:
|
|
sources = [s for s in sources if s.source_type == source_type]
|
|
return AudioSourceListResponse(
|
|
sources=[_to_response(s) for s in sources],
|
|
count=len(sources),
|
|
)
|
|
|
|
|
|
@router.post(
|
|
"/api/v1/audio-sources",
|
|
response_model=AudioSourceResponse,
|
|
status_code=201,
|
|
tags=["Audio Sources"],
|
|
)
|
|
async def create_audio_source(
|
|
data: Annotated[AudioSourceCreate, Body(discriminator="source_type")],
|
|
_auth: AuthRequired,
|
|
store: AudioSourceStore = Depends(get_audio_source_store),
|
|
):
|
|
"""Create a new audio source."""
|
|
try:
|
|
fields = data.model_dump(exclude={"source_type", "name", "description", "tags"})
|
|
source = store.create_source(
|
|
name=data.name,
|
|
source_type=data.source_type,
|
|
description=data.description,
|
|
tags=data.tags,
|
|
**fields,
|
|
)
|
|
fire_entity_event("audio_source", "created", source.id)
|
|
return _to_response(source)
|
|
except EntityNotFoundError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.get(
|
|
"/api/v1/audio-sources/{source_id}", response_model=AudioSourceResponse, tags=["Audio Sources"]
|
|
)
|
|
async def get_audio_source(
|
|
source_id: str,
|
|
_auth: AuthRequired,
|
|
store: AudioSourceStore = Depends(get_audio_source_store),
|
|
):
|
|
"""Get an audio source by ID."""
|
|
try:
|
|
source = store.get_source(source_id)
|
|
return _to_response(source)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
|
|
@router.put(
|
|
"/api/v1/audio-sources/{source_id}", response_model=AudioSourceResponse, tags=["Audio Sources"]
|
|
)
|
|
async def update_audio_source(
|
|
source_id: str,
|
|
data: Annotated[AudioSourceUpdate, Body(discriminator="source_type")],
|
|
_auth: AuthRequired,
|
|
store: AudioSourceStore = Depends(get_audio_source_store),
|
|
):
|
|
"""Update an existing audio source."""
|
|
try:
|
|
fields = data.model_dump(exclude={"source_type"}, exclude_none=True)
|
|
source = store.update_source(source_id=source_id, **fields)
|
|
fire_entity_event("audio_source", "updated", source_id)
|
|
return _to_response(source)
|
|
except EntityNotFoundError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.delete("/api/v1/audio-sources/{source_id}", status_code=204, tags=["Audio Sources"])
|
|
async def delete_audio_source(
|
|
source_id: str,
|
|
_auth: AuthRequired,
|
|
store: AudioSourceStore = Depends(get_audio_source_store),
|
|
css_store: ColorStripStore = Depends(get_color_strip_store),
|
|
):
|
|
"""Delete an audio source."""
|
|
try:
|
|
# Check if any CSS entities reference this audio source
|
|
from wled_controller.storage.color_strip_source import AudioColorStripSource
|
|
|
|
for css in css_store.get_all_sources():
|
|
if (
|
|
isinstance(css, AudioColorStripSource)
|
|
and getattr(css, "audio_source_id", None) == source_id
|
|
):
|
|
raise ValueError(f"Cannot delete: referenced by color strip source '{css.name}'")
|
|
|
|
store.delete_source(source_id)
|
|
fire_entity_event("audio_source", "deleted", source_id)
|
|
except EntityNotFoundError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
# ===== REAL-TIME AUDIO TEST WEBSOCKET =====
|
|
|
|
|
|
@router.websocket("/api/v1/audio-sources/{source_id}/test/ws")
|
|
async def test_audio_source_ws(
|
|
websocket: WebSocket,
|
|
source_id: str,
|
|
token: str = Query(""),
|
|
):
|
|
"""WebSocket for real-time audio spectrum analysis. Auth via ?token=<api_key>.
|
|
|
|
Resolves the audio source to its device and template chain, acquires a
|
|
ManagedAudioStream (ref-counted — shares with running targets), and streams
|
|
AudioAnalysis snapshots as JSON at ~20 Hz.
|
|
|
|
Audio processing filters from the template chain are applied to the
|
|
analysis before sending, so the WebSocket output matches what running
|
|
streams see.
|
|
"""
|
|
from wled_controller.api.auth import verify_ws_token
|
|
from wled_controller.core.audio.filters.pipeline import build_pipeline_from_template_ids
|
|
|
|
if not verify_ws_token(token):
|
|
await websocket.close(code=4001, reason="Unauthorized")
|
|
return
|
|
|
|
# Resolve source → device info + processing template chain
|
|
store = get_audio_source_store()
|
|
template_store = get_audio_template_store()
|
|
apt_store = get_audio_processing_template_store()
|
|
manager = get_processor_manager()
|
|
|
|
try:
|
|
resolved = store.resolve_audio_source(source_id)
|
|
except ValueError as e:
|
|
await websocket.close(code=4004, reason=str(e))
|
|
return
|
|
|
|
device_index = resolved.device_index
|
|
is_loopback = resolved.is_loopback
|
|
audio_template_id = resolved.audio_template_id
|
|
|
|
# Resolve capture template → engine_type + config
|
|
engine_type = None
|
|
engine_config = None
|
|
if audio_template_id:
|
|
try:
|
|
template = template_store.get_template(audio_template_id)
|
|
engine_type = template.engine_type
|
|
engine_config = template.engine_config
|
|
except ValueError as e:
|
|
logger.debug("Audio template not found, falling back to best available engine: %s", e)
|
|
pass # Fall back to best available engine
|
|
|
|
# Build filter pipeline from processing template chain
|
|
pipeline = None
|
|
if resolved.audio_processing_template_ids and apt_store:
|
|
pipeline = build_pipeline_from_template_ids(
|
|
resolved.audio_processing_template_ids, apt_store
|
|
)
|
|
if pipeline.empty:
|
|
pipeline = None
|
|
|
|
# Acquire shared audio stream
|
|
audio_mgr = manager.audio_capture_manager
|
|
try:
|
|
stream = audio_mgr.acquire(device_index, is_loopback, engine_type, engine_config)
|
|
except RuntimeError as e:
|
|
await websocket.close(code=4003, reason=str(e))
|
|
return
|
|
|
|
await websocket.accept()
|
|
logger.info(f"Audio test WebSocket connected for source {source_id}")
|
|
|
|
last_ts = 0.0
|
|
try:
|
|
while True:
|
|
analysis = stream.get_latest_analysis()
|
|
if analysis is not None and analysis.timestamp != last_ts:
|
|
last_ts = analysis.timestamp
|
|
|
|
# Apply filter pipeline (channel extract, band extract, gain, etc.)
|
|
if pipeline is not None:
|
|
analysis = pipeline.process(analysis)
|
|
|
|
await websocket.send_json(
|
|
{
|
|
"spectrum": analysis.spectrum.tolist(),
|
|
"rms": round(analysis.rms, 4),
|
|
"peak": round(analysis.peak, 4),
|
|
"beat": analysis.beat,
|
|
"beat_intensity": round(analysis.beat_intensity, 4),
|
|
}
|
|
)
|
|
|
|
await asyncio.sleep(0.05)
|
|
except WebSocketDisconnect:
|
|
logger.debug("Audio test WebSocket disconnected for source %s", source_id)
|
|
pass
|
|
except Exception as e:
|
|
logger.error(f"Audio test WebSocket error for {source_id}: {e}")
|
|
finally:
|
|
if pipeline is not None:
|
|
pipeline.close()
|
|
audio_mgr.release(device_index, is_loopback, engine_type)
|
|
logger.info(f"Audio test WebSocket disconnected for source {source_id}")
|