feat: HA light target live color preview — per-entity swatches via WebSocket
Lint & Test / test (push) Successful in 1m24s
Lint & Test / test (push) Successful in 1m24s
- Cache per-entity colors in HALightTargetProcessor._update_lights()
- Broadcast colors_update to WS clients at target's update_rate
- WS endpoint: /api/v1/output-targets/{target_id}/ha-light/ws
- Frontend: connect WS when target runs, update swatch colors live
- Card shows colored boxes per mapped entity with entity name labels
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
Extracted from output_targets.py to keep files under 800 lines.
|
||||
"""
|
||||
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from wled_controller.api.auth import AuthRequired
|
||||
@@ -22,7 +21,10 @@ from wled_controller.api.schemas.output_targets import (
|
||||
from wled_controller.core.processing.processor_manager import ProcessorManager
|
||||
from wled_controller.core.capture.screen_capture import get_available_displays
|
||||
from wled_controller.storage.color_strip_store import ColorStripStore
|
||||
from wled_controller.storage.color_strip_source import AdvancedPictureColorStripSource, PictureColorStripSource
|
||||
from wled_controller.storage.color_strip_source import (
|
||||
AdvancedPictureColorStripSource,
|
||||
PictureColorStripSource,
|
||||
)
|
||||
from wled_controller.storage.picture_source_store import PictureSourceStore
|
||||
from wled_controller.storage.wled_output_target import WledOutputTarget
|
||||
from wled_controller.storage.output_target_store import OutputTargetStore
|
||||
@@ -35,7 +37,10 @@ router = APIRouter()
|
||||
|
||||
# ===== BULK PROCESSING CONTROL ENDPOINTS =====
|
||||
|
||||
@router.post("/api/v1/output-targets/bulk/start", response_model=BulkTargetResponse, tags=["Processing"])
|
||||
|
||||
@router.post(
|
||||
"/api/v1/output-targets/bulk/start", response_model=BulkTargetResponse, tags=["Processing"]
|
||||
)
|
||||
async def bulk_start_processing(
|
||||
body: BulkTargetRequest,
|
||||
_auth: AuthRequired,
|
||||
@@ -67,7 +72,9 @@ async def bulk_start_processing(
|
||||
return BulkTargetResponse(started=started, errors=errors)
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/bulk/stop", response_model=BulkTargetResponse, tags=["Processing"])
|
||||
@router.post(
|
||||
"/api/v1/output-targets/bulk/stop", response_model=BulkTargetResponse, tags=["Processing"]
|
||||
)
|
||||
async def bulk_stop_processing(
|
||||
body: BulkTargetRequest,
|
||||
_auth: AuthRequired,
|
||||
@@ -93,6 +100,7 @@ async def bulk_stop_processing(
|
||||
|
||||
# ===== PROCESSING CONTROL ENDPOINTS =====
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/start", tags=["Processing"])
|
||||
async def start_processing(
|
||||
target_id: str,
|
||||
@@ -146,7 +154,12 @@ async def stop_processing(
|
||||
|
||||
# ===== STATE & METRICS ENDPOINTS =====
|
||||
|
||||
@router.get("/api/v1/output-targets/{target_id}/state", response_model=TargetProcessingState, tags=["Processing"])
|
||||
|
||||
@router.get(
|
||||
"/api/v1/output-targets/{target_id}/state",
|
||||
response_model=TargetProcessingState,
|
||||
tags=["Processing"],
|
||||
)
|
||||
async def get_target_state(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
@@ -164,7 +177,11 @@ async def get_target_state(
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/api/v1/output-targets/{target_id}/metrics", response_model=TargetMetricsResponse, tags=["Metrics"])
|
||||
@router.get(
|
||||
"/api/v1/output-targets/{target_id}/metrics",
|
||||
response_model=TargetMetricsResponse,
|
||||
tags=["Metrics"],
|
||||
)
|
||||
async def get_target_metrics(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
@@ -192,6 +209,7 @@ async def events_ws(
|
||||
):
|
||||
"""WebSocket for real-time state change events. Auth via ?token=<api_key>."""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
@@ -215,6 +233,7 @@ async def events_ws(
|
||||
|
||||
# ===== OVERLAY VISUALIZATION =====
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/overlay/start", tags=["Visualization"])
|
||||
async def start_target_overlay(
|
||||
target_id: str,
|
||||
@@ -247,10 +266,16 @@ async def start_target_overlay(
|
||||
if first_css_id:
|
||||
try:
|
||||
css = color_strip_store.get_source(first_css_id)
|
||||
if isinstance(css, (PictureColorStripSource, AdvancedPictureColorStripSource)) and css.calibration:
|
||||
if (
|
||||
isinstance(css, (PictureColorStripSource, AdvancedPictureColorStripSource))
|
||||
and css.calibration
|
||||
):
|
||||
calibration = css.calibration
|
||||
# Resolve the display this CSS is capturing
|
||||
from wled_controller.api.routes.color_strip_sources import _resolve_display_index
|
||||
from wled_controller.api.routes.color_strip_sources import (
|
||||
_resolve_display_index,
|
||||
)
|
||||
|
||||
ps_id = getattr(css, "picture_source_id", "") or ""
|
||||
display_index = _resolve_display_index(ps_id, picture_source_store)
|
||||
displays = get_available_displays()
|
||||
@@ -258,9 +283,13 @@ async def start_target_overlay(
|
||||
display_index = min(display_index, len(displays) - 1)
|
||||
display_info = displays[display_index]
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not pre-load CSS calibration for overlay on {target_id}: {e}")
|
||||
logger.warning(
|
||||
f"Could not pre-load CSS calibration for overlay on {target_id}: {e}"
|
||||
)
|
||||
|
||||
await manager.start_overlay(target_id, target.name, calibration=calibration, display_info=display_info)
|
||||
await manager.start_overlay(
|
||||
target_id, target.name, calibration=calibration, display_info=display_info
|
||||
)
|
||||
return {"status": "started", "target_id": target_id}
|
||||
|
||||
except ValueError as e:
|
||||
@@ -305,8 +334,55 @@ async def get_overlay_status(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# ===== HA LIGHT COLOR PREVIEW WEBSOCKET =====
|
||||
|
||||
|
||||
@router.websocket("/api/v1/output-targets/{target_id}/ha-light/ws")
|
||||
async def ha_light_colors_ws(
|
||||
websocket: WebSocket,
|
||||
target_id: str,
|
||||
token: str = Query(""),
|
||||
):
|
||||
"""WebSocket for live HA light entity color preview.
|
||||
|
||||
Streams: {"type": "colors_update", "colors": {entity_id: {r,g,b,hex}, ...}}
|
||||
at the target's update_rate.
|
||||
"""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
manager: ProcessorManager = get_processor_manager()
|
||||
|
||||
try:
|
||||
proc = manager._processors.get(target_id)
|
||||
if not proc or not proc.is_running:
|
||||
await websocket.close(code=4003, reason="Target not running")
|
||||
return
|
||||
except Exception as e:
|
||||
await websocket.close(code=4004, reason=str(e))
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
manager.add_ha_light_ws_client(target_id, websocket)
|
||||
while True:
|
||||
# Keep connection alive — wait for client disconnect
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
manager.remove_ha_light_ws_client(target_id, websocket)
|
||||
|
||||
|
||||
# ===== LED PREVIEW WEBSOCKET =====
|
||||
|
||||
|
||||
@router.websocket("/api/v1/output-targets/{target_id}/led-preview/ws")
|
||||
async def led_preview_ws(
|
||||
websocket: WebSocket,
|
||||
@@ -315,6 +391,7 @@ async def led_preview_ws(
|
||||
):
|
||||
"""WebSocket for real-time LED strip preview. Sends binary RGB frames. Auth via ?token=<api_key>."""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
@@ -1,528 +0,0 @@
|
||||
"""Output target routes: key colors endpoints, testing, and WebSocket streams.
|
||||
|
||||
Extracted from output_targets.py to keep files under 800 lines.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from wled_controller.api.auth import AuthRequired
|
||||
from wled_controller.api.dependencies import (
|
||||
get_device_store,
|
||||
get_output_target_store,
|
||||
get_pattern_template_store,
|
||||
get_picture_source_store,
|
||||
get_pp_template_store,
|
||||
get_processor_manager,
|
||||
get_template_store,
|
||||
)
|
||||
from wled_controller.api.schemas.output_targets import (
|
||||
ExtractedColorResponse,
|
||||
KCTestRectangleResponse,
|
||||
KCTestResponse,
|
||||
KeyColorsResponse,
|
||||
)
|
||||
from wled_controller.core.capture_engines import EngineRegistry
|
||||
from wled_controller.core.filters import FilterRegistry, ImagePool
|
||||
from wled_controller.core.processing.processor_manager import ProcessorManager
|
||||
from wled_controller.core.capture.screen_capture import (
|
||||
calculate_average_color,
|
||||
calculate_dominant_color,
|
||||
calculate_median_color,
|
||||
)
|
||||
from wled_controller.storage import DeviceStore
|
||||
from wled_controller.storage.pattern_template_store import PatternTemplateStore
|
||||
from wled_controller.storage.picture_source import ScreenCapturePictureSource, StaticImagePictureSource
|
||||
from wled_controller.storage.picture_source_store import PictureSourceStore
|
||||
from wled_controller.storage.template_store import TemplateStore
|
||||
from wled_controller.storage.key_colors_output_target import KeyColorsOutputTarget
|
||||
from wled_controller.storage.output_target_store import OutputTargetStore
|
||||
from wled_controller.storage.base_store import EntityNotFoundError
|
||||
from wled_controller.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ===== KEY COLORS ENDPOINTS =====
|
||||
|
||||
@router.get("/api/v1/output-targets/{target_id}/colors", response_model=KeyColorsResponse, tags=["Key Colors"])
|
||||
async def get_target_colors(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Get latest extracted colors for a key-colors target (polling)."""
|
||||
try:
|
||||
raw_colors = manager.get_kc_latest_colors(target_id)
|
||||
colors = {}
|
||||
for name, (r, g, b) in raw_colors.items():
|
||||
colors[name] = ExtractedColorResponse(
|
||||
r=r, g=g, b=b,
|
||||
hex=f"#{r:02x}{g:02x}{b:02x}",
|
||||
)
|
||||
from datetime import datetime, timezone
|
||||
return KeyColorsResponse(
|
||||
target_id=target_id,
|
||||
colors=colors,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/test", response_model=KCTestResponse, tags=["Key Colors"])
|
||||
async def test_kc_target(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
target_store: OutputTargetStore = Depends(get_output_target_store),
|
||||
source_store: PictureSourceStore = Depends(get_picture_source_store),
|
||||
template_store: TemplateStore = Depends(get_template_store),
|
||||
pattern_store: PatternTemplateStore = Depends(get_pattern_template_store),
|
||||
processor_manager: ProcessorManager = Depends(get_processor_manager),
|
||||
device_store: DeviceStore = Depends(get_device_store),
|
||||
pp_template_store=Depends(get_pp_template_store),
|
||||
):
|
||||
"""Test a key-colors target: capture a frame, extract colors from each rectangle."""
|
||||
|
||||
stream = None
|
||||
try:
|
||||
# 1. Load and validate KC target
|
||||
try:
|
||||
target = target_store.get_target(target_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if not isinstance(target, KeyColorsOutputTarget):
|
||||
raise HTTPException(status_code=400, detail="Target is not a key_colors target")
|
||||
|
||||
settings = target.settings
|
||||
|
||||
# 2. Resolve pattern template
|
||||
if not settings.pattern_template_id:
|
||||
raise HTTPException(status_code=400, detail="No pattern template configured")
|
||||
|
||||
try:
|
||||
pattern_tmpl = pattern_store.get_template(settings.pattern_template_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Pattern template not found: {settings.pattern_template_id}")
|
||||
|
||||
rectangles = pattern_tmpl.rectangles
|
||||
if not rectangles:
|
||||
raise HTTPException(status_code=400, detail="Pattern template has no rectangles")
|
||||
|
||||
# 3. Resolve picture source and capture a frame
|
||||
if not target.picture_source_id:
|
||||
raise HTTPException(status_code=400, detail="No picture source configured")
|
||||
|
||||
try:
|
||||
chain = source_store.resolve_stream_chain(target.picture_source_id)
|
||||
except EntityNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
raw_stream = chain["raw_stream"]
|
||||
|
||||
from wled_controller.utils.image_codec import load_image_file
|
||||
|
||||
if isinstance(raw_stream, StaticImagePictureSource):
|
||||
from wled_controller.api.dependencies import get_asset_store as _get_asset_store
|
||||
|
||||
asset_store = _get_asset_store()
|
||||
image_path = asset_store.get_file_path(raw_stream.image_asset_id) if raw_stream.image_asset_id else None
|
||||
if not image_path:
|
||||
raise HTTPException(status_code=400, detail="Image asset not found or missing file")
|
||||
image = load_image_file(image_path)
|
||||
|
||||
elif isinstance(raw_stream, ScreenCapturePictureSource):
|
||||
try:
|
||||
capture_template = template_store.get_template(raw_stream.capture_template_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Capture template not found: {raw_stream.capture_template_id}",
|
||||
)
|
||||
|
||||
display_index = raw_stream.display_index
|
||||
|
||||
if capture_template.engine_type not in EngineRegistry.get_available_engines():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Engine '{capture_template.engine_type}' is not available on this system",
|
||||
)
|
||||
|
||||
locked_device_id = processor_manager.get_display_lock_info(display_index)
|
||||
if locked_device_id:
|
||||
try:
|
||||
device = device_store.get_device(locked_device_id)
|
||||
device_name = device.name
|
||||
except Exception:
|
||||
device_name = locked_device_id
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Display {display_index} is currently being captured by device '{device_name}'. "
|
||||
f"Please stop the device processing before testing.",
|
||||
)
|
||||
|
||||
stream = EngineRegistry.create_stream(
|
||||
capture_template.engine_type, display_index, capture_template.engine_config
|
||||
)
|
||||
stream.initialize()
|
||||
|
||||
screen_capture = stream.capture_frame()
|
||||
if screen_capture is None:
|
||||
raise RuntimeError("No frame captured")
|
||||
|
||||
if not isinstance(screen_capture.image, np.ndarray):
|
||||
raise ValueError("Unexpected image format from engine")
|
||||
image = screen_capture.image
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Unsupported picture source type")
|
||||
|
||||
# 3b. Apply postprocessing filters (if the picture source has a filter chain)
|
||||
pp_template_ids = chain.get("postprocessing_template_ids", [])
|
||||
if pp_template_ids and pp_template_store:
|
||||
image_pool = ImagePool()
|
||||
for pp_id in pp_template_ids:
|
||||
try:
|
||||
pp_template = pp_template_store.get_template(pp_id)
|
||||
except ValueError:
|
||||
logger.warning(f"KC test: PP template {pp_id} not found, skipping")
|
||||
continue
|
||||
flat_filters = pp_template_store.resolve_filter_instances(pp_template.filters)
|
||||
for fi in flat_filters:
|
||||
try:
|
||||
f = FilterRegistry.create_instance(fi.filter_id, fi.options)
|
||||
result = f.process_image(image, image_pool)
|
||||
if result is not None:
|
||||
image = result
|
||||
except ValueError:
|
||||
logger.warning(f"KC test: unknown filter '{fi.filter_id}', skipping")
|
||||
|
||||
# 4. Extract colors from each rectangle
|
||||
img_array = image
|
||||
h, w = img_array.shape[:2]
|
||||
|
||||
calc_fns = {
|
||||
"average": calculate_average_color,
|
||||
"median": calculate_median_color,
|
||||
"dominant": calculate_dominant_color,
|
||||
}
|
||||
calc_fn = calc_fns.get(settings.interpolation_mode, calculate_average_color)
|
||||
|
||||
result_rects = []
|
||||
for rect in rectangles:
|
||||
px_x = max(0, int(rect.x * w))
|
||||
px_y = max(0, int(rect.y * h))
|
||||
px_w = max(1, int(rect.width * w))
|
||||
px_h = max(1, int(rect.height * h))
|
||||
px_x = min(px_x, w - 1)
|
||||
px_y = min(px_y, h - 1)
|
||||
px_w = min(px_w, w - px_x)
|
||||
px_h = min(px_h, h - px_y)
|
||||
|
||||
sub_img = img_array[px_y:px_y + px_h, px_x:px_x + px_w]
|
||||
r, g, b = calc_fn(sub_img)
|
||||
|
||||
result_rects.append(KCTestRectangleResponse(
|
||||
name=rect.name,
|
||||
x=rect.x,
|
||||
y=rect.y,
|
||||
width=rect.width,
|
||||
height=rect.height,
|
||||
color=ExtractedColorResponse(r=r, g=g, b=b, hex=f"#{r:02x}{g:02x}{b:02x}"),
|
||||
))
|
||||
|
||||
# 5. Encode frame as base64 JPEG
|
||||
from wled_controller.utils.image_codec import encode_jpeg_data_uri
|
||||
image_data_uri = encode_jpeg_data_uri(image, quality=90)
|
||||
|
||||
return KCTestResponse(
|
||||
image=image_data_uri,
|
||||
rectangles=result_rects,
|
||||
interpolation_mode=settings.interpolation_mode,
|
||||
pattern_template_name=pattern_tmpl.name,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except EntityNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error("Capture error during KC target test: %s", e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
except Exception as e:
|
||||
logger.error("Failed to test KC target: %s", e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
finally:
|
||||
if stream:
|
||||
try:
|
||||
stream.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up test stream: {e}")
|
||||
|
||||
|
||||
@router.websocket("/api/v1/output-targets/{target_id}/test/ws")
|
||||
async def test_kc_target_ws(
|
||||
websocket: WebSocket,
|
||||
target_id: str,
|
||||
token: str = Query(""),
|
||||
fps: int = Query(3),
|
||||
preview_width: int = Query(480),
|
||||
):
|
||||
"""WebSocket for real-time KC target test preview. Auth via ?token=<api_key>.
|
||||
|
||||
Streams JSON frames: {"type": "frame", "image": "data:image/jpeg;base64,...",
|
||||
"rectangles": [...], "pattern_template_name": "...", "interpolation_mode": "..."}
|
||||
"""
|
||||
import json as _json
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
# Load stores
|
||||
target_store_inst: OutputTargetStore = get_output_target_store()
|
||||
source_store_inst: PictureSourceStore = get_picture_source_store()
|
||||
get_template_store()
|
||||
pattern_store_inst: PatternTemplateStore = get_pattern_template_store()
|
||||
processor_manager_inst: ProcessorManager = get_processor_manager()
|
||||
device_store_inst: DeviceStore = get_device_store()
|
||||
pp_template_store_inst = get_pp_template_store()
|
||||
|
||||
# Validate target
|
||||
try:
|
||||
target = target_store_inst.get_target(target_id)
|
||||
except ValueError as e:
|
||||
await websocket.close(code=4004, reason=str(e))
|
||||
return
|
||||
|
||||
if not isinstance(target, KeyColorsOutputTarget):
|
||||
await websocket.close(code=4003, reason="Target is not a key_colors target")
|
||||
return
|
||||
|
||||
settings = target.settings
|
||||
|
||||
if not settings.pattern_template_id:
|
||||
await websocket.close(code=4003, reason="No pattern template configured")
|
||||
return
|
||||
|
||||
try:
|
||||
pattern_tmpl = pattern_store_inst.get_template(settings.pattern_template_id)
|
||||
except ValueError:
|
||||
await websocket.close(code=4003, reason=f"Pattern template not found: {settings.pattern_template_id}")
|
||||
return
|
||||
|
||||
rectangles = pattern_tmpl.rectangles
|
||||
if not rectangles:
|
||||
await websocket.close(code=4003, reason="Pattern template has no rectangles")
|
||||
return
|
||||
|
||||
if not target.picture_source_id:
|
||||
await websocket.close(code=4003, reason="No picture source configured")
|
||||
return
|
||||
|
||||
try:
|
||||
chain = source_store_inst.resolve_stream_chain(target.picture_source_id)
|
||||
except ValueError as e:
|
||||
await websocket.close(code=4003, reason=str(e))
|
||||
return
|
||||
|
||||
raw_stream = chain["raw_stream"]
|
||||
|
||||
# For screen capture sources, check display lock
|
||||
if isinstance(raw_stream, ScreenCapturePictureSource):
|
||||
display_index = raw_stream.display_index
|
||||
locked_device_id = processor_manager_inst.get_display_lock_info(display_index)
|
||||
if locked_device_id:
|
||||
try:
|
||||
device = device_store_inst.get_device(locked_device_id)
|
||||
device_name = device.name
|
||||
except Exception:
|
||||
device_name = locked_device_id
|
||||
await websocket.close(
|
||||
code=4003,
|
||||
reason=f"Display {display_index} is captured by '{device_name}'. Stop processing first.",
|
||||
)
|
||||
return
|
||||
|
||||
fps = max(1, min(30, fps))
|
||||
preview_width = max(120, min(1920, preview_width))
|
||||
frame_interval = 1.0 / fps
|
||||
|
||||
calc_fns = {
|
||||
"average": calculate_average_color,
|
||||
"median": calculate_median_color,
|
||||
"dominant": calculate_dominant_color,
|
||||
}
|
||||
calc_fn = calc_fns.get(settings.interpolation_mode, calculate_average_color)
|
||||
|
||||
await websocket.accept()
|
||||
logger.info(f"KC test WS connected for {target_id} (fps={fps})")
|
||||
|
||||
# Use the shared LiveStreamManager so we share the capture stream with
|
||||
# running LED targets instead of creating a competing DXGI duplicator.
|
||||
live_stream_mgr = processor_manager_inst._live_stream_manager
|
||||
live_stream = None
|
||||
|
||||
try:
|
||||
live_stream = await asyncio.to_thread(
|
||||
live_stream_mgr.acquire, target.picture_source_id
|
||||
)
|
||||
logger.info(f"KC test WS acquired shared live stream for {target.picture_source_id}")
|
||||
|
||||
prev_frame_ref = None
|
||||
|
||||
while True:
|
||||
loop_start = time.monotonic()
|
||||
|
||||
try:
|
||||
capture = await asyncio.to_thread(live_stream.get_latest_frame)
|
||||
|
||||
if capture is None or capture.image is None:
|
||||
await asyncio.sleep(frame_interval)
|
||||
continue
|
||||
|
||||
# Skip if same frame object (no new capture yet)
|
||||
if capture is prev_frame_ref:
|
||||
await asyncio.sleep(frame_interval * 0.5)
|
||||
continue
|
||||
prev_frame_ref = capture
|
||||
|
||||
if not isinstance(capture.image, np.ndarray):
|
||||
await asyncio.sleep(frame_interval)
|
||||
continue
|
||||
cur_image = capture.image
|
||||
if cur_image is None:
|
||||
await asyncio.sleep(frame_interval)
|
||||
continue
|
||||
|
||||
# Apply postprocessing (if the source chain has PP templates)
|
||||
chain = source_store_inst.resolve_stream_chain(target.picture_source_id)
|
||||
pp_template_ids = chain.get("postprocessing_template_ids", [])
|
||||
if pp_template_ids and pp_template_store_inst:
|
||||
image_pool = ImagePool()
|
||||
for pp_id in pp_template_ids:
|
||||
try:
|
||||
pp_template = pp_template_store_inst.get_template(pp_id)
|
||||
except ValueError as e:
|
||||
logger.debug("PP template %s not found during KC test: %s", pp_id, e)
|
||||
continue
|
||||
flat_filters = pp_template_store_inst.resolve_filter_instances(pp_template.filters)
|
||||
for fi in flat_filters:
|
||||
try:
|
||||
f = FilterRegistry.create_instance(fi.filter_id, fi.options)
|
||||
result = f.process_image(cur_image, image_pool)
|
||||
if result is not None:
|
||||
cur_image = result
|
||||
except ValueError as e:
|
||||
logger.debug("Filter processing error during KC test: %s", e)
|
||||
pass
|
||||
|
||||
# Extract colors
|
||||
img_array = cur_image
|
||||
h, w = img_array.shape[:2]
|
||||
|
||||
result_rects = []
|
||||
for rect in rectangles:
|
||||
px_x = max(0, int(rect.x * w))
|
||||
px_y = max(0, int(rect.y * h))
|
||||
px_w = max(1, int(rect.width * w))
|
||||
px_h = max(1, int(rect.height * h))
|
||||
px_x = min(px_x, w - 1)
|
||||
px_y = min(px_y, h - 1)
|
||||
px_w = min(px_w, w - px_x)
|
||||
px_h = min(px_h, h - px_y)
|
||||
|
||||
sub_img = img_array[px_y:px_y + px_h, px_x:px_x + px_w]
|
||||
r, g, b = calc_fn(sub_img)
|
||||
|
||||
result_rects.append({
|
||||
"name": rect.name,
|
||||
"x": rect.x,
|
||||
"y": rect.y,
|
||||
"width": rect.width,
|
||||
"height": rect.height,
|
||||
"color": {"r": r, "g": g, "b": b, "hex": f"#{r:02x}{g:02x}{b:02x}"},
|
||||
})
|
||||
|
||||
# Encode frame as JPEG
|
||||
from wled_controller.utils.image_codec import encode_jpeg_data_uri, resize_down
|
||||
frame_to_encode = resize_down(cur_image, preview_width) if preview_width else cur_image
|
||||
frame_uri = encode_jpeg_data_uri(frame_to_encode, quality=85)
|
||||
|
||||
await websocket.send_text(_json.dumps({
|
||||
"type": "frame",
|
||||
"image": frame_uri,
|
||||
"rectangles": result_rects,
|
||||
"pattern_template_name": pattern_tmpl.name,
|
||||
"interpolation_mode": settings.interpolation_mode,
|
||||
}))
|
||||
|
||||
except (WebSocketDisconnect, Exception) as inner_e:
|
||||
if isinstance(inner_e, WebSocketDisconnect):
|
||||
raise
|
||||
logger.warning(f"KC test WS frame error for {target_id}: {inner_e}")
|
||||
|
||||
elapsed = time.monotonic() - loop_start
|
||||
sleep_time = frame_interval - elapsed
|
||||
if sleep_time > 0:
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"KC test WS disconnected for {target_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"KC test WS error for {target_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
if live_stream is not None:
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
live_stream_mgr.release, target.picture_source_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Live stream release during KC test cleanup: %s", e)
|
||||
pass
|
||||
logger.info(f"KC test WS closed for {target_id}")
|
||||
|
||||
|
||||
@router.websocket("/api/v1/output-targets/{target_id}/ws")
|
||||
async def target_colors_ws(
|
||||
websocket: WebSocket,
|
||||
target_id: str,
|
||||
token: str = Query(""),
|
||||
):
|
||||
"""WebSocket for real-time key color updates. Auth via ?token=<api_key>."""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
manager = get_processor_manager()
|
||||
|
||||
try:
|
||||
manager.add_kc_ws_client(target_id, websocket)
|
||||
except ValueError:
|
||||
await websocket.close(code=4004, reason="Target not found")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Keep alive — wait for client messages (or disconnect)
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("KC live WebSocket disconnected for target %s", target_id)
|
||||
pass
|
||||
finally:
|
||||
manager.remove_kc_ws_client(target_id, websocket)
|
||||
@@ -15,7 +15,7 @@ from wled_controller.api.schemas.pattern_templates import (
|
||||
PatternTemplateUpdate,
|
||||
)
|
||||
from wled_controller.api.schemas.output_targets import KeyColorRectangleSchema
|
||||
from wled_controller.storage.key_colors_output_target import KeyColorRectangle
|
||||
from wled_controller.storage.pattern_template import KeyColorRectangle
|
||||
from wled_controller.storage.pattern_template_store import PatternTemplateStore
|
||||
from wled_controller.storage.output_target_store import OutputTargetStore
|
||||
from wled_controller.utils import get_logger
|
||||
@@ -42,7 +42,11 @@ def _pat_template_to_response(t) -> PatternTemplateResponse:
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/v1/pattern-templates", response_model=PatternTemplateListResponse, tags=["Pattern Templates"])
|
||||
@router.get(
|
||||
"/api/v1/pattern-templates",
|
||||
response_model=PatternTemplateListResponse,
|
||||
tags=["Pattern Templates"],
|
||||
)
|
||||
async def list_pattern_templates(
|
||||
_auth: AuthRequired,
|
||||
store: PatternTemplateStore = Depends(get_pattern_template_store),
|
||||
@@ -57,7 +61,12 @@ async def list_pattern_templates(
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post("/api/v1/pattern-templates", response_model=PatternTemplateResponse, tags=["Pattern Templates"], status_code=201)
|
||||
@router.post(
|
||||
"/api/v1/pattern-templates",
|
||||
response_model=PatternTemplateResponse,
|
||||
tags=["Pattern Templates"],
|
||||
status_code=201,
|
||||
)
|
||||
async def create_pattern_template(
|
||||
data: PatternTemplateCreate,
|
||||
_auth: AuthRequired,
|
||||
@@ -87,7 +96,11 @@ async def create_pattern_template(
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.get("/api/v1/pattern-templates/{template_id}", response_model=PatternTemplateResponse, tags=["Pattern Templates"])
|
||||
@router.get(
|
||||
"/api/v1/pattern-templates/{template_id}",
|
||||
response_model=PatternTemplateResponse,
|
||||
tags=["Pattern Templates"],
|
||||
)
|
||||
async def get_pattern_template(
|
||||
template_id: str,
|
||||
_auth: AuthRequired,
|
||||
@@ -101,7 +114,11 @@ async def get_pattern_template(
|
||||
raise HTTPException(status_code=404, detail=f"Pattern template {template_id} not found")
|
||||
|
||||
|
||||
@router.put("/api/v1/pattern-templates/{template_id}", response_model=PatternTemplateResponse, tags=["Pattern Templates"])
|
||||
@router.put(
|
||||
"/api/v1/pattern-templates/{template_id}",
|
||||
response_model=PatternTemplateResponse,
|
||||
tags=["Pattern Templates"],
|
||||
)
|
||||
async def update_pattern_template(
|
||||
template_id: str,
|
||||
data: PatternTemplateUpdate,
|
||||
@@ -135,7 +152,9 @@ async def update_pattern_template(
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.delete("/api/v1/pattern-templates/{template_id}", status_code=204, tags=["Pattern Templates"])
|
||||
@router.delete(
|
||||
"/api/v1/pattern-templates/{template_id}", status_code=204, tags=["Pattern Templates"]
|
||||
)
|
||||
async def delete_pattern_template(
|
||||
template_id: str,
|
||||
_auth: AuthRequired,
|
||||
@@ -150,7 +169,7 @@ async def delete_pattern_template(
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Cannot delete pattern template: it is referenced by target(s): {names}. "
|
||||
"Please reassign those targets before deleting.",
|
||||
"Please reassign those targets before deleting.",
|
||||
)
|
||||
store.delete_template(template_id)
|
||||
fire_entity_event("pattern_template", "deleted", template_id)
|
||||
|
||||
@@ -18,42 +18,6 @@ class KeyColorRectangleSchema(BaseModel):
|
||||
height: float = Field(default=1.0, description="Height (0.0-1.0)", gt=0.0, le=1.0)
|
||||
|
||||
|
||||
class KeyColorsSettingsSchema(BaseModel):
|
||||
"""Settings for key colors extraction."""
|
||||
|
||||
fps: int = Field(default=10, description="Extraction rate (1-60)", ge=1, le=60)
|
||||
interpolation_mode: str = Field(
|
||||
default="average", description="Color mode (average, median, dominant)"
|
||||
)
|
||||
smoothing: float = Field(
|
||||
default=0.3, description="Temporal smoothing (0.0-1.0)", ge=0.0, le=1.0
|
||||
)
|
||||
pattern_template_id: str = Field(
|
||||
default="", description="Pattern template ID for rectangle layout"
|
||||
)
|
||||
brightness: float = Field(
|
||||
default=1.0, description="Output brightness (0.0-1.0)", ge=0.0, le=1.0
|
||||
)
|
||||
brightness_value_source_id: str = Field(default="", description="Brightness value source ID")
|
||||
|
||||
|
||||
class ExtractedColorResponse(BaseModel):
|
||||
"""A single extracted color."""
|
||||
|
||||
r: int = Field(description="Red (0-255)")
|
||||
g: int = Field(description="Green (0-255)")
|
||||
b: int = Field(description="Blue (0-255)")
|
||||
hex: str = Field(description="Hex color (#rrggbb)")
|
||||
|
||||
|
||||
class KeyColorsResponse(BaseModel):
|
||||
"""Extracted key colors for a target."""
|
||||
|
||||
target_id: str = Field(description="Target ID")
|
||||
colors: Dict[str, ExtractedColorResponse] = Field(description="Rectangle name -> color")
|
||||
timestamp: Optional[datetime] = Field(None, description="Extraction timestamp")
|
||||
|
||||
|
||||
class HALightMappingSchema(BaseModel):
|
||||
"""Maps an LED range to one HA light entity."""
|
||||
|
||||
@@ -69,7 +33,7 @@ class OutputTargetCreate(BaseModel):
|
||||
"""Request to create an output target."""
|
||||
|
||||
name: str = Field(description="Target name", min_length=1, max_length=100)
|
||||
target_type: str = Field(default="led", description="Target type (led, key_colors, ha_light)")
|
||||
target_type: str = Field(default="led", description="Target type (led, ha_light)")
|
||||
# LED target fields
|
||||
device_id: str = Field(default="", description="LED device ID")
|
||||
color_strip_source_id: str = Field(default="", description="Color strip source ID")
|
||||
@@ -101,13 +65,6 @@ class OutputTargetCreate(BaseModel):
|
||||
pattern="^(ddp|http)$",
|
||||
description="Send protocol: ddp (UDP) or http (JSON API)",
|
||||
)
|
||||
# KC target fields
|
||||
picture_source_id: str = Field(
|
||||
default="", description="Picture source ID (for key_colors targets)"
|
||||
)
|
||||
key_colors_settings: Optional[KeyColorsSettingsSchema] = Field(
|
||||
None, description="Key colors settings (for key_colors targets)"
|
||||
)
|
||||
# HA light target fields
|
||||
ha_source_id: str = Field(
|
||||
default="", description="Home Assistant source ID (for ha_light targets)"
|
||||
@@ -157,13 +114,6 @@ class OutputTargetUpdate(BaseModel):
|
||||
protocol: Optional[str] = Field(
|
||||
None, pattern="^(ddp|http)$", description="Send protocol: ddp (UDP) or http (JSON API)"
|
||||
)
|
||||
# KC target fields
|
||||
picture_source_id: Optional[str] = Field(
|
||||
None, description="Picture source ID (for key_colors targets)"
|
||||
)
|
||||
key_colors_settings: Optional[KeyColorsSettingsSchema] = Field(
|
||||
None, description="Key colors settings (for key_colors targets)"
|
||||
)
|
||||
# HA light target fields
|
||||
ha_source_id: Optional[str] = Field(
|
||||
None, description="Home Assistant source ID (for ha_light targets)"
|
||||
@@ -206,11 +156,6 @@ class OutputTargetResponse(BaseModel):
|
||||
default=False, description="Auto-reduce FPS when device is unresponsive"
|
||||
)
|
||||
protocol: str = Field(default="ddp", description="Send protocol (ddp or http)")
|
||||
# KC target fields
|
||||
picture_source_id: str = Field(default="", description="Picture source ID (key_colors)")
|
||||
key_colors_settings: Optional[KeyColorsSettingsSchema] = Field(
|
||||
None, description="Key colors settings"
|
||||
)
|
||||
# HA light target fields
|
||||
ha_source_id: str = Field(default="", description="Home Assistant source ID (ha_light)")
|
||||
ha_light_mappings: Optional[List[HALightMappingSchema]] = Field(
|
||||
@@ -263,12 +208,6 @@ class TargetProcessingState(BaseModel):
|
||||
timing_audio_render_ms: Optional[float] = Field(
|
||||
None, description="Audio visualization render time (ms)"
|
||||
)
|
||||
timing_calc_colors_ms: Optional[float] = Field(
|
||||
None, description="Color calculation time (ms, KC targets)"
|
||||
)
|
||||
timing_broadcast_ms: Optional[float] = Field(
|
||||
None, description="WebSocket broadcast time (ms, KC targets)"
|
||||
)
|
||||
display_index: Optional[int] = Field(None, description="Current display index")
|
||||
overlay_active: bool = Field(
|
||||
default=False, description="Whether visualization overlay is active"
|
||||
@@ -328,25 +267,3 @@ class BulkTargetResponse(BaseModel):
|
||||
errors: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Map of target ID to error message for failures"
|
||||
)
|
||||
|
||||
|
||||
class KCTestRectangleResponse(BaseModel):
|
||||
"""A rectangle with its extracted color from a KC test."""
|
||||
|
||||
name: str = Field(description="Rectangle name")
|
||||
x: float = Field(description="Left edge (0.0-1.0)")
|
||||
y: float = Field(description="Top edge (0.0-1.0)")
|
||||
width: float = Field(description="Width (0.0-1.0)")
|
||||
height: float = Field(description="Height (0.0-1.0)")
|
||||
color: ExtractedColorResponse = Field(description="Extracted color for this rectangle")
|
||||
|
||||
|
||||
class KCTestResponse(BaseModel):
|
||||
"""Response from testing a KC target."""
|
||||
|
||||
image: str = Field(description="Base64 data URI of the captured frame")
|
||||
rectangles: List[KCTestRectangleResponse] = Field(
|
||||
description="Rectangles with extracted colors"
|
||||
)
|
||||
interpolation_mode: str = Field(description="Color extraction mode used")
|
||||
pattern_template_name: str = Field(description="Pattern template name")
|
||||
|
||||
Reference in New Issue
Block a user