"""Shared helpers for WebSocket-based capture test endpoints.""" import asyncio import base64 import io import secrets import threading import time from typing import Callable, List, Optional import numpy as np from PIL import Image from starlette.websockets import WebSocket from wled_controller.config import get_config from wled_controller.core.filters import FilterRegistry, ImagePool from wled_controller.utils import get_logger logger = get_logger(__name__) PREVIEW_INTERVAL = 0.1 # seconds between intermediate thumbnail sends PREVIEW_MAX_WIDTH = 640 # px for intermediate thumbnails FINAL_THUMBNAIL_WIDTH = 640 # px for the final thumbnail FINAL_JPEG_QUALITY = 90 PREVIEW_JPEG_QUALITY = 70 def authenticate_ws_token(token: str) -> bool: """Check a WebSocket query-param token against configured API keys.""" cfg = get_config() if token and cfg.auth.api_keys: for _label, api_key in cfg.auth.api_keys.items(): if secrets.compare_digest(token, api_key): return True return False def _encode_jpeg(pil_image: Image.Image, quality: int = 85) -> str: """Encode a PIL image as a JPEG base64 data URI.""" buf = io.BytesIO() pil_image.save(buf, format="JPEG", quality=quality) buf.seek(0) b64 = base64.b64encode(buf.getvalue()).decode("utf-8") return f"data:image/jpeg;base64,{b64}" def _make_thumbnail(pil_image: Image.Image, max_width: int) -> Image.Image: """Create a thumbnail copy of the image, preserving aspect ratio.""" thumb = pil_image.copy() aspect = pil_image.height / pil_image.width thumb.thumbnail((max_width, int(max_width * aspect)), Image.Resampling.LANCZOS) return thumb def _apply_pp_filters(pil_image: Image.Image, flat_filters: list) -> Image.Image: """Apply postprocessing filter instances to a PIL image.""" if not flat_filters: return pil_image pool = ImagePool() arr = np.array(pil_image) for fi in flat_filters: f = FilterRegistry.create_instance(fi.filter_id, fi.options) result = f.process_image(arr, pool) if result is not None: arr = result return Image.fromarray(arr) async def stream_capture_test( websocket: WebSocket, engine_factory: Callable, duration: float, pp_filters: Optional[list] = None, preview_width: Optional[int] = None, ) -> None: """Run a capture test, streaming intermediate thumbnails and a final full-res frame. The engine is created and used entirely within a background thread to avoid thread-affinity issues (e.g. MSS uses thread-local state). Args: websocket: Accepted WebSocket connection. engine_factory: Zero-arg callable that returns an initialized engine stream (with .capture_frame() and .cleanup() methods). Called inside the capture thread so thread-local resources work correctly. duration: Test duration in seconds. pp_filters: Optional list of resolved filter instances to apply to frames. """ thumb_width = preview_width or PREVIEW_MAX_WIDTH # Shared state between capture thread and async loop latest_frame = None # PIL Image (converted from numpy) frame_count = 0 total_capture_time = 0.0 stop_event = threading.Event() done_event = threading.Event() init_error = None # set if engine_factory fails def _capture_loop(): nonlocal latest_frame, frame_count, total_capture_time, init_error stream = None try: stream = engine_factory() start = time.perf_counter() end = start + duration while time.perf_counter() < end and not stop_event.is_set(): t0 = time.perf_counter() capture = stream.capture_frame() t1 = time.perf_counter() if capture is None: time.sleep(0.005) continue total_capture_time += t1 - t0 frame_count += 1 # Convert numpy → PIL once in the capture thread if isinstance(capture.image, np.ndarray): latest_frame = Image.fromarray(capture.image) else: latest_frame = capture.image except Exception as e: init_error = str(e) logger.error(f"Capture thread error: {e}") finally: if stream: try: stream.cleanup() except Exception: pass done_event.set() # Start capture in background thread loop = asyncio.get_running_loop() capture_future = loop.run_in_executor(None, _capture_loop) start_time = time.perf_counter() last_sent_frame = None try: # Stream intermediate previews while not done_event.is_set(): await asyncio.sleep(PREVIEW_INTERVAL) # Check for init error if init_error: stop_event.set() await capture_future await websocket.send_json({"type": "error", "detail": init_error}) return frame = latest_frame if frame is not None and frame is not last_sent_frame: last_sent_frame = frame elapsed = time.perf_counter() - start_time fc = frame_count tc = total_capture_time # Encode preview thumbnail (small + fast) thumb = _make_thumbnail(frame, thumb_width) if pp_filters: thumb = _apply_pp_filters(thumb, pp_filters) thumb_uri = _encode_jpeg(thumb, PREVIEW_JPEG_QUALITY) fps = fc / elapsed if elapsed > 0 else 0 avg_ms = (tc / fc * 1000) if fc > 0 else 0 await websocket.send_json({ "type": "frame", "thumbnail": thumb_uri, "frame_count": fc, "elapsed_s": round(elapsed, 2), "fps": round(fps, 1), "avg_capture_ms": round(avg_ms, 1), }) # Wait for capture thread to fully finish await capture_future # Check for errors if init_error: await websocket.send_json({"type": "error", "detail": init_error}) return # Send final result final_frame = latest_frame if final_frame is None: await websocket.send_json({"type": "error", "detail": "No frames captured"}) return elapsed = time.perf_counter() - start_time fc = frame_count tc = total_capture_time fps = fc / elapsed if elapsed > 0 else 0 avg_ms = (tc / fc * 1000) if fc > 0 else 0 w, h = final_frame.size # Apply PP filters to final images if pp_filters: final_frame = _apply_pp_filters(final_frame, pp_filters) full_uri = _encode_jpeg(final_frame, FINAL_JPEG_QUALITY) thumb = _make_thumbnail(final_frame, FINAL_THUMBNAIL_WIDTH) thumb_uri = _encode_jpeg(thumb, 85) await websocket.send_json({ "type": "result", "full_image": full_uri, "thumbnail": thumb_uri, "width": w, "height": h, "frame_count": fc, "elapsed_s": round(elapsed, 2), "fps": round(fps, 1), "avg_capture_ms": round(avg_ms, 1), }) except Exception: # WebSocket disconnect or send error — signal capture thread to stop stop_event.set() await capture_future raise