"""Postprocessing template routes.""" import base64 import io import time import httpx import numpy as np from PIL import Image from fastapi import APIRouter, HTTPException, Depends, Query, WebSocket, WebSocketDisconnect from wled_controller.api.auth import AuthRequired from wled_controller.api.dependencies import ( get_picture_source_store, get_pp_template_store, get_template_store, ) from wled_controller.api.schemas.common import ( CaptureImage, PerformanceMetrics, TemplateTestResponse, ) from wled_controller.api.schemas.filters import FilterInstanceSchema from wled_controller.api.schemas.postprocessing import ( PostprocessingTemplateCreate, PostprocessingTemplateListResponse, PostprocessingTemplateResponse, PostprocessingTemplateUpdate, PPTemplateTestRequest, ) from wled_controller.core.capture_engines import EngineRegistry from wled_controller.core.filters import FilterRegistry, FilterInstance, ImagePool from wled_controller.storage.template_store import TemplateStore from wled_controller.storage.postprocessing_template_store import PostprocessingTemplateStore from wled_controller.storage.picture_source_store import PictureSourceStore from wled_controller.storage.picture_source import ScreenCapturePictureSource, StaticImagePictureSource from wled_controller.utils import get_logger logger = get_logger(__name__) router = APIRouter() def _pp_template_to_response(t) -> PostprocessingTemplateResponse: """Convert a PostprocessingTemplate to its API response.""" return PostprocessingTemplateResponse( id=t.id, name=t.name, filters=[FilterInstanceSchema(filter_id=f.filter_id, options=f.options) for f in t.filters], created_at=t.created_at, updated_at=t.updated_at, description=t.description, ) @router.get("/api/v1/postprocessing-templates", response_model=PostprocessingTemplateListResponse, tags=["Postprocessing Templates"]) async def list_pp_templates( _auth: AuthRequired, store: PostprocessingTemplateStore = Depends(get_pp_template_store), ): """List all postprocessing templates.""" try: templates = store.get_all_templates() responses = [_pp_template_to_response(t) for t in templates] return PostprocessingTemplateListResponse(templates=responses, count=len(responses)) except Exception as e: logger.error(f"Failed to list postprocessing templates: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/api/v1/postprocessing-templates", response_model=PostprocessingTemplateResponse, tags=["Postprocessing Templates"], status_code=201) async def create_pp_template( data: PostprocessingTemplateCreate, _auth: AuthRequired, store: PostprocessingTemplateStore = Depends(get_pp_template_store), ): """Create a new postprocessing template.""" try: filters = [FilterInstance(f.filter_id, f.options) for f in data.filters] template = store.create_template( name=data.name, filters=filters, description=data.description, ) return _pp_template_to_response(template) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Failed to create postprocessing template: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/api/v1/postprocessing-templates/{template_id}", response_model=PostprocessingTemplateResponse, tags=["Postprocessing Templates"]) async def get_pp_template( template_id: str, _auth: AuthRequired, store: PostprocessingTemplateStore = Depends(get_pp_template_store), ): """Get postprocessing template by ID.""" try: template = store.get_template(template_id) return _pp_template_to_response(template) except ValueError: raise HTTPException(status_code=404, detail=f"Postprocessing template {template_id} not found") @router.put("/api/v1/postprocessing-templates/{template_id}", response_model=PostprocessingTemplateResponse, tags=["Postprocessing Templates"]) async def update_pp_template( template_id: str, data: PostprocessingTemplateUpdate, _auth: AuthRequired, store: PostprocessingTemplateStore = Depends(get_pp_template_store), ): """Update a postprocessing template.""" try: filters = [FilterInstance(f.filter_id, f.options) for f in data.filters] if data.filters is not None else None template = store.update_template( template_id=template_id, name=data.name, filters=filters, description=data.description, ) return _pp_template_to_response(template) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Failed to update postprocessing template: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.delete("/api/v1/postprocessing-templates/{template_id}", status_code=204, tags=["Postprocessing Templates"]) async def delete_pp_template( template_id: str, _auth: AuthRequired, store: PostprocessingTemplateStore = Depends(get_pp_template_store), stream_store: PictureSourceStore = Depends(get_picture_source_store), ): """Delete a postprocessing template.""" try: # Check if any picture source references this template source_names = store.get_sources_referencing(template_id, stream_store) if source_names: names = ", ".join(source_names) raise HTTPException( status_code=409, detail=f"Cannot delete postprocessing template: it is referenced by picture source(s): {names}. " "Please reassign those streams before deleting.", ) store.delete_template(template_id) except HTTPException: raise except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Failed to delete postprocessing template: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/api/v1/postprocessing-templates/{template_id}/test", response_model=TemplateTestResponse, tags=["Postprocessing Templates"]) async def test_pp_template( template_id: str, test_request: PPTemplateTestRequest, _auth: AuthRequired, pp_store: PostprocessingTemplateStore = Depends(get_pp_template_store), stream_store: PictureSourceStore = Depends(get_picture_source_store), template_store: TemplateStore = Depends(get_template_store), ): """Test a postprocessing template by capturing from a source stream and applying filters.""" stream = None try: # Get the PP template try: pp_template = pp_store.get_template(template_id) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) # Resolve source stream chain to get the raw stream try: chain = stream_store.resolve_stream_chain(test_request.source_stream_id) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) raw_stream = chain["raw_stream"] if isinstance(raw_stream, StaticImagePictureSource): # Static image: load directly from pathlib import Path source = raw_stream.image_source start_time = time.perf_counter() if source.startswith(("http://", "https://")): async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client: resp = await client.get(source) resp.raise_for_status() pil_image = Image.open(io.BytesIO(resp.content)).convert("RGB") else: path = Path(source) if not path.exists(): raise HTTPException(status_code=400, detail=f"Image file not found: {source}") pil_image = Image.open(path).convert("RGB") actual_duration = time.perf_counter() - start_time frame_count = 1 total_capture_time = actual_duration elif isinstance(raw_stream, ScreenCapturePictureSource): # Screen capture stream: use engine try: capture_template = template_store.get_template(raw_stream.capture_template_id) except ValueError: raise HTTPException( status_code=400, detail=f"Capture template not found: {raw_stream.capture_template_id}", ) display_index = raw_stream.display_index if capture_template.engine_type not in EngineRegistry.get_available_engines(): raise HTTPException( status_code=400, detail=f"Engine '{capture_template.engine_type}' is not available on this system", ) stream = EngineRegistry.create_stream( capture_template.engine_type, display_index, capture_template.engine_config ) stream.initialize() logger.info(f"Starting {test_request.capture_duration}s PP template test for {template_id} using stream {test_request.source_stream_id}") frame_count = 0 total_capture_time = 0.0 last_frame = None start_time = time.perf_counter() end_time = start_time + test_request.capture_duration while time.perf_counter() < end_time: capture_start = time.perf_counter() screen_capture = stream.capture_frame() capture_elapsed = time.perf_counter() - capture_start if screen_capture is None: continue total_capture_time += capture_elapsed frame_count += 1 last_frame = screen_capture actual_duration = time.perf_counter() - start_time if last_frame is None: raise RuntimeError("No frames captured during test") if isinstance(last_frame.image, np.ndarray): pil_image = Image.fromarray(last_frame.image) else: raise ValueError("Unexpected image format from engine") # Create thumbnail thumbnail_width = 640 aspect_ratio = pil_image.height / pil_image.width thumbnail_height = int(thumbnail_width * aspect_ratio) thumbnail = pil_image.copy() thumbnail.thumbnail((thumbnail_width, thumbnail_height), Image.Resampling.LANCZOS) # Apply postprocessing filters (expand filter_template references) flat_filters = pp_store.resolve_filter_instances(pp_template.filters) if flat_filters: pool = ImagePool() def apply_filters(img): arr = np.array(img) for fi in flat_filters: f = FilterRegistry.create_instance(fi.filter_id, fi.options) result = f.process_image(arr, pool) if result is not None: arr = result return Image.fromarray(arr) thumbnail = apply_filters(thumbnail) pil_image = apply_filters(pil_image) # Encode thumbnail img_buffer = io.BytesIO() thumbnail.save(img_buffer, format='JPEG', quality=85) img_buffer.seek(0) thumbnail_b64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8') thumbnail_data_uri = f"data:image/jpeg;base64,{thumbnail_b64}" # Encode full-resolution image full_buffer = io.BytesIO() pil_image.save(full_buffer, format='JPEG', quality=90) full_buffer.seek(0) full_b64 = base64.b64encode(full_buffer.getvalue()).decode('utf-8') full_data_uri = f"data:image/jpeg;base64,{full_b64}" actual_fps = frame_count / actual_duration if actual_duration > 0 else 0 avg_capture_time_ms = (total_capture_time / frame_count * 1000) if frame_count > 0 else 0 width, height = pil_image.size thumb_w, thumb_h = thumbnail.size return TemplateTestResponse( full_capture=CaptureImage( image=thumbnail_data_uri, full_image=full_data_uri, width=width, height=height, thumbnail_width=thumb_w, thumbnail_height=thumb_h, ), border_extraction=None, performance=PerformanceMetrics( capture_duration_s=actual_duration, frame_count=frame_count, actual_fps=actual_fps, avg_capture_time_ms=avg_capture_time_ms, ), ) except HTTPException: raise except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"Postprocessing template test failed: {e}") raise HTTPException(status_code=500, detail=str(e)) finally: if stream: try: stream.cleanup() except Exception: pass # ===== REAL-TIME PP TEMPLATE TEST WEBSOCKET ===== @router.websocket("/api/v1/postprocessing-templates/{template_id}/test/ws") async def test_pp_template_ws( websocket: WebSocket, template_id: str, token: str = Query(""), duration: float = Query(5.0), source_stream_id: str = Query(""), preview_width: int = Query(0), ): """WebSocket for PP template test with intermediate frame previews.""" from wled_controller.api.routes._test_helpers import ( authenticate_ws_token, stream_capture_test, ) from wled_controller.api.dependencies import ( get_picture_source_store as _get_ps_store, get_template_store as _get_t_store, get_pp_template_store as _get_pp_store, ) if not authenticate_ws_token(token): await websocket.close(code=4001, reason="Unauthorized") return if not source_stream_id: await websocket.close(code=4003, reason="source_stream_id is required") return pp_store = _get_pp_store() stream_store = _get_ps_store() template_store = _get_t_store() # Get PP template try: pp_template = pp_store.get_template(template_id) except ValueError as e: await websocket.close(code=4004, reason=str(e)) return # Resolve source stream chain try: chain = stream_store.resolve_stream_chain(source_stream_id) except ValueError as e: await websocket.close(code=4004, reason=str(e)) return raw_stream = chain["raw_stream"] if isinstance(raw_stream, StaticImagePictureSource): await websocket.close(code=4003, reason="Static image streams don't support live test") return if not isinstance(raw_stream, ScreenCapturePictureSource): await websocket.close(code=4003, reason="Unsupported stream type for live test") return # Create capture engine try: capture_template = template_store.get_template(raw_stream.capture_template_id) except ValueError as e: await websocket.close(code=4004, reason=str(e)) return if capture_template.engine_type not in EngineRegistry.get_available_engines(): await websocket.close(code=4003, reason=f"Engine '{capture_template.engine_type}' not available") return # Resolve PP filters pp_filters = pp_store.resolve_filter_instances(pp_template.filters) or None # Engine factory — creates + initializes engine inside the capture thread # to avoid thread-affinity issues (e.g. MSS uses thread-local state) _engine_type = capture_template.engine_type _display_index = raw_stream.display_index _engine_config = capture_template.engine_config def engine_factory(): s = EngineRegistry.create_stream(_engine_type, _display_index, _engine_config) s.initialize() return s await websocket.accept() logger.info(f"PP template test WS connected for {template_id} ({duration}s)") try: await stream_capture_test( websocket, engine_factory, duration, pp_filters=pp_filters, preview_width=preview_width or None, ) except WebSocketDisconnect: pass except Exception as e: logger.error(f"PP template test WS error for {template_id}: {e}") finally: logger.info(f"PP template test WS disconnected for {template_id}")