Files
wled-screen-controller-mixed/server/src/wled_controller/api/routes/postprocessing.py
alexei.dolgolyov 90acae5207 Fix test endpoints reporting pre-filter image dimensions
- WebSocket test: move w,h capture to after PP filter application
  so downscaler effect is reflected in reported resolution
- HTTP test: read actual thumbnail dimensions from filtered image
  instead of using pre-computed values

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-01 01:10:23 +03:00

435 lines
16 KiB
Python

"""Postprocessing template routes."""
import base64
import io
import time
import httpx
import numpy as np
from PIL import Image
from fastapi import APIRouter, HTTPException, Depends, Query, WebSocket, WebSocketDisconnect
from wled_controller.api.auth import AuthRequired
from wled_controller.api.dependencies import (
get_picture_source_store,
get_pp_template_store,
get_template_store,
)
from wled_controller.api.schemas.common import (
CaptureImage,
PerformanceMetrics,
TemplateTestResponse,
)
from wled_controller.api.schemas.filters import FilterInstanceSchema
from wled_controller.api.schemas.postprocessing import (
PostprocessingTemplateCreate,
PostprocessingTemplateListResponse,
PostprocessingTemplateResponse,
PostprocessingTemplateUpdate,
PPTemplateTestRequest,
)
from wled_controller.core.capture_engines import EngineRegistry
from wled_controller.core.filters import FilterRegistry, FilterInstance, ImagePool
from wled_controller.storage.template_store import TemplateStore
from wled_controller.storage.postprocessing_template_store import PostprocessingTemplateStore
from wled_controller.storage.picture_source_store import PictureSourceStore
from wled_controller.storage.picture_source import ScreenCapturePictureSource, StaticImagePictureSource
from wled_controller.utils import get_logger
logger = get_logger(__name__)
router = APIRouter()
def _pp_template_to_response(t) -> PostprocessingTemplateResponse:
"""Convert a PostprocessingTemplate to its API response."""
return PostprocessingTemplateResponse(
id=t.id,
name=t.name,
filters=[FilterInstanceSchema(filter_id=f.filter_id, options=f.options) for f in t.filters],
created_at=t.created_at,
updated_at=t.updated_at,
description=t.description,
)
@router.get("/api/v1/postprocessing-templates", response_model=PostprocessingTemplateListResponse, tags=["Postprocessing Templates"])
async def list_pp_templates(
_auth: AuthRequired,
store: PostprocessingTemplateStore = Depends(get_pp_template_store),
):
"""List all postprocessing templates."""
try:
templates = store.get_all_templates()
responses = [_pp_template_to_response(t) for t in templates]
return PostprocessingTemplateListResponse(templates=responses, count=len(responses))
except Exception as e:
logger.error(f"Failed to list postprocessing templates: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/api/v1/postprocessing-templates", response_model=PostprocessingTemplateResponse, tags=["Postprocessing Templates"], status_code=201)
async def create_pp_template(
data: PostprocessingTemplateCreate,
_auth: AuthRequired,
store: PostprocessingTemplateStore = Depends(get_pp_template_store),
):
"""Create a new postprocessing template."""
try:
filters = [FilterInstance(f.filter_id, f.options) for f in data.filters]
template = store.create_template(
name=data.name,
filters=filters,
description=data.description,
)
return _pp_template_to_response(template)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Failed to create postprocessing template: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/api/v1/postprocessing-templates/{template_id}", response_model=PostprocessingTemplateResponse, tags=["Postprocessing Templates"])
async def get_pp_template(
template_id: str,
_auth: AuthRequired,
store: PostprocessingTemplateStore = Depends(get_pp_template_store),
):
"""Get postprocessing template by ID."""
try:
template = store.get_template(template_id)
return _pp_template_to_response(template)
except ValueError:
raise HTTPException(status_code=404, detail=f"Postprocessing template {template_id} not found")
@router.put("/api/v1/postprocessing-templates/{template_id}", response_model=PostprocessingTemplateResponse, tags=["Postprocessing Templates"])
async def update_pp_template(
template_id: str,
data: PostprocessingTemplateUpdate,
_auth: AuthRequired,
store: PostprocessingTemplateStore = Depends(get_pp_template_store),
):
"""Update a postprocessing template."""
try:
filters = [FilterInstance(f.filter_id, f.options) for f in data.filters] if data.filters is not None else None
template = store.update_template(
template_id=template_id,
name=data.name,
filters=filters,
description=data.description,
)
return _pp_template_to_response(template)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Failed to update postprocessing template: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/api/v1/postprocessing-templates/{template_id}", status_code=204, tags=["Postprocessing Templates"])
async def delete_pp_template(
template_id: str,
_auth: AuthRequired,
store: PostprocessingTemplateStore = Depends(get_pp_template_store),
stream_store: PictureSourceStore = Depends(get_picture_source_store),
):
"""Delete a postprocessing template."""
try:
# Check if any picture source references this template
source_names = store.get_sources_referencing(template_id, stream_store)
if source_names:
names = ", ".join(source_names)
raise HTTPException(
status_code=409,
detail=f"Cannot delete postprocessing template: it is referenced by picture source(s): {names}. "
"Please reassign those streams before deleting.",
)
store.delete_template(template_id)
except HTTPException:
raise
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Failed to delete postprocessing template: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/api/v1/postprocessing-templates/{template_id}/test", response_model=TemplateTestResponse, tags=["Postprocessing Templates"])
async def test_pp_template(
template_id: str,
test_request: PPTemplateTestRequest,
_auth: AuthRequired,
pp_store: PostprocessingTemplateStore = Depends(get_pp_template_store),
stream_store: PictureSourceStore = Depends(get_picture_source_store),
template_store: TemplateStore = Depends(get_template_store),
):
"""Test a postprocessing template by capturing from a source stream and applying filters."""
stream = None
try:
# Get the PP template
try:
pp_template = pp_store.get_template(template_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
# Resolve source stream chain to get the raw stream
try:
chain = stream_store.resolve_stream_chain(test_request.source_stream_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
raw_stream = chain["raw_stream"]
if isinstance(raw_stream, StaticImagePictureSource):
# Static image: load directly
from pathlib import Path
source = raw_stream.image_source
start_time = time.perf_counter()
if source.startswith(("http://", "https://")):
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
resp = await client.get(source)
resp.raise_for_status()
pil_image = Image.open(io.BytesIO(resp.content)).convert("RGB")
else:
path = Path(source)
if not path.exists():
raise HTTPException(status_code=400, detail=f"Image file not found: {source}")
pil_image = Image.open(path).convert("RGB")
actual_duration = time.perf_counter() - start_time
frame_count = 1
total_capture_time = actual_duration
elif isinstance(raw_stream, ScreenCapturePictureSource):
# Screen capture stream: use engine
try:
capture_template = template_store.get_template(raw_stream.capture_template_id)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Capture template not found: {raw_stream.capture_template_id}",
)
display_index = raw_stream.display_index
if capture_template.engine_type not in EngineRegistry.get_available_engines():
raise HTTPException(
status_code=400,
detail=f"Engine '{capture_template.engine_type}' is not available on this system",
)
stream = EngineRegistry.create_stream(
capture_template.engine_type, display_index, capture_template.engine_config
)
stream.initialize()
logger.info(f"Starting {test_request.capture_duration}s PP template test for {template_id} using stream {test_request.source_stream_id}")
frame_count = 0
total_capture_time = 0.0
last_frame = None
start_time = time.perf_counter()
end_time = start_time + test_request.capture_duration
while time.perf_counter() < end_time:
capture_start = time.perf_counter()
screen_capture = stream.capture_frame()
capture_elapsed = time.perf_counter() - capture_start
if screen_capture is None:
continue
total_capture_time += capture_elapsed
frame_count += 1
last_frame = screen_capture
actual_duration = time.perf_counter() - start_time
if last_frame is None:
raise RuntimeError("No frames captured during test")
if isinstance(last_frame.image, np.ndarray):
pil_image = Image.fromarray(last_frame.image)
else:
raise ValueError("Unexpected image format from engine")
# Create thumbnail
thumbnail_width = 640
aspect_ratio = pil_image.height / pil_image.width
thumbnail_height = int(thumbnail_width * aspect_ratio)
thumbnail = pil_image.copy()
thumbnail.thumbnail((thumbnail_width, thumbnail_height), Image.Resampling.LANCZOS)
# Apply postprocessing filters (expand filter_template references)
flat_filters = pp_store.resolve_filter_instances(pp_template.filters)
if flat_filters:
pool = ImagePool()
def apply_filters(img):
arr = np.array(img)
for fi in flat_filters:
f = FilterRegistry.create_instance(fi.filter_id, fi.options)
result = f.process_image(arr, pool)
if result is not None:
arr = result
return Image.fromarray(arr)
thumbnail = apply_filters(thumbnail)
pil_image = apply_filters(pil_image)
# Encode thumbnail
img_buffer = io.BytesIO()
thumbnail.save(img_buffer, format='JPEG', quality=85)
img_buffer.seek(0)
thumbnail_b64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
thumbnail_data_uri = f"data:image/jpeg;base64,{thumbnail_b64}"
# Encode full-resolution image
full_buffer = io.BytesIO()
pil_image.save(full_buffer, format='JPEG', quality=90)
full_buffer.seek(0)
full_b64 = base64.b64encode(full_buffer.getvalue()).decode('utf-8')
full_data_uri = f"data:image/jpeg;base64,{full_b64}"
actual_fps = frame_count / actual_duration if actual_duration > 0 else 0
avg_capture_time_ms = (total_capture_time / frame_count * 1000) if frame_count > 0 else 0
width, height = pil_image.size
thumb_w, thumb_h = thumbnail.size
return TemplateTestResponse(
full_capture=CaptureImage(
image=thumbnail_data_uri,
full_image=full_data_uri,
width=width,
height=height,
thumbnail_width=thumb_w,
thumbnail_height=thumb_h,
),
border_extraction=None,
performance=PerformanceMetrics(
capture_duration_s=actual_duration,
frame_count=frame_count,
actual_fps=actual_fps,
avg_capture_time_ms=avg_capture_time_ms,
),
)
except HTTPException:
raise
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Postprocessing template test failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
finally:
if stream:
try:
stream.cleanup()
except Exception:
pass
# ===== REAL-TIME PP TEMPLATE TEST WEBSOCKET =====
@router.websocket("/api/v1/postprocessing-templates/{template_id}/test/ws")
async def test_pp_template_ws(
websocket: WebSocket,
template_id: str,
token: str = Query(""),
duration: float = Query(5.0),
source_stream_id: str = Query(""),
preview_width: int = Query(0),
):
"""WebSocket for PP template test with intermediate frame previews."""
from wled_controller.api.routes._test_helpers import (
authenticate_ws_token,
stream_capture_test,
)
from wled_controller.api.dependencies import (
get_picture_source_store as _get_ps_store,
get_template_store as _get_t_store,
get_pp_template_store as _get_pp_store,
)
if not authenticate_ws_token(token):
await websocket.close(code=4001, reason="Unauthorized")
return
if not source_stream_id:
await websocket.close(code=4003, reason="source_stream_id is required")
return
pp_store = _get_pp_store()
stream_store = _get_ps_store()
template_store = _get_t_store()
# Get PP template
try:
pp_template = pp_store.get_template(template_id)
except ValueError as e:
await websocket.close(code=4004, reason=str(e))
return
# Resolve source stream chain
try:
chain = stream_store.resolve_stream_chain(source_stream_id)
except ValueError as e:
await websocket.close(code=4004, reason=str(e))
return
raw_stream = chain["raw_stream"]
if isinstance(raw_stream, StaticImagePictureSource):
await websocket.close(code=4003, reason="Static image streams don't support live test")
return
if not isinstance(raw_stream, ScreenCapturePictureSource):
await websocket.close(code=4003, reason="Unsupported stream type for live test")
return
# Create capture engine
try:
capture_template = template_store.get_template(raw_stream.capture_template_id)
except ValueError as e:
await websocket.close(code=4004, reason=str(e))
return
if capture_template.engine_type not in EngineRegistry.get_available_engines():
await websocket.close(code=4003, reason=f"Engine '{capture_template.engine_type}' not available")
return
# Resolve PP filters
pp_filters = pp_store.resolve_filter_instances(pp_template.filters) or None
# Engine factory — creates + initializes engine inside the capture thread
# to avoid thread-affinity issues (e.g. MSS uses thread-local state)
_engine_type = capture_template.engine_type
_display_index = raw_stream.display_index
_engine_config = capture_template.engine_config
def engine_factory():
s = EngineRegistry.create_stream(_engine_type, _display_index, _engine_config)
s.initialize()
return s
await websocket.accept()
logger.info(f"PP template test WS connected for {template_id} ({duration}s)")
try:
await stream_capture_test(
websocket, engine_factory, duration,
pp_filters=pp_filters,
preview_width=preview_width or None,
)
except WebSocketDisconnect:
pass
except Exception as e:
logger.error(f"PP template test WS error for {template_id}: {e}")
finally:
logger.info(f"PP template test WS disconnected for {template_id}")