refactor: comprehensive code quality, security, and release readiness improvements
Some checks failed
Lint & Test / test (push) Failing after 48s
Some checks failed
Lint & Test / test (push) Failing after 48s
Security: tighten CORS defaults, add webhook rate limiting, fix XSS in automations, guard WebSocket JSON.parse, validate ADB address input, seal debug exception leak, URL-encode WS tokens, CSS.escape in selectors. Code quality: add Pydantic models for brightness/power endpoints, fix thread safety and name uniqueness in DeviceStore, immutable update pattern, split 6 oversized files into 16 focused modules, enable TypeScript strictNullChecks (741→102 errors), type state variables, add dom-utils helper, migrate 3 modules from inline onclick to event delegation, ProcessorDependencies dataclass. Performance: async store saves, health endpoint log level, command palette debounce, optimized entity-events comparison, fix service worker precache list. Testing: expand from 45 to 293 passing tests — add store tests (141), route tests (25), core logic tests (42), E2E flow tests (33), organize into tests/api/, tests/storage/, tests/core/, tests/e2e/. DevOps: CI test pipeline, pre-commit config, Dockerfile multi-stage build with non-root user and health check, docker-compose improvements, version bump to 0.2.0. Docs: rewrite CLAUDE.md (202→56 lines), server/CLAUDE.md (212→76), create contexts/server-operations.md, fix .js→.ts references, fix env var prefix in README, rewrite INSTALLATION.md, add CONTRIBUTING.md and .env.example.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Shared helpers for WebSocket-based capture test endpoints."""
|
||||
"""Shared helpers for WebSocket-based capture preview endpoints."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
@@ -50,7 +50,7 @@ def encode_preview_frame(image: np.ndarray, max_width: int = None, quality: int
|
||||
scale = max_width / image.shape[1]
|
||||
new_h = int(image.shape[0] * scale)
|
||||
image = cv2.resize(image, (max_width, new_h), interpolation=cv2.INTER_AREA)
|
||||
# RGB → BGR for OpenCV JPEG encoding
|
||||
# RGB -> BGR for OpenCV JPEG encoding
|
||||
bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
_, buf = cv2.imencode('.jpg', bgr, [cv2.IMWRITE_JPEG_QUALITY, quality])
|
||||
return buf.tobytes()
|
||||
@@ -124,7 +124,7 @@ async def stream_capture_test(
|
||||
continue
|
||||
total_capture_time += t1 - t0
|
||||
frame_count += 1
|
||||
# Convert numpy → PIL once in the capture thread
|
||||
# Convert numpy -> PIL once in the capture thread
|
||||
if isinstance(capture.image, np.ndarray):
|
||||
latest_frame = Image.fromarray(capture.image)
|
||||
else:
|
||||
395
server/src/wled_controller/api/routes/backup.py
Normal file
395
server/src/wled_controller/api/routes/backup.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""System routes: backup, restore, export, import, auto-backup.
|
||||
|
||||
Extracted from system.py to keep files under 800 lines.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from wled_controller import __version__
|
||||
from wled_controller.api.auth import AuthRequired
|
||||
from wled_controller.api.dependencies import get_auto_backup_engine
|
||||
from wled_controller.api.schemas.system import (
|
||||
AutoBackupSettings,
|
||||
AutoBackupStatusResponse,
|
||||
BackupFileInfo,
|
||||
BackupListResponse,
|
||||
RestoreResponse,
|
||||
)
|
||||
from wled_controller.core.backup.auto_backup import AutoBackupEngine
|
||||
from wled_controller.config import get_config
|
||||
from wled_controller.utils import atomic_write_json, get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration backup / restore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Mapping: logical store name -> StorageConfig attribute name
|
||||
STORE_MAP = {
|
||||
"devices": "devices_file",
|
||||
"capture_templates": "templates_file",
|
||||
"postprocessing_templates": "postprocessing_templates_file",
|
||||
"picture_sources": "picture_sources_file",
|
||||
"output_targets": "output_targets_file",
|
||||
"pattern_templates": "pattern_templates_file",
|
||||
"color_strip_sources": "color_strip_sources_file",
|
||||
"audio_sources": "audio_sources_file",
|
||||
"audio_templates": "audio_templates_file",
|
||||
"value_sources": "value_sources_file",
|
||||
"sync_clocks": "sync_clocks_file",
|
||||
"color_strip_processing_templates": "color_strip_processing_templates_file",
|
||||
"automations": "automations_file",
|
||||
"scene_presets": "scene_presets_file",
|
||||
}
|
||||
|
||||
_SERVER_DIR = Path(__file__).resolve().parents[4]
|
||||
|
||||
|
||||
def _schedule_restart() -> None:
|
||||
"""Spawn a restart script after a short delay so the HTTP response completes."""
|
||||
|
||||
def _restart():
|
||||
import time
|
||||
time.sleep(1)
|
||||
if sys.platform == "win32":
|
||||
subprocess.Popen(
|
||||
["powershell", "-ExecutionPolicy", "Bypass", "-File",
|
||||
str(_SERVER_DIR / "restart.ps1")],
|
||||
creationflags=subprocess.DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP,
|
||||
)
|
||||
else:
|
||||
subprocess.Popen(
|
||||
["bash", str(_SERVER_DIR / "restart.sh")],
|
||||
start_new_session=True,
|
||||
)
|
||||
|
||||
threading.Thread(target=_restart, daemon=True).start()
|
||||
|
||||
|
||||
@router.get("/api/v1/system/export/{store_key}", tags=["System"])
|
||||
def export_store(store_key: str, _: AuthRequired):
|
||||
"""Download a single entity store as a JSON file."""
|
||||
if store_key not in STORE_MAP:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Unknown store '{store_key}'. Valid keys: {sorted(STORE_MAP.keys())}",
|
||||
)
|
||||
config = get_config()
|
||||
file_path = Path(getattr(config.storage, STORE_MAP[store_key]))
|
||||
if file_path.exists():
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
else:
|
||||
data = {}
|
||||
|
||||
export = {
|
||||
"meta": {
|
||||
"format": "ledgrab-partial-export",
|
||||
"format_version": 1,
|
||||
"store_key": store_key,
|
||||
"app_version": __version__,
|
||||
"created_at": datetime.now(timezone.utc).isoformat() + "Z",
|
||||
},
|
||||
"store": data,
|
||||
}
|
||||
content = json.dumps(export, indent=2, ensure_ascii=False)
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H%M%S")
|
||||
filename = f"ledgrab-{store_key}-{timestamp}.json"
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content.encode("utf-8")),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/system/import/{store_key}", tags=["System"])
|
||||
async def import_store(
|
||||
store_key: str,
|
||||
_: AuthRequired,
|
||||
file: UploadFile = File(...),
|
||||
merge: bool = Query(False, description="Merge into existing data instead of replacing"),
|
||||
):
|
||||
"""Upload a partial export file to replace or merge one entity store. Triggers server restart."""
|
||||
if store_key not in STORE_MAP:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Unknown store '{store_key}'. Valid keys: {sorted(STORE_MAP.keys())}",
|
||||
)
|
||||
|
||||
try:
|
||||
raw = await file.read()
|
||||
if len(raw) > 10 * 1024 * 1024:
|
||||
raise HTTPException(status_code=400, detail="File too large (max 10 MB)")
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
|
||||
|
||||
# Support both full-backup format and partial-export format
|
||||
if "stores" in payload and isinstance(payload.get("meta"), dict):
|
||||
# Full backup: extract the specific store
|
||||
if payload["meta"].get("format") not in ("ledgrab-backup",):
|
||||
raise HTTPException(status_code=400, detail="Not a valid LED Grab backup or partial export file")
|
||||
stores = payload.get("stores", {})
|
||||
if store_key not in stores:
|
||||
raise HTTPException(status_code=400, detail=f"Backup does not contain store '{store_key}'")
|
||||
incoming = stores[store_key]
|
||||
elif isinstance(payload.get("meta"), dict) and payload["meta"].get("format") == "ledgrab-partial-export":
|
||||
# Partial export format
|
||||
if payload["meta"].get("store_key") != store_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File is for store '{payload['meta']['store_key']}', not '{store_key}'",
|
||||
)
|
||||
incoming = payload.get("store", {})
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Not a valid LED Grab backup or partial export file")
|
||||
|
||||
if not isinstance(incoming, dict):
|
||||
raise HTTPException(status_code=400, detail="Store data must be a JSON object")
|
||||
|
||||
config = get_config()
|
||||
file_path = Path(getattr(config.storage, STORE_MAP[store_key]))
|
||||
|
||||
def _write():
|
||||
if merge and file_path.exists():
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
existing = json.load(f)
|
||||
if isinstance(existing, dict):
|
||||
existing.update(incoming)
|
||||
atomic_write_json(file_path, existing)
|
||||
return len(existing)
|
||||
atomic_write_json(file_path, incoming)
|
||||
return len(incoming)
|
||||
|
||||
count = await asyncio.to_thread(_write)
|
||||
logger.info(f"Imported store '{store_key}' ({count} entries, merge={merge}). Scheduling restart...")
|
||||
_schedule_restart()
|
||||
return {
|
||||
"status": "imported",
|
||||
"store_key": store_key,
|
||||
"entries": count,
|
||||
"merge": merge,
|
||||
"restart_scheduled": True,
|
||||
"message": f"Imported {count} entries for '{store_key}'. Server restarting...",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/v1/system/backup", tags=["System"])
|
||||
def backup_config(_: AuthRequired):
|
||||
"""Download all configuration as a single JSON backup file."""
|
||||
config = get_config()
|
||||
stores = {}
|
||||
for store_key, config_attr in STORE_MAP.items():
|
||||
file_path = Path(getattr(config.storage, config_attr))
|
||||
if file_path.exists():
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
stores[store_key] = json.load(f)
|
||||
else:
|
||||
stores[store_key] = {}
|
||||
|
||||
backup = {
|
||||
"meta": {
|
||||
"format": "ledgrab-backup",
|
||||
"format_version": 1,
|
||||
"app_version": __version__,
|
||||
"created_at": datetime.now(timezone.utc).isoformat() + "Z",
|
||||
"store_count": len(stores),
|
||||
},
|
||||
"stores": stores,
|
||||
}
|
||||
|
||||
content = json.dumps(backup, indent=2, ensure_ascii=False)
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H%M%S")
|
||||
filename = f"ledgrab-backup-{timestamp}.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content.encode("utf-8")),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/system/restart", tags=["System"])
|
||||
def restart_server(_: AuthRequired):
|
||||
"""Schedule a server restart and return immediately."""
|
||||
_schedule_restart()
|
||||
return {"status": "restarting"}
|
||||
|
||||
|
||||
@router.post("/api/v1/system/restore", response_model=RestoreResponse, tags=["System"])
|
||||
async def restore_config(
|
||||
_: AuthRequired,
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""Upload a backup file to restore all configuration. Triggers server restart."""
|
||||
# Read and parse
|
||||
try:
|
||||
raw = await file.read()
|
||||
if len(raw) > 10 * 1024 * 1024: # 10 MB limit
|
||||
raise HTTPException(status_code=400, detail="Backup file too large (max 10 MB)")
|
||||
backup = json.loads(raw)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON file: {e}")
|
||||
|
||||
# Validate envelope
|
||||
meta = backup.get("meta")
|
||||
if not isinstance(meta, dict) or meta.get("format") != "ledgrab-backup":
|
||||
raise HTTPException(status_code=400, detail="Not a valid LED Grab backup file")
|
||||
|
||||
fmt_version = meta.get("format_version", 0)
|
||||
if fmt_version > 1:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Backup format version {fmt_version} is not supported by this server version",
|
||||
)
|
||||
|
||||
stores = backup.get("stores")
|
||||
if not isinstance(stores, dict):
|
||||
raise HTTPException(status_code=400, detail="Backup file missing 'stores' section")
|
||||
|
||||
known_keys = set(STORE_MAP.keys())
|
||||
present_keys = known_keys & set(stores.keys())
|
||||
if not present_keys:
|
||||
raise HTTPException(status_code=400, detail="Backup contains no recognized store data")
|
||||
|
||||
for key in present_keys:
|
||||
if not isinstance(stores[key], dict):
|
||||
raise HTTPException(status_code=400, detail=f"Store '{key}' in backup is not a valid JSON object")
|
||||
|
||||
# Write store files atomically (in thread to avoid blocking event loop)
|
||||
config = get_config()
|
||||
|
||||
def _write_stores():
|
||||
count = 0
|
||||
for store_key, config_attr in STORE_MAP.items():
|
||||
if store_key in stores:
|
||||
file_path = Path(getattr(config.storage, config_attr))
|
||||
atomic_write_json(file_path, stores[store_key])
|
||||
count += 1
|
||||
logger.info(f"Restored store: {store_key} -> {file_path}")
|
||||
return count
|
||||
|
||||
written = await asyncio.to_thread(_write_stores)
|
||||
|
||||
logger.info(f"Restore complete: {written}/{len(STORE_MAP)} stores written. Scheduling restart...")
|
||||
_schedule_restart()
|
||||
|
||||
missing = known_keys - present_keys
|
||||
return RestoreResponse(
|
||||
status="restored",
|
||||
stores_written=written,
|
||||
stores_total=len(STORE_MAP),
|
||||
missing_stores=sorted(missing) if missing else [],
|
||||
restart_scheduled=True,
|
||||
message=f"Restored {written} stores. Server restarting...",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-backup settings & saved backups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/api/v1/system/auto-backup/settings",
|
||||
response_model=AutoBackupStatusResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def get_auto_backup_settings(
|
||||
_: AuthRequired,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""Get auto-backup settings and status."""
|
||||
return engine.get_settings()
|
||||
|
||||
|
||||
@router.put(
|
||||
"/api/v1/system/auto-backup/settings",
|
||||
response_model=AutoBackupStatusResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def update_auto_backup_settings(
|
||||
_: AuthRequired,
|
||||
body: AutoBackupSettings,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""Update auto-backup settings (enable/disable, interval, max backups)."""
|
||||
return await engine.update_settings(
|
||||
enabled=body.enabled,
|
||||
interval_hours=body.interval_hours,
|
||||
max_backups=body.max_backups,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/system/auto-backup/trigger", tags=["System"])
|
||||
async def trigger_backup(
|
||||
_: AuthRequired,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""Manually trigger a backup now."""
|
||||
backup = await engine.trigger_backup()
|
||||
return {"status": "ok", "backup": backup}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/api/v1/system/backups",
|
||||
response_model=BackupListResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def list_backups(
|
||||
_: AuthRequired,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""List all saved backup files."""
|
||||
backups = engine.list_backups()
|
||||
return BackupListResponse(
|
||||
backups=[BackupFileInfo(**b) for b in backups],
|
||||
count=len(backups),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/v1/system/backups/{filename}", tags=["System"])
|
||||
def download_saved_backup(
|
||||
filename: str,
|
||||
_: AuthRequired,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""Download a specific saved backup file."""
|
||||
try:
|
||||
path = engine.get_backup_path(filename)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
content = path.read_bytes()
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/api/v1/system/backups/{filename}", tags=["System"])
|
||||
async def delete_saved_backup(
|
||||
filename: str,
|
||||
_: AuthRequired,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""Delete a specific saved backup file."""
|
||||
try:
|
||||
engine.delete_backup(filename)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return {"status": "deleted", "filename": filename}
|
||||
@@ -16,6 +16,7 @@ from wled_controller.api.dependencies import (
|
||||
get_processor_manager,
|
||||
)
|
||||
from wled_controller.api.schemas.devices import (
|
||||
BrightnessRequest,
|
||||
DeviceCreate,
|
||||
DeviceListResponse,
|
||||
DeviceResponse,
|
||||
@@ -25,6 +26,7 @@ from wled_controller.api.schemas.devices import (
|
||||
DiscoverDevicesResponse,
|
||||
OpenRGBZoneResponse,
|
||||
OpenRGBZonesResponse,
|
||||
PowerRequest,
|
||||
)
|
||||
from wled_controller.core.processing.processor_manager import ProcessorManager
|
||||
from wled_controller.storage import DeviceStore
|
||||
@@ -53,18 +55,19 @@ def _device_to_response(device) -> DeviceResponse:
|
||||
zone_mode=device.zone_mode,
|
||||
capabilities=sorted(get_device_capabilities(device.device_type)),
|
||||
tags=device.tags,
|
||||
dmx_protocol=getattr(device, 'dmx_protocol', 'artnet'),
|
||||
dmx_start_universe=getattr(device, 'dmx_start_universe', 0),
|
||||
dmx_start_channel=getattr(device, 'dmx_start_channel', 1),
|
||||
espnow_peer_mac=getattr(device, 'espnow_peer_mac', ''),
|
||||
espnow_channel=getattr(device, 'espnow_channel', 1),
|
||||
hue_username=getattr(device, 'hue_username', ''),
|
||||
hue_client_key=getattr(device, 'hue_client_key', ''),
|
||||
hue_entertainment_group_id=getattr(device, 'hue_entertainment_group_id', ''),
|
||||
spi_speed_hz=getattr(device, 'spi_speed_hz', 800000),
|
||||
spi_led_type=getattr(device, 'spi_led_type', 'WS2812B'),
|
||||
chroma_device_type=getattr(device, 'chroma_device_type', 'chromalink'),
|
||||
gamesense_device_type=getattr(device, 'gamesense_device_type', 'keyboard'),
|
||||
dmx_protocol=device.dmx_protocol,
|
||||
dmx_start_universe=device.dmx_start_universe,
|
||||
dmx_start_channel=device.dmx_start_channel,
|
||||
espnow_peer_mac=device.espnow_peer_mac,
|
||||
espnow_channel=device.espnow_channel,
|
||||
hue_username=device.hue_username,
|
||||
hue_client_key=device.hue_client_key,
|
||||
hue_entertainment_group_id=device.hue_entertainment_group_id,
|
||||
spi_speed_hz=device.spi_speed_hz,
|
||||
spi_led_type=device.spi_led_type,
|
||||
chroma_device_type=device.chroma_device_type,
|
||||
gamesense_device_type=device.gamesense_device_type,
|
||||
default_css_processing_template_id=device.default_css_processing_template_id,
|
||||
created_at=device.created_at,
|
||||
updated_at=device.updated_at,
|
||||
)
|
||||
@@ -508,7 +511,7 @@ async def get_device_brightness(
|
||||
@router.put("/api/v1/devices/{device_id}/brightness", tags=["Settings"])
|
||||
async def set_device_brightness(
|
||||
device_id: str,
|
||||
body: dict,
|
||||
body: BrightnessRequest,
|
||||
_auth: AuthRequired,
|
||||
store: DeviceStore = Depends(get_device_store),
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
@@ -521,9 +524,7 @@ async def set_device_brightness(
|
||||
if "brightness_control" not in get_device_capabilities(device.device_type):
|
||||
raise HTTPException(status_code=400, detail=f"Brightness control is not supported for {device.device_type} devices")
|
||||
|
||||
bri = body.get("brightness")
|
||||
if bri is None or not isinstance(bri, int) or not 0 <= bri <= 255:
|
||||
raise HTTPException(status_code=400, detail="brightness must be an integer 0-255")
|
||||
bri = body.brightness
|
||||
|
||||
try:
|
||||
try:
|
||||
@@ -581,7 +582,7 @@ async def get_device_power(
|
||||
@router.put("/api/v1/devices/{device_id}/power", tags=["Settings"])
|
||||
async def set_device_power(
|
||||
device_id: str,
|
||||
body: dict,
|
||||
body: PowerRequest,
|
||||
_auth: AuthRequired,
|
||||
store: DeviceStore = Depends(get_device_store),
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
@@ -594,9 +595,7 @@ async def set_device_power(
|
||||
if "power_control" not in get_device_capabilities(device.device_type):
|
||||
raise HTTPException(status_code=400, detail=f"Power control is not supported for {device.device_type} devices")
|
||||
|
||||
on = body.get("on")
|
||||
if on is None or not isinstance(on, bool):
|
||||
raise HTTPException(status_code=400, detail="'on' must be a boolean")
|
||||
on = body.power
|
||||
|
||||
try:
|
||||
# For serial devices, use the cached idle client to avoid port conflicts
|
||||
|
||||
@@ -1,57 +1,26 @@
|
||||
"""Output target routes: CRUD, processing control, settings, state, metrics."""
|
||||
"""Output target routes: CRUD endpoints and batch state/metrics queries."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
from PIL import Image
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from fastapi import Query as QueryParam
|
||||
|
||||
from wled_controller.api.auth import AuthRequired
|
||||
from wled_controller.api.dependencies import (
|
||||
fire_entity_event,
|
||||
get_color_strip_store,
|
||||
get_device_store,
|
||||
get_pattern_template_store,
|
||||
get_picture_source_store,
|
||||
get_output_target_store,
|
||||
get_pp_template_store,
|
||||
get_processor_manager,
|
||||
get_template_store,
|
||||
)
|
||||
from wled_controller.api.schemas.output_targets import (
|
||||
BulkTargetRequest,
|
||||
BulkTargetResponse,
|
||||
ExtractedColorResponse,
|
||||
KCTestRectangleResponse,
|
||||
KCTestResponse,
|
||||
KeyColorsResponse,
|
||||
KeyColorsSettingsSchema,
|
||||
OutputTargetCreate,
|
||||
OutputTargetListResponse,
|
||||
OutputTargetResponse,
|
||||
OutputTargetUpdate,
|
||||
TargetMetricsResponse,
|
||||
TargetProcessingState,
|
||||
)
|
||||
from wled_controller.core.capture_engines import EngineRegistry
|
||||
from wled_controller.core.filters import FilterRegistry, ImagePool
|
||||
from wled_controller.core.processing.processor_manager import ProcessorManager
|
||||
from wled_controller.core.capture.screen_capture import (
|
||||
calculate_average_color,
|
||||
calculate_dominant_color,
|
||||
calculate_median_color,
|
||||
get_available_displays,
|
||||
)
|
||||
from wled_controller.storage.color_strip_store import ColorStripStore
|
||||
from wled_controller.storage.color_strip_source import AdvancedPictureColorStripSource, PictureColorStripSource
|
||||
from wled_controller.storage import DeviceStore
|
||||
from wled_controller.storage.pattern_template_store import PatternTemplateStore
|
||||
from wled_controller.storage.picture_source import ScreenCapturePictureSource, StaticImagePictureSource
|
||||
from wled_controller.storage.picture_source_store import PictureSourceStore
|
||||
from wled_controller.storage.template_store import TemplateStore
|
||||
from wled_controller.storage.wled_output_target import WledOutputTarget
|
||||
from wled_controller.storage.key_colors_output_target import (
|
||||
KeyColorsSettings,
|
||||
@@ -326,7 +295,7 @@ async def update_target(
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Device change requires async stop → swap → start cycle
|
||||
# Device change requires async stop -> swap -> start cycle
|
||||
if data.device_id is not None:
|
||||
try:
|
||||
await manager.update_target_device(target_id, target.device_id)
|
||||
@@ -377,795 +346,3 @@ async def delete_target(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete target: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ===== BULK PROCESSING CONTROL ENDPOINTS =====
|
||||
|
||||
@router.post("/api/v1/output-targets/bulk/start", response_model=BulkTargetResponse, tags=["Processing"])
|
||||
async def bulk_start_processing(
|
||||
body: BulkTargetRequest,
|
||||
_auth: AuthRequired,
|
||||
target_store: OutputTargetStore = Depends(get_output_target_store),
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Start processing for multiple output targets. Returns lists of started IDs and per-ID errors."""
|
||||
started: list[str] = []
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
for target_id in body.ids:
|
||||
try:
|
||||
target_store.get_target(target_id)
|
||||
await manager.start_processing(target_id)
|
||||
started.append(target_id)
|
||||
logger.info(f"Bulk start: started processing for target {target_id}")
|
||||
except ValueError as e:
|
||||
errors[target_id] = str(e)
|
||||
except RuntimeError as e:
|
||||
msg = str(e)
|
||||
for t in target_store.get_all_targets():
|
||||
if t.id in msg:
|
||||
msg = msg.replace(t.id, f"'{t.name}'")
|
||||
errors[target_id] = msg
|
||||
except Exception as e:
|
||||
logger.error(f"Bulk start: failed to start target {target_id}: {e}")
|
||||
errors[target_id] = str(e)
|
||||
|
||||
return BulkTargetResponse(started=started, errors=errors)
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/bulk/stop", response_model=BulkTargetResponse, tags=["Processing"])
|
||||
async def bulk_stop_processing(
|
||||
body: BulkTargetRequest,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Stop processing for multiple output targets. Returns lists of stopped IDs and per-ID errors."""
|
||||
stopped: list[str] = []
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
for target_id in body.ids:
|
||||
try:
|
||||
await manager.stop_processing(target_id)
|
||||
stopped.append(target_id)
|
||||
logger.info(f"Bulk stop: stopped processing for target {target_id}")
|
||||
except ValueError as e:
|
||||
errors[target_id] = str(e)
|
||||
except Exception as e:
|
||||
logger.error(f"Bulk stop: failed to stop target {target_id}: {e}")
|
||||
errors[target_id] = str(e)
|
||||
|
||||
return BulkTargetResponse(stopped=stopped, errors=errors)
|
||||
|
||||
|
||||
# ===== PROCESSING CONTROL ENDPOINTS =====
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/start", tags=["Processing"])
|
||||
async def start_processing(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
target_store: OutputTargetStore = Depends(get_output_target_store),
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Start processing for a output target."""
|
||||
try:
|
||||
# Verify target exists in store
|
||||
target_store.get_target(target_id)
|
||||
|
||||
await manager.start_processing(target_id)
|
||||
|
||||
logger.info(f"Started processing for target {target_id}")
|
||||
return {"status": "started", "target_id": target_id}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
# Resolve target IDs to human-readable names in error messages
|
||||
msg = str(e)
|
||||
for t in target_store.get_all_targets():
|
||||
if t.id in msg:
|
||||
msg = msg.replace(t.id, f"'{t.name}'")
|
||||
raise HTTPException(status_code=409, detail=msg)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start processing: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/stop", tags=["Processing"])
|
||||
async def stop_processing(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Stop processing for a output target."""
|
||||
try:
|
||||
await manager.stop_processing(target_id)
|
||||
|
||||
logger.info(f"Stopped processing for target {target_id}")
|
||||
return {"status": "stopped", "target_id": target_id}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop processing: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ===== STATE & METRICS ENDPOINTS =====
|
||||
|
||||
@router.get("/api/v1/output-targets/{target_id}/state", response_model=TargetProcessingState, tags=["Processing"])
|
||||
async def get_target_state(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Get current processing state for a target."""
|
||||
try:
|
||||
state = manager.get_target_state(target_id)
|
||||
return TargetProcessingState(**state)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get target state: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/api/v1/output-targets/{target_id}/metrics", response_model=TargetMetricsResponse, tags=["Metrics"])
|
||||
async def get_target_metrics(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Get processing metrics for a target."""
|
||||
try:
|
||||
metrics = manager.get_target_metrics(target_id)
|
||||
return TargetMetricsResponse(**metrics)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get target metrics: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ===== KEY COLORS ENDPOINTS =====
|
||||
|
||||
@router.get("/api/v1/output-targets/{target_id}/colors", response_model=KeyColorsResponse, tags=["Key Colors"])
|
||||
async def get_target_colors(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Get latest extracted colors for a key-colors target (polling)."""
|
||||
try:
|
||||
raw_colors = manager.get_kc_latest_colors(target_id)
|
||||
colors = {}
|
||||
for name, (r, g, b) in raw_colors.items():
|
||||
colors[name] = ExtractedColorResponse(
|
||||
r=r, g=g, b=b,
|
||||
hex=f"#{r:02x}{g:02x}{b:02x}",
|
||||
)
|
||||
from datetime import datetime, timezone
|
||||
return KeyColorsResponse(
|
||||
target_id=target_id,
|
||||
colors=colors,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/test", response_model=KCTestResponse, tags=["Key Colors"])
|
||||
async def test_kc_target(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
target_store: OutputTargetStore = Depends(get_output_target_store),
|
||||
source_store: PictureSourceStore = Depends(get_picture_source_store),
|
||||
template_store: TemplateStore = Depends(get_template_store),
|
||||
pattern_store: PatternTemplateStore = Depends(get_pattern_template_store),
|
||||
processor_manager: ProcessorManager = Depends(get_processor_manager),
|
||||
device_store: DeviceStore = Depends(get_device_store),
|
||||
pp_template_store=Depends(get_pp_template_store),
|
||||
):
|
||||
"""Test a key-colors target: capture a frame, extract colors from each rectangle."""
|
||||
import httpx
|
||||
|
||||
stream = None
|
||||
try:
|
||||
# 1. Load and validate KC target
|
||||
try:
|
||||
target = target_store.get_target(target_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if not isinstance(target, KeyColorsOutputTarget):
|
||||
raise HTTPException(status_code=400, detail="Target is not a key_colors target")
|
||||
|
||||
settings = target.settings
|
||||
|
||||
# 2. Resolve pattern template
|
||||
if not settings.pattern_template_id:
|
||||
raise HTTPException(status_code=400, detail="No pattern template configured")
|
||||
|
||||
try:
|
||||
pattern_tmpl = pattern_store.get_template(settings.pattern_template_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Pattern template not found: {settings.pattern_template_id}")
|
||||
|
||||
rectangles = pattern_tmpl.rectangles
|
||||
if not rectangles:
|
||||
raise HTTPException(status_code=400, detail="Pattern template has no rectangles")
|
||||
|
||||
# 3. Resolve picture source and capture a frame
|
||||
if not target.picture_source_id:
|
||||
raise HTTPException(status_code=400, detail="No picture source configured")
|
||||
|
||||
try:
|
||||
chain = source_store.resolve_stream_chain(target.picture_source_id)
|
||||
except EntityNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
raw_stream = chain["raw_stream"]
|
||||
|
||||
if isinstance(raw_stream, StaticImagePictureSource):
|
||||
source = raw_stream.image_source
|
||||
if source.startswith(("http://", "https://")):
|
||||
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
|
||||
resp = await client.get(source)
|
||||
resp.raise_for_status()
|
||||
pil_image = Image.open(io.BytesIO(resp.content)).convert("RGB")
|
||||
else:
|
||||
from pathlib import Path
|
||||
path = Path(source)
|
||||
if not path.exists():
|
||||
raise HTTPException(status_code=400, detail=f"Image file not found: {source}")
|
||||
pil_image = Image.open(path).convert("RGB")
|
||||
|
||||
elif isinstance(raw_stream, ScreenCapturePictureSource):
|
||||
try:
|
||||
capture_template = template_store.get_template(raw_stream.capture_template_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Capture template not found: {raw_stream.capture_template_id}",
|
||||
)
|
||||
|
||||
display_index = raw_stream.display_index
|
||||
|
||||
if capture_template.engine_type not in EngineRegistry.get_available_engines():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Engine '{capture_template.engine_type}' is not available on this system",
|
||||
)
|
||||
|
||||
locked_device_id = processor_manager.get_display_lock_info(display_index)
|
||||
if locked_device_id:
|
||||
try:
|
||||
device = device_store.get_device(locked_device_id)
|
||||
device_name = device.name
|
||||
except Exception:
|
||||
device_name = locked_device_id
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Display {display_index} is currently being captured by device '{device_name}'. "
|
||||
f"Please stop the device processing before testing.",
|
||||
)
|
||||
|
||||
stream = EngineRegistry.create_stream(
|
||||
capture_template.engine_type, display_index, capture_template.engine_config
|
||||
)
|
||||
stream.initialize()
|
||||
|
||||
screen_capture = stream.capture_frame()
|
||||
if screen_capture is None:
|
||||
raise RuntimeError("No frame captured")
|
||||
|
||||
if isinstance(screen_capture.image, np.ndarray):
|
||||
pil_image = Image.fromarray(screen_capture.image)
|
||||
else:
|
||||
raise ValueError("Unexpected image format from engine")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Unsupported picture source type")
|
||||
|
||||
# 3b. Apply postprocessing filters (if the picture source has a filter chain)
|
||||
pp_template_ids = chain.get("postprocessing_template_ids", [])
|
||||
if pp_template_ids and pp_template_store:
|
||||
img_array = np.array(pil_image)
|
||||
image_pool = ImagePool()
|
||||
for pp_id in pp_template_ids:
|
||||
try:
|
||||
pp_template = pp_template_store.get_template(pp_id)
|
||||
except ValueError:
|
||||
logger.warning(f"KC test: PP template {pp_id} not found, skipping")
|
||||
continue
|
||||
flat_filters = pp_template_store.resolve_filter_instances(pp_template.filters)
|
||||
for fi in flat_filters:
|
||||
try:
|
||||
f = FilterRegistry.create_instance(fi.filter_id, fi.options)
|
||||
result = f.process_image(img_array, image_pool)
|
||||
if result is not None:
|
||||
img_array = result
|
||||
except ValueError:
|
||||
logger.warning(f"KC test: unknown filter '{fi.filter_id}', skipping")
|
||||
pil_image = Image.fromarray(img_array)
|
||||
|
||||
# 4. Extract colors from each rectangle
|
||||
img_array = np.array(pil_image)
|
||||
h, w = img_array.shape[:2]
|
||||
|
||||
calc_fns = {
|
||||
"average": calculate_average_color,
|
||||
"median": calculate_median_color,
|
||||
"dominant": calculate_dominant_color,
|
||||
}
|
||||
calc_fn = calc_fns.get(settings.interpolation_mode, calculate_average_color)
|
||||
|
||||
result_rects = []
|
||||
for rect in rectangles:
|
||||
px_x = max(0, int(rect.x * w))
|
||||
px_y = max(0, int(rect.y * h))
|
||||
px_w = max(1, int(rect.width * w))
|
||||
px_h = max(1, int(rect.height * h))
|
||||
px_x = min(px_x, w - 1)
|
||||
px_y = min(px_y, h - 1)
|
||||
px_w = min(px_w, w - px_x)
|
||||
px_h = min(px_h, h - px_y)
|
||||
|
||||
sub_img = img_array[px_y:px_y + px_h, px_x:px_x + px_w]
|
||||
r, g, b = calc_fn(sub_img)
|
||||
|
||||
result_rects.append(KCTestRectangleResponse(
|
||||
name=rect.name,
|
||||
x=rect.x,
|
||||
y=rect.y,
|
||||
width=rect.width,
|
||||
height=rect.height,
|
||||
color=ExtractedColorResponse(r=r, g=g, b=b, hex=f"#{r:02x}{g:02x}{b:02x}"),
|
||||
))
|
||||
|
||||
# 5. Encode frame as base64 JPEG
|
||||
full_buffer = io.BytesIO()
|
||||
pil_image.save(full_buffer, format='JPEG', quality=90)
|
||||
full_buffer.seek(0)
|
||||
full_b64 = base64.b64encode(full_buffer.getvalue()).decode('utf-8')
|
||||
image_data_uri = f"data:image/jpeg;base64,{full_b64}"
|
||||
|
||||
return KCTestResponse(
|
||||
image=image_data_uri,
|
||||
rectangles=result_rects,
|
||||
interpolation_mode=settings.interpolation_mode,
|
||||
pattern_template_name=pattern_tmpl.name,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except EntityNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=500, detail=f"Capture error: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test KC target: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
if stream:
|
||||
try:
|
||||
stream.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up test stream: {e}")
|
||||
|
||||
|
||||
@router.websocket("/api/v1/output-targets/{target_id}/test/ws")
|
||||
async def test_kc_target_ws(
|
||||
websocket: WebSocket,
|
||||
target_id: str,
|
||||
token: str = Query(""),
|
||||
fps: int = Query(3),
|
||||
preview_width: int = Query(480),
|
||||
):
|
||||
"""WebSocket for real-time KC target test preview. Auth via ?token=<api_key>.
|
||||
|
||||
Streams JSON frames: {"type": "frame", "image": "data:image/jpeg;base64,...",
|
||||
"rectangles": [...], "pattern_template_name": "...", "interpolation_mode": "..."}
|
||||
"""
|
||||
import json as _json
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
# Load stores
|
||||
target_store_inst: OutputTargetStore = get_output_target_store()
|
||||
source_store_inst: PictureSourceStore = get_picture_source_store()
|
||||
template_store_inst: TemplateStore = get_template_store()
|
||||
pattern_store_inst: PatternTemplateStore = get_pattern_template_store()
|
||||
processor_manager_inst: ProcessorManager = get_processor_manager()
|
||||
device_store_inst: DeviceStore = get_device_store()
|
||||
pp_template_store_inst = get_pp_template_store()
|
||||
|
||||
# Validate target
|
||||
try:
|
||||
target = target_store_inst.get_target(target_id)
|
||||
except ValueError as e:
|
||||
await websocket.close(code=4004, reason=str(e))
|
||||
return
|
||||
|
||||
if not isinstance(target, KeyColorsOutputTarget):
|
||||
await websocket.close(code=4003, reason="Target is not a key_colors target")
|
||||
return
|
||||
|
||||
settings = target.settings
|
||||
|
||||
if not settings.pattern_template_id:
|
||||
await websocket.close(code=4003, reason="No pattern template configured")
|
||||
return
|
||||
|
||||
try:
|
||||
pattern_tmpl = pattern_store_inst.get_template(settings.pattern_template_id)
|
||||
except ValueError:
|
||||
await websocket.close(code=4003, reason=f"Pattern template not found: {settings.pattern_template_id}")
|
||||
return
|
||||
|
||||
rectangles = pattern_tmpl.rectangles
|
||||
if not rectangles:
|
||||
await websocket.close(code=4003, reason="Pattern template has no rectangles")
|
||||
return
|
||||
|
||||
if not target.picture_source_id:
|
||||
await websocket.close(code=4003, reason="No picture source configured")
|
||||
return
|
||||
|
||||
try:
|
||||
chain = source_store_inst.resolve_stream_chain(target.picture_source_id)
|
||||
except ValueError as e:
|
||||
await websocket.close(code=4003, reason=str(e))
|
||||
return
|
||||
|
||||
raw_stream = chain["raw_stream"]
|
||||
|
||||
# For screen capture sources, check display lock
|
||||
if isinstance(raw_stream, ScreenCapturePictureSource):
|
||||
display_index = raw_stream.display_index
|
||||
locked_device_id = processor_manager_inst.get_display_lock_info(display_index)
|
||||
if locked_device_id:
|
||||
try:
|
||||
device = device_store_inst.get_device(locked_device_id)
|
||||
device_name = device.name
|
||||
except Exception:
|
||||
device_name = locked_device_id
|
||||
await websocket.close(
|
||||
code=4003,
|
||||
reason=f"Display {display_index} is captured by '{device_name}'. Stop processing first.",
|
||||
)
|
||||
return
|
||||
|
||||
fps = max(1, min(30, fps))
|
||||
preview_width = max(120, min(1920, preview_width))
|
||||
frame_interval = 1.0 / fps
|
||||
|
||||
calc_fns = {
|
||||
"average": calculate_average_color,
|
||||
"median": calculate_median_color,
|
||||
"dominant": calculate_dominant_color,
|
||||
}
|
||||
calc_fn = calc_fns.get(settings.interpolation_mode, calculate_average_color)
|
||||
|
||||
await websocket.accept()
|
||||
logger.info(f"KC test WS connected for {target_id} (fps={fps})")
|
||||
|
||||
# Use the shared LiveStreamManager so we share the capture stream with
|
||||
# running LED targets instead of creating a competing DXGI duplicator.
|
||||
live_stream_mgr = processor_manager_inst._live_stream_manager
|
||||
live_stream = None
|
||||
|
||||
try:
|
||||
live_stream = await asyncio.to_thread(
|
||||
live_stream_mgr.acquire, target.picture_source_id
|
||||
)
|
||||
logger.info(f"KC test WS acquired shared live stream for {target.picture_source_id}")
|
||||
|
||||
prev_frame_ref = None
|
||||
|
||||
while True:
|
||||
loop_start = time.monotonic()
|
||||
|
||||
try:
|
||||
capture = await asyncio.to_thread(live_stream.get_latest_frame)
|
||||
|
||||
if capture is None or capture.image is None:
|
||||
await asyncio.sleep(frame_interval)
|
||||
continue
|
||||
|
||||
# Skip if same frame object (no new capture yet)
|
||||
if capture is prev_frame_ref:
|
||||
await asyncio.sleep(frame_interval * 0.5)
|
||||
continue
|
||||
prev_frame_ref = capture
|
||||
|
||||
pil_image = Image.fromarray(capture.image) if isinstance(capture.image, np.ndarray) else None
|
||||
if pil_image is None:
|
||||
await asyncio.sleep(frame_interval)
|
||||
continue
|
||||
|
||||
# Apply postprocessing (if the source chain has PP templates)
|
||||
chain = source_store_inst.resolve_stream_chain(target.picture_source_id)
|
||||
pp_template_ids = chain.get("postprocessing_template_ids", [])
|
||||
if pp_template_ids and pp_template_store_inst:
|
||||
img_array = np.array(pil_image)
|
||||
image_pool = ImagePool()
|
||||
for pp_id in pp_template_ids:
|
||||
try:
|
||||
pp_template = pp_template_store_inst.get_template(pp_id)
|
||||
except ValueError:
|
||||
continue
|
||||
flat_filters = pp_template_store_inst.resolve_filter_instances(pp_template.filters)
|
||||
for fi in flat_filters:
|
||||
try:
|
||||
f = FilterRegistry.create_instance(fi.filter_id, fi.options)
|
||||
result = f.process_image(img_array, image_pool)
|
||||
if result is not None:
|
||||
img_array = result
|
||||
except ValueError:
|
||||
pass
|
||||
pil_image = Image.fromarray(img_array)
|
||||
|
||||
# Extract colors
|
||||
img_array = np.array(pil_image)
|
||||
h, w = img_array.shape[:2]
|
||||
|
||||
result_rects = []
|
||||
for rect in rectangles:
|
||||
px_x = max(0, int(rect.x * w))
|
||||
px_y = max(0, int(rect.y * h))
|
||||
px_w = max(1, int(rect.width * w))
|
||||
px_h = max(1, int(rect.height * h))
|
||||
px_x = min(px_x, w - 1)
|
||||
px_y = min(px_y, h - 1)
|
||||
px_w = min(px_w, w - px_x)
|
||||
px_h = min(px_h, h - px_y)
|
||||
|
||||
sub_img = img_array[px_y:px_y + px_h, px_x:px_x + px_w]
|
||||
r, g, b = calc_fn(sub_img)
|
||||
|
||||
result_rects.append({
|
||||
"name": rect.name,
|
||||
"x": rect.x,
|
||||
"y": rect.y,
|
||||
"width": rect.width,
|
||||
"height": rect.height,
|
||||
"color": {"r": r, "g": g, "b": b, "hex": f"#{r:02x}{g:02x}{b:02x}"},
|
||||
})
|
||||
|
||||
# Encode frame as JPEG
|
||||
if preview_width and pil_image.width > preview_width:
|
||||
ratio = preview_width / pil_image.width
|
||||
thumb = pil_image.resize((preview_width, int(pil_image.height * ratio)), Image.LANCZOS)
|
||||
else:
|
||||
thumb = pil_image
|
||||
buf = io.BytesIO()
|
||||
thumb.save(buf, format="JPEG", quality=85)
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
await websocket.send_text(_json.dumps({
|
||||
"type": "frame",
|
||||
"image": f"data:image/jpeg;base64,{b64}",
|
||||
"rectangles": result_rects,
|
||||
"pattern_template_name": pattern_tmpl.name,
|
||||
"interpolation_mode": settings.interpolation_mode,
|
||||
}))
|
||||
|
||||
except (WebSocketDisconnect, Exception) as inner_e:
|
||||
if isinstance(inner_e, WebSocketDisconnect):
|
||||
raise
|
||||
logger.warning(f"KC test WS frame error for {target_id}: {inner_e}")
|
||||
|
||||
elapsed = time.monotonic() - loop_start
|
||||
sleep_time = frame_interval - elapsed
|
||||
if sleep_time > 0:
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"KC test WS disconnected for {target_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"KC test WS error for {target_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
if live_stream is not None:
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
live_stream_mgr.release, target.picture_source_id
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"KC test WS closed for {target_id}")
|
||||
|
||||
|
||||
@router.websocket("/api/v1/output-targets/{target_id}/ws")
|
||||
async def target_colors_ws(
|
||||
websocket: WebSocket,
|
||||
target_id: str,
|
||||
token: str = Query(""),
|
||||
):
|
||||
"""WebSocket for real-time key color updates. Auth via ?token=<api_key>."""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
manager = get_processor_manager()
|
||||
|
||||
try:
|
||||
manager.add_kc_ws_client(target_id, websocket)
|
||||
except ValueError:
|
||||
await websocket.close(code=4004, reason="Target not found")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Keep alive — wait for client messages (or disconnect)
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
manager.remove_kc_ws_client(target_id, websocket)
|
||||
|
||||
|
||||
@router.websocket("/api/v1/output-targets/{target_id}/led-preview/ws")
|
||||
async def led_preview_ws(
|
||||
websocket: WebSocket,
|
||||
target_id: str,
|
||||
token: str = Query(""),
|
||||
):
|
||||
"""WebSocket for real-time LED strip preview. Sends binary RGB frames. Auth via ?token=<api_key>."""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
manager = get_processor_manager()
|
||||
|
||||
try:
|
||||
manager.add_led_preview_client(target_id, websocket)
|
||||
except ValueError:
|
||||
await websocket.close(code=4004, reason="Target not found")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
manager.remove_led_preview_client(target_id, websocket)
|
||||
|
||||
|
||||
# ===== STATE CHANGE EVENT STREAM =====
|
||||
|
||||
|
||||
@router.websocket("/api/v1/events/ws")
|
||||
async def events_ws(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(""),
|
||||
):
|
||||
"""WebSocket for real-time state change events. Auth via ?token=<api_key>."""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
manager = get_processor_manager()
|
||||
queue = manager.subscribe_events()
|
||||
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
await websocket.send_json(event)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
manager.unsubscribe_events(queue)
|
||||
|
||||
|
||||
# ===== OVERLAY VISUALIZATION =====
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/overlay/start", tags=["Visualization"])
|
||||
async def start_target_overlay(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
target_store: OutputTargetStore = Depends(get_output_target_store),
|
||||
color_strip_store: ColorStripStore = Depends(get_color_strip_store),
|
||||
picture_source_store: PictureSourceStore = Depends(get_picture_source_store),
|
||||
):
|
||||
"""Start screen overlay visualization for a target.
|
||||
|
||||
Displays a transparent overlay on the target display showing:
|
||||
- Border sampling zones (colored rectangles)
|
||||
- LED position markers (numbered dots)
|
||||
- Pixel-to-LED mapping ranges (colored segments)
|
||||
- Calibration info text
|
||||
"""
|
||||
try:
|
||||
# Get target name from store
|
||||
target = target_store.get_target(target_id)
|
||||
if not target:
|
||||
raise ValueError(f"Target {target_id} not found")
|
||||
|
||||
# Pre-load calibration and display info from the CSS store so the overlay
|
||||
# can start even when processing is not currently running.
|
||||
calibration = None
|
||||
display_info = None
|
||||
if isinstance(target, WledOutputTarget) and target.color_strip_source_id:
|
||||
first_css_id = target.color_strip_source_id
|
||||
if first_css_id:
|
||||
try:
|
||||
css = color_strip_store.get_source(first_css_id)
|
||||
if isinstance(css, (PictureColorStripSource, AdvancedPictureColorStripSource)) and css.calibration:
|
||||
calibration = css.calibration
|
||||
# Resolve the display this CSS is capturing
|
||||
from wled_controller.api.routes.color_strip_sources import _resolve_display_index
|
||||
ps_id = getattr(css, "picture_source_id", "") or ""
|
||||
display_index = _resolve_display_index(ps_id, picture_source_store)
|
||||
displays = get_available_displays()
|
||||
if displays:
|
||||
display_index = min(display_index, len(displays) - 1)
|
||||
display_info = displays[display_index]
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not pre-load CSS calibration for overlay on {target_id}: {e}")
|
||||
|
||||
await manager.start_overlay(target_id, target.name, calibration=calibration, display_info=display_info)
|
||||
return {"status": "started", "target_id": target_id}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start overlay: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/overlay/stop", tags=["Visualization"])
|
||||
async def stop_target_overlay(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Stop screen overlay visualization for a target."""
|
||||
try:
|
||||
await manager.stop_overlay(target_id)
|
||||
return {"status": "stopped", "target_id": target_id}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop overlay: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/api/v1/output-targets/{target_id}/overlay/status", tags=["Visualization"])
|
||||
async def get_overlay_status(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Check if overlay is active for a target."""
|
||||
try:
|
||||
active = manager.is_overlay_active(target_id)
|
||||
return {"target_id": target_id, "active": active}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
341
server/src/wled_controller/api/routes/output_targets_control.py
Normal file
341
server/src/wled_controller/api/routes/output_targets_control.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""Output target routes: processing control, state, metrics, events, overlay.
|
||||
|
||||
Extracted from output_targets.py to keep files under 800 lines.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
|
||||
from wled_controller.api.auth import AuthRequired
|
||||
from wled_controller.api.dependencies import (
|
||||
fire_entity_event,
|
||||
get_color_strip_store,
|
||||
get_device_store,
|
||||
get_output_target_store,
|
||||
get_picture_source_store,
|
||||
get_processor_manager,
|
||||
)
|
||||
from wled_controller.api.schemas.output_targets import (
|
||||
BulkTargetRequest,
|
||||
BulkTargetResponse,
|
||||
TargetMetricsResponse,
|
||||
TargetProcessingState,
|
||||
)
|
||||
from wled_controller.core.processing.processor_manager import ProcessorManager
|
||||
from wled_controller.core.capture.screen_capture import get_available_displays
|
||||
from wled_controller.storage.color_strip_store import ColorStripStore
|
||||
from wled_controller.storage.color_strip_source import AdvancedPictureColorStripSource, PictureColorStripSource
|
||||
from wled_controller.storage.picture_source_store import PictureSourceStore
|
||||
from wled_controller.storage.wled_output_target import WledOutputTarget
|
||||
from wled_controller.storage.output_target_store import OutputTargetStore
|
||||
from wled_controller.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ===== BULK PROCESSING CONTROL ENDPOINTS =====
|
||||
|
||||
@router.post("/api/v1/output-targets/bulk/start", response_model=BulkTargetResponse, tags=["Processing"])
|
||||
async def bulk_start_processing(
|
||||
body: BulkTargetRequest,
|
||||
_auth: AuthRequired,
|
||||
target_store: OutputTargetStore = Depends(get_output_target_store),
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Start processing for multiple output targets. Returns lists of started IDs and per-ID errors."""
|
||||
started: list[str] = []
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
for target_id in body.ids:
|
||||
try:
|
||||
target_store.get_target(target_id)
|
||||
await manager.start_processing(target_id)
|
||||
started.append(target_id)
|
||||
logger.info(f"Bulk start: started processing for target {target_id}")
|
||||
except ValueError as e:
|
||||
errors[target_id] = str(e)
|
||||
except RuntimeError as e:
|
||||
msg = str(e)
|
||||
for t in target_store.get_all_targets():
|
||||
if t.id in msg:
|
||||
msg = msg.replace(t.id, f"'{t.name}'")
|
||||
errors[target_id] = msg
|
||||
except Exception as e:
|
||||
logger.error(f"Bulk start: failed to start target {target_id}: {e}")
|
||||
errors[target_id] = str(e)
|
||||
|
||||
return BulkTargetResponse(started=started, errors=errors)
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/bulk/stop", response_model=BulkTargetResponse, tags=["Processing"])
|
||||
async def bulk_stop_processing(
|
||||
body: BulkTargetRequest,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Stop processing for multiple output targets. Returns lists of stopped IDs and per-ID errors."""
|
||||
stopped: list[str] = []
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
for target_id in body.ids:
|
||||
try:
|
||||
await manager.stop_processing(target_id)
|
||||
stopped.append(target_id)
|
||||
logger.info(f"Bulk stop: stopped processing for target {target_id}")
|
||||
except ValueError as e:
|
||||
errors[target_id] = str(e)
|
||||
except Exception as e:
|
||||
logger.error(f"Bulk stop: failed to stop target {target_id}: {e}")
|
||||
errors[target_id] = str(e)
|
||||
|
||||
return BulkTargetResponse(stopped=stopped, errors=errors)
|
||||
|
||||
|
||||
# ===== PROCESSING CONTROL ENDPOINTS =====
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/start", tags=["Processing"])
|
||||
async def start_processing(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
target_store: OutputTargetStore = Depends(get_output_target_store),
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Start processing for a output target."""
|
||||
try:
|
||||
# Verify target exists in store
|
||||
target_store.get_target(target_id)
|
||||
|
||||
await manager.start_processing(target_id)
|
||||
|
||||
logger.info(f"Started processing for target {target_id}")
|
||||
return {"status": "started", "target_id": target_id}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
# Resolve target IDs to human-readable names in error messages
|
||||
msg = str(e)
|
||||
for t in target_store.get_all_targets():
|
||||
if t.id in msg:
|
||||
msg = msg.replace(t.id, f"'{t.name}'")
|
||||
raise HTTPException(status_code=409, detail=msg)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start processing: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/stop", tags=["Processing"])
|
||||
async def stop_processing(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Stop processing for a output target."""
|
||||
try:
|
||||
await manager.stop_processing(target_id)
|
||||
|
||||
logger.info(f"Stopped processing for target {target_id}")
|
||||
return {"status": "stopped", "target_id": target_id}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop processing: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ===== STATE & METRICS ENDPOINTS =====
|
||||
|
||||
@router.get("/api/v1/output-targets/{target_id}/state", response_model=TargetProcessingState, tags=["Processing"])
|
||||
async def get_target_state(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Get current processing state for a target."""
|
||||
try:
|
||||
state = manager.get_target_state(target_id)
|
||||
return TargetProcessingState(**state)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get target state: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/api/v1/output-targets/{target_id}/metrics", response_model=TargetMetricsResponse, tags=["Metrics"])
|
||||
async def get_target_metrics(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Get processing metrics for a target."""
|
||||
try:
|
||||
metrics = manager.get_target_metrics(target_id)
|
||||
return TargetMetricsResponse(**metrics)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get target metrics: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ===== STATE CHANGE EVENT STREAM =====
|
||||
|
||||
|
||||
@router.websocket("/api/v1/events/ws")
|
||||
async def events_ws(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(""),
|
||||
):
|
||||
"""WebSocket for real-time state change events. Auth via ?token=<api_key>."""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
manager = get_processor_manager()
|
||||
queue = manager.subscribe_events()
|
||||
|
||||
try:
|
||||
while True:
|
||||
event = await queue.get()
|
||||
await websocket.send_json(event)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
manager.unsubscribe_events(queue)
|
||||
|
||||
|
||||
# ===== OVERLAY VISUALIZATION =====
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/overlay/start", tags=["Visualization"])
|
||||
async def start_target_overlay(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
target_store: OutputTargetStore = Depends(get_output_target_store),
|
||||
color_strip_store: ColorStripStore = Depends(get_color_strip_store),
|
||||
picture_source_store: PictureSourceStore = Depends(get_picture_source_store),
|
||||
):
|
||||
"""Start screen overlay visualization for a target.
|
||||
|
||||
Displays a transparent overlay on the target display showing:
|
||||
- Border sampling zones (colored rectangles)
|
||||
- LED position markers (numbered dots)
|
||||
- Pixel-to-LED mapping ranges (colored segments)
|
||||
- Calibration info text
|
||||
"""
|
||||
try:
|
||||
# Get target name from store
|
||||
target = target_store.get_target(target_id)
|
||||
if not target:
|
||||
raise ValueError(f"Target {target_id} not found")
|
||||
|
||||
# Pre-load calibration and display info from the CSS store so the overlay
|
||||
# can start even when processing is not currently running.
|
||||
calibration = None
|
||||
display_info = None
|
||||
if isinstance(target, WledOutputTarget) and target.color_strip_source_id:
|
||||
first_css_id = target.color_strip_source_id
|
||||
if first_css_id:
|
||||
try:
|
||||
css = color_strip_store.get_source(first_css_id)
|
||||
if isinstance(css, (PictureColorStripSource, AdvancedPictureColorStripSource)) and css.calibration:
|
||||
calibration = css.calibration
|
||||
# Resolve the display this CSS is capturing
|
||||
from wled_controller.api.routes.color_strip_sources import _resolve_display_index
|
||||
ps_id = getattr(css, "picture_source_id", "") or ""
|
||||
display_index = _resolve_display_index(ps_id, picture_source_store)
|
||||
displays = get_available_displays()
|
||||
if displays:
|
||||
display_index = min(display_index, len(displays) - 1)
|
||||
display_info = displays[display_index]
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not pre-load CSS calibration for overlay on {target_id}: {e}")
|
||||
|
||||
await manager.start_overlay(target_id, target.name, calibration=calibration, display_info=display_info)
|
||||
return {"status": "started", "target_id": target_id}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start overlay: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/overlay/stop", tags=["Visualization"])
|
||||
async def stop_target_overlay(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Stop screen overlay visualization for a target."""
|
||||
try:
|
||||
await manager.stop_overlay(target_id)
|
||||
return {"status": "stopped", "target_id": target_id}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop overlay: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/api/v1/output-targets/{target_id}/overlay/status", tags=["Visualization"])
|
||||
async def get_overlay_status(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Check if overlay is active for a target."""
|
||||
try:
|
||||
active = manager.is_overlay_active(target_id)
|
||||
return {"target_id": target_id, "active": active}
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
# ===== LED PREVIEW WEBSOCKET =====
|
||||
|
||||
@router.websocket("/api/v1/output-targets/{target_id}/led-preview/ws")
|
||||
async def led_preview_ws(
|
||||
websocket: WebSocket,
|
||||
target_id: str,
|
||||
token: str = Query(""),
|
||||
):
|
||||
"""WebSocket for real-time LED strip preview. Sends binary RGB frames. Auth via ?token=<api_key>."""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
manager = get_processor_manager()
|
||||
|
||||
try:
|
||||
manager.add_led_preview_client(target_id, websocket)
|
||||
except ValueError:
|
||||
await websocket.close(code=4004, reason="Target not found")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
manager.remove_led_preview_client(target_id, websocket)
|
||||
@@ -0,0 +1,540 @@
|
||||
"""Output target routes: key colors endpoints, testing, and WebSocket streams.
|
||||
|
||||
Extracted from output_targets.py to keep files under 800 lines.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
from PIL import Image
|
||||
|
||||
from wled_controller.api.auth import AuthRequired
|
||||
from wled_controller.api.dependencies import (
|
||||
get_device_store,
|
||||
get_output_target_store,
|
||||
get_pattern_template_store,
|
||||
get_picture_source_store,
|
||||
get_pp_template_store,
|
||||
get_processor_manager,
|
||||
get_template_store,
|
||||
)
|
||||
from wled_controller.api.schemas.output_targets import (
|
||||
ExtractedColorResponse,
|
||||
KCTestRectangleResponse,
|
||||
KCTestResponse,
|
||||
KeyColorsResponse,
|
||||
)
|
||||
from wled_controller.core.capture_engines import EngineRegistry
|
||||
from wled_controller.core.filters import FilterRegistry, ImagePool
|
||||
from wled_controller.core.processing.processor_manager import ProcessorManager
|
||||
from wled_controller.core.capture.screen_capture import (
|
||||
calculate_average_color,
|
||||
calculate_dominant_color,
|
||||
calculate_median_color,
|
||||
)
|
||||
from wled_controller.storage import DeviceStore
|
||||
from wled_controller.storage.pattern_template_store import PatternTemplateStore
|
||||
from wled_controller.storage.picture_source import ScreenCapturePictureSource, StaticImagePictureSource
|
||||
from wled_controller.storage.picture_source_store import PictureSourceStore
|
||||
from wled_controller.storage.template_store import TemplateStore
|
||||
from wled_controller.storage.key_colors_output_target import KeyColorsOutputTarget
|
||||
from wled_controller.storage.output_target_store import OutputTargetStore
|
||||
from wled_controller.storage.base_store import EntityNotFoundError
|
||||
from wled_controller.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ===== KEY COLORS ENDPOINTS =====
|
||||
|
||||
@router.get("/api/v1/output-targets/{target_id}/colors", response_model=KeyColorsResponse, tags=["Key Colors"])
|
||||
async def get_target_colors(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
manager: ProcessorManager = Depends(get_processor_manager),
|
||||
):
|
||||
"""Get latest extracted colors for a key-colors target (polling)."""
|
||||
try:
|
||||
raw_colors = manager.get_kc_latest_colors(target_id)
|
||||
colors = {}
|
||||
for name, (r, g, b) in raw_colors.items():
|
||||
colors[name] = ExtractedColorResponse(
|
||||
r=r, g=g, b=b,
|
||||
hex=f"#{r:02x}{g:02x}{b:02x}",
|
||||
)
|
||||
from datetime import datetime, timezone
|
||||
return KeyColorsResponse(
|
||||
target_id=target_id,
|
||||
colors=colors,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/api/v1/output-targets/{target_id}/test", response_model=KCTestResponse, tags=["Key Colors"])
|
||||
async def test_kc_target(
|
||||
target_id: str,
|
||||
_auth: AuthRequired,
|
||||
target_store: OutputTargetStore = Depends(get_output_target_store),
|
||||
source_store: PictureSourceStore = Depends(get_picture_source_store),
|
||||
template_store: TemplateStore = Depends(get_template_store),
|
||||
pattern_store: PatternTemplateStore = Depends(get_pattern_template_store),
|
||||
processor_manager: ProcessorManager = Depends(get_processor_manager),
|
||||
device_store: DeviceStore = Depends(get_device_store),
|
||||
pp_template_store=Depends(get_pp_template_store),
|
||||
):
|
||||
"""Test a key-colors target: capture a frame, extract colors from each rectangle."""
|
||||
import httpx
|
||||
|
||||
stream = None
|
||||
try:
|
||||
# 1. Load and validate KC target
|
||||
try:
|
||||
target = target_store.get_target(target_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if not isinstance(target, KeyColorsOutputTarget):
|
||||
raise HTTPException(status_code=400, detail="Target is not a key_colors target")
|
||||
|
||||
settings = target.settings
|
||||
|
||||
# 2. Resolve pattern template
|
||||
if not settings.pattern_template_id:
|
||||
raise HTTPException(status_code=400, detail="No pattern template configured")
|
||||
|
||||
try:
|
||||
pattern_tmpl = pattern_store.get_template(settings.pattern_template_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Pattern template not found: {settings.pattern_template_id}")
|
||||
|
||||
rectangles = pattern_tmpl.rectangles
|
||||
if not rectangles:
|
||||
raise HTTPException(status_code=400, detail="Pattern template has no rectangles")
|
||||
|
||||
# 3. Resolve picture source and capture a frame
|
||||
if not target.picture_source_id:
|
||||
raise HTTPException(status_code=400, detail="No picture source configured")
|
||||
|
||||
try:
|
||||
chain = source_store.resolve_stream_chain(target.picture_source_id)
|
||||
except EntityNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
raw_stream = chain["raw_stream"]
|
||||
|
||||
if isinstance(raw_stream, StaticImagePictureSource):
|
||||
source = raw_stream.image_source
|
||||
if source.startswith(("http://", "https://")):
|
||||
async with httpx.AsyncClient(timeout=15, follow_redirects=True) as client:
|
||||
resp = await client.get(source)
|
||||
resp.raise_for_status()
|
||||
pil_image = Image.open(io.BytesIO(resp.content)).convert("RGB")
|
||||
else:
|
||||
from pathlib import Path
|
||||
path = Path(source)
|
||||
if not path.exists():
|
||||
raise HTTPException(status_code=400, detail=f"Image file not found: {source}")
|
||||
pil_image = Image.open(path).convert("RGB")
|
||||
|
||||
elif isinstance(raw_stream, ScreenCapturePictureSource):
|
||||
try:
|
||||
capture_template = template_store.get_template(raw_stream.capture_template_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Capture template not found: {raw_stream.capture_template_id}",
|
||||
)
|
||||
|
||||
display_index = raw_stream.display_index
|
||||
|
||||
if capture_template.engine_type not in EngineRegistry.get_available_engines():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Engine '{capture_template.engine_type}' is not available on this system",
|
||||
)
|
||||
|
||||
locked_device_id = processor_manager.get_display_lock_info(display_index)
|
||||
if locked_device_id:
|
||||
try:
|
||||
device = device_store.get_device(locked_device_id)
|
||||
device_name = device.name
|
||||
except Exception:
|
||||
device_name = locked_device_id
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Display {display_index} is currently being captured by device '{device_name}'. "
|
||||
f"Please stop the device processing before testing.",
|
||||
)
|
||||
|
||||
stream = EngineRegistry.create_stream(
|
||||
capture_template.engine_type, display_index, capture_template.engine_config
|
||||
)
|
||||
stream.initialize()
|
||||
|
||||
screen_capture = stream.capture_frame()
|
||||
if screen_capture is None:
|
||||
raise RuntimeError("No frame captured")
|
||||
|
||||
if isinstance(screen_capture.image, np.ndarray):
|
||||
pil_image = Image.fromarray(screen_capture.image)
|
||||
else:
|
||||
raise ValueError("Unexpected image format from engine")
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Unsupported picture source type")
|
||||
|
||||
# 3b. Apply postprocessing filters (if the picture source has a filter chain)
|
||||
pp_template_ids = chain.get("postprocessing_template_ids", [])
|
||||
if pp_template_ids and pp_template_store:
|
||||
img_array = np.array(pil_image)
|
||||
image_pool = ImagePool()
|
||||
for pp_id in pp_template_ids:
|
||||
try:
|
||||
pp_template = pp_template_store.get_template(pp_id)
|
||||
except ValueError:
|
||||
logger.warning(f"KC test: PP template {pp_id} not found, skipping")
|
||||
continue
|
||||
flat_filters = pp_template_store.resolve_filter_instances(pp_template.filters)
|
||||
for fi in flat_filters:
|
||||
try:
|
||||
f = FilterRegistry.create_instance(fi.filter_id, fi.options)
|
||||
result = f.process_image(img_array, image_pool)
|
||||
if result is not None:
|
||||
img_array = result
|
||||
except ValueError:
|
||||
logger.warning(f"KC test: unknown filter '{fi.filter_id}', skipping")
|
||||
pil_image = Image.fromarray(img_array)
|
||||
|
||||
# 4. Extract colors from each rectangle
|
||||
img_array = np.array(pil_image)
|
||||
h, w = img_array.shape[:2]
|
||||
|
||||
calc_fns = {
|
||||
"average": calculate_average_color,
|
||||
"median": calculate_median_color,
|
||||
"dominant": calculate_dominant_color,
|
||||
}
|
||||
calc_fn = calc_fns.get(settings.interpolation_mode, calculate_average_color)
|
||||
|
||||
result_rects = []
|
||||
for rect in rectangles:
|
||||
px_x = max(0, int(rect.x * w))
|
||||
px_y = max(0, int(rect.y * h))
|
||||
px_w = max(1, int(rect.width * w))
|
||||
px_h = max(1, int(rect.height * h))
|
||||
px_x = min(px_x, w - 1)
|
||||
px_y = min(px_y, h - 1)
|
||||
px_w = min(px_w, w - px_x)
|
||||
px_h = min(px_h, h - px_y)
|
||||
|
||||
sub_img = img_array[px_y:px_y + px_h, px_x:px_x + px_w]
|
||||
r, g, b = calc_fn(sub_img)
|
||||
|
||||
result_rects.append(KCTestRectangleResponse(
|
||||
name=rect.name,
|
||||
x=rect.x,
|
||||
y=rect.y,
|
||||
width=rect.width,
|
||||
height=rect.height,
|
||||
color=ExtractedColorResponse(r=r, g=g, b=b, hex=f"#{r:02x}{g:02x}{b:02x}"),
|
||||
))
|
||||
|
||||
# 5. Encode frame as base64 JPEG
|
||||
full_buffer = io.BytesIO()
|
||||
pil_image.save(full_buffer, format='JPEG', quality=90)
|
||||
full_buffer.seek(0)
|
||||
full_b64 = base64.b64encode(full_buffer.getvalue()).decode('utf-8')
|
||||
image_data_uri = f"data:image/jpeg;base64,{full_b64}"
|
||||
|
||||
return KCTestResponse(
|
||||
image=image_data_uri,
|
||||
rectangles=result_rects,
|
||||
interpolation_mode=settings.interpolation_mode,
|
||||
pattern_template_name=pattern_tmpl.name,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except EntityNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=500, detail=f"Capture error: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test KC target: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
if stream:
|
||||
try:
|
||||
stream.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up test stream: {e}")
|
||||
|
||||
|
||||
@router.websocket("/api/v1/output-targets/{target_id}/test/ws")
|
||||
async def test_kc_target_ws(
|
||||
websocket: WebSocket,
|
||||
target_id: str,
|
||||
token: str = Query(""),
|
||||
fps: int = Query(3),
|
||||
preview_width: int = Query(480),
|
||||
):
|
||||
"""WebSocket for real-time KC target test preview. Auth via ?token=<api_key>.
|
||||
|
||||
Streams JSON frames: {"type": "frame", "image": "data:image/jpeg;base64,...",
|
||||
"rectangles": [...], "pattern_template_name": "...", "interpolation_mode": "..."}
|
||||
"""
|
||||
import json as _json
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
# Load stores
|
||||
target_store_inst: OutputTargetStore = get_output_target_store()
|
||||
source_store_inst: PictureSourceStore = get_picture_source_store()
|
||||
template_store_inst: TemplateStore = get_template_store()
|
||||
pattern_store_inst: PatternTemplateStore = get_pattern_template_store()
|
||||
processor_manager_inst: ProcessorManager = get_processor_manager()
|
||||
device_store_inst: DeviceStore = get_device_store()
|
||||
pp_template_store_inst = get_pp_template_store()
|
||||
|
||||
# Validate target
|
||||
try:
|
||||
target = target_store_inst.get_target(target_id)
|
||||
except ValueError as e:
|
||||
await websocket.close(code=4004, reason=str(e))
|
||||
return
|
||||
|
||||
if not isinstance(target, KeyColorsOutputTarget):
|
||||
await websocket.close(code=4003, reason="Target is not a key_colors target")
|
||||
return
|
||||
|
||||
settings = target.settings
|
||||
|
||||
if not settings.pattern_template_id:
|
||||
await websocket.close(code=4003, reason="No pattern template configured")
|
||||
return
|
||||
|
||||
try:
|
||||
pattern_tmpl = pattern_store_inst.get_template(settings.pattern_template_id)
|
||||
except ValueError:
|
||||
await websocket.close(code=4003, reason=f"Pattern template not found: {settings.pattern_template_id}")
|
||||
return
|
||||
|
||||
rectangles = pattern_tmpl.rectangles
|
||||
if not rectangles:
|
||||
await websocket.close(code=4003, reason="Pattern template has no rectangles")
|
||||
return
|
||||
|
||||
if not target.picture_source_id:
|
||||
await websocket.close(code=4003, reason="No picture source configured")
|
||||
return
|
||||
|
||||
try:
|
||||
chain = source_store_inst.resolve_stream_chain(target.picture_source_id)
|
||||
except ValueError as e:
|
||||
await websocket.close(code=4003, reason=str(e))
|
||||
return
|
||||
|
||||
raw_stream = chain["raw_stream"]
|
||||
|
||||
# For screen capture sources, check display lock
|
||||
if isinstance(raw_stream, ScreenCapturePictureSource):
|
||||
display_index = raw_stream.display_index
|
||||
locked_device_id = processor_manager_inst.get_display_lock_info(display_index)
|
||||
if locked_device_id:
|
||||
try:
|
||||
device = device_store_inst.get_device(locked_device_id)
|
||||
device_name = device.name
|
||||
except Exception:
|
||||
device_name = locked_device_id
|
||||
await websocket.close(
|
||||
code=4003,
|
||||
reason=f"Display {display_index} is captured by '{device_name}'. Stop processing first.",
|
||||
)
|
||||
return
|
||||
|
||||
fps = max(1, min(30, fps))
|
||||
preview_width = max(120, min(1920, preview_width))
|
||||
frame_interval = 1.0 / fps
|
||||
|
||||
calc_fns = {
|
||||
"average": calculate_average_color,
|
||||
"median": calculate_median_color,
|
||||
"dominant": calculate_dominant_color,
|
||||
}
|
||||
calc_fn = calc_fns.get(settings.interpolation_mode, calculate_average_color)
|
||||
|
||||
await websocket.accept()
|
||||
logger.info(f"KC test WS connected for {target_id} (fps={fps})")
|
||||
|
||||
# Use the shared LiveStreamManager so we share the capture stream with
|
||||
# running LED targets instead of creating a competing DXGI duplicator.
|
||||
live_stream_mgr = processor_manager_inst._live_stream_manager
|
||||
live_stream = None
|
||||
|
||||
try:
|
||||
live_stream = await asyncio.to_thread(
|
||||
live_stream_mgr.acquire, target.picture_source_id
|
||||
)
|
||||
logger.info(f"KC test WS acquired shared live stream for {target.picture_source_id}")
|
||||
|
||||
prev_frame_ref = None
|
||||
|
||||
while True:
|
||||
loop_start = time.monotonic()
|
||||
|
||||
try:
|
||||
capture = await asyncio.to_thread(live_stream.get_latest_frame)
|
||||
|
||||
if capture is None or capture.image is None:
|
||||
await asyncio.sleep(frame_interval)
|
||||
continue
|
||||
|
||||
# Skip if same frame object (no new capture yet)
|
||||
if capture is prev_frame_ref:
|
||||
await asyncio.sleep(frame_interval * 0.5)
|
||||
continue
|
||||
prev_frame_ref = capture
|
||||
|
||||
pil_image = Image.fromarray(capture.image) if isinstance(capture.image, np.ndarray) else None
|
||||
if pil_image is None:
|
||||
await asyncio.sleep(frame_interval)
|
||||
continue
|
||||
|
||||
# Apply postprocessing (if the source chain has PP templates)
|
||||
chain = source_store_inst.resolve_stream_chain(target.picture_source_id)
|
||||
pp_template_ids = chain.get("postprocessing_template_ids", [])
|
||||
if pp_template_ids and pp_template_store_inst:
|
||||
img_array = np.array(pil_image)
|
||||
image_pool = ImagePool()
|
||||
for pp_id in pp_template_ids:
|
||||
try:
|
||||
pp_template = pp_template_store_inst.get_template(pp_id)
|
||||
except ValueError:
|
||||
continue
|
||||
flat_filters = pp_template_store_inst.resolve_filter_instances(pp_template.filters)
|
||||
for fi in flat_filters:
|
||||
try:
|
||||
f = FilterRegistry.create_instance(fi.filter_id, fi.options)
|
||||
result = f.process_image(img_array, image_pool)
|
||||
if result is not None:
|
||||
img_array = result
|
||||
except ValueError:
|
||||
pass
|
||||
pil_image = Image.fromarray(img_array)
|
||||
|
||||
# Extract colors
|
||||
img_array = np.array(pil_image)
|
||||
h, w = img_array.shape[:2]
|
||||
|
||||
result_rects = []
|
||||
for rect in rectangles:
|
||||
px_x = max(0, int(rect.x * w))
|
||||
px_y = max(0, int(rect.y * h))
|
||||
px_w = max(1, int(rect.width * w))
|
||||
px_h = max(1, int(rect.height * h))
|
||||
px_x = min(px_x, w - 1)
|
||||
px_y = min(px_y, h - 1)
|
||||
px_w = min(px_w, w - px_x)
|
||||
px_h = min(px_h, h - px_y)
|
||||
|
||||
sub_img = img_array[px_y:px_y + px_h, px_x:px_x + px_w]
|
||||
r, g, b = calc_fn(sub_img)
|
||||
|
||||
result_rects.append({
|
||||
"name": rect.name,
|
||||
"x": rect.x,
|
||||
"y": rect.y,
|
||||
"width": rect.width,
|
||||
"height": rect.height,
|
||||
"color": {"r": r, "g": g, "b": b, "hex": f"#{r:02x}{g:02x}{b:02x}"},
|
||||
})
|
||||
|
||||
# Encode frame as JPEG
|
||||
if preview_width and pil_image.width > preview_width:
|
||||
ratio = preview_width / pil_image.width
|
||||
thumb = pil_image.resize((preview_width, int(pil_image.height * ratio)), Image.LANCZOS)
|
||||
else:
|
||||
thumb = pil_image
|
||||
buf = io.BytesIO()
|
||||
thumb.save(buf, format="JPEG", quality=85)
|
||||
b64 = base64.b64encode(buf.getvalue()).decode()
|
||||
|
||||
await websocket.send_text(_json.dumps({
|
||||
"type": "frame",
|
||||
"image": f"data:image/jpeg;base64,{b64}",
|
||||
"rectangles": result_rects,
|
||||
"pattern_template_name": pattern_tmpl.name,
|
||||
"interpolation_mode": settings.interpolation_mode,
|
||||
}))
|
||||
|
||||
except (WebSocketDisconnect, Exception) as inner_e:
|
||||
if isinstance(inner_e, WebSocketDisconnect):
|
||||
raise
|
||||
logger.warning(f"KC test WS frame error for {target_id}: {inner_e}")
|
||||
|
||||
elapsed = time.monotonic() - loop_start
|
||||
sleep_time = frame_interval - elapsed
|
||||
if sleep_time > 0:
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"KC test WS disconnected for {target_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"KC test WS error for {target_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
if live_stream is not None:
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
live_stream_mgr.release, target.picture_source_id
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"KC test WS closed for {target_id}")
|
||||
|
||||
|
||||
@router.websocket("/api/v1/output-targets/{target_id}/ws")
|
||||
async def target_colors_ws(
|
||||
websocket: WebSocket,
|
||||
target_id: str,
|
||||
token: str = Query(""),
|
||||
):
|
||||
"""WebSocket for real-time key color updates. Auth via ?token=<api_key>."""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
manager = get_processor_manager()
|
||||
|
||||
try:
|
||||
manager.add_kc_ws_client(target_id, websocket)
|
||||
except ValueError:
|
||||
await websocket.close(code=4004, reason="Target not found")
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Keep alive — wait for client messages (or disconnect)
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
manager.remove_kc_ws_client(target_id, websocket)
|
||||
@@ -584,7 +584,7 @@ async def test_picture_source_ws(
|
||||
preview_width: int = Query(0),
|
||||
):
|
||||
"""WebSocket for picture source test with intermediate frame previews."""
|
||||
from wled_controller.api.routes._test_helpers import (
|
||||
from wled_controller.api.routes._preview_helpers import (
|
||||
authenticate_ws_token,
|
||||
stream_capture_test,
|
||||
)
|
||||
|
||||
@@ -365,7 +365,7 @@ async def test_pp_template_ws(
|
||||
preview_width: int = Query(0),
|
||||
):
|
||||
"""WebSocket for PP template test with intermediate frame previews."""
|
||||
from wled_controller.api.routes._test_helpers import (
|
||||
from wled_controller.api.routes._preview_helpers import (
|
||||
authenticate_ws_token,
|
||||
stream_capture_test,
|
||||
)
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
"""System routes: health, version, displays, performance, backup/restore, ADB."""
|
||||
"""System routes: health, version, displays, performance, tags, api-keys.
|
||||
|
||||
Backup/restore and settings routes are in backup.py and system_settings.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import psutil
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from wled_controller import __version__
|
||||
from wled_controller.api.auth import AuthRequired
|
||||
from wled_controller.api.dependencies import (
|
||||
get_auto_backup_engine,
|
||||
get_audio_source_store,
|
||||
get_audio_template_store,
|
||||
get_automation_store,
|
||||
@@ -37,29 +32,22 @@ from wled_controller.api.dependencies import (
|
||||
get_value_source_store,
|
||||
)
|
||||
from wled_controller.api.schemas.system import (
|
||||
AutoBackupSettings,
|
||||
AutoBackupStatusResponse,
|
||||
BackupFileInfo,
|
||||
BackupListResponse,
|
||||
DisplayInfo,
|
||||
DisplayListResponse,
|
||||
ExternalUrlRequest,
|
||||
ExternalUrlResponse,
|
||||
GpuInfo,
|
||||
HealthResponse,
|
||||
LogLevelRequest,
|
||||
LogLevelResponse,
|
||||
MQTTSettingsRequest,
|
||||
MQTTSettingsResponse,
|
||||
PerformanceResponse,
|
||||
ProcessListResponse,
|
||||
RestoreResponse,
|
||||
VersionResponse,
|
||||
)
|
||||
from wled_controller.core.backup.auto_backup import AutoBackupEngine
|
||||
from wled_controller.config import get_config, is_demo_mode
|
||||
from wled_controller.core.capture.screen_capture import get_available_displays
|
||||
from wled_controller.utils import atomic_write_json, get_logger
|
||||
from wled_controller.utils import get_logger
|
||||
from wled_controller.storage.base_store import EntityNotFoundError
|
||||
|
||||
# Re-export STORE_MAP and load_external_url so existing callers still work
|
||||
from wled_controller.api.routes.backup import STORE_MAP # noqa: F401
|
||||
from wled_controller.api.routes.system_settings import load_external_url # noqa: F401
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -68,7 +56,6 @@ psutil.cpu_percent(interval=None)
|
||||
|
||||
# GPU monitoring (initialized once in utils.gpu, shared with metrics_history)
|
||||
from wled_controller.utils.gpu import nvml_available as _nvml_available, nvml as _nvml, nvml_handle as _nvml_handle
|
||||
from wled_controller.storage.base_store import EntityNotFoundError
|
||||
|
||||
|
||||
def _get_cpu_name() -> str | None:
|
||||
@@ -113,7 +100,7 @@ async def health_check():
|
||||
|
||||
Returns basic health information including status, version, and timestamp.
|
||||
"""
|
||||
logger.info("Health check requested")
|
||||
logger.debug("Health check requested")
|
||||
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
@@ -129,7 +116,7 @@ async def get_version():
|
||||
|
||||
Returns application version, Python version, and API version.
|
||||
"""
|
||||
logger.info("Version info requested")
|
||||
logger.debug("Version info requested")
|
||||
|
||||
return VersionResponse(
|
||||
version=__version__,
|
||||
@@ -308,52 +295,6 @@ async def get_metrics_history(
|
||||
return manager.metrics_history.get_history()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration backup / restore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Mapping: logical store name → StorageConfig attribute name
|
||||
STORE_MAP = {
|
||||
"devices": "devices_file",
|
||||
"capture_templates": "templates_file",
|
||||
"postprocessing_templates": "postprocessing_templates_file",
|
||||
"picture_sources": "picture_sources_file",
|
||||
"output_targets": "output_targets_file",
|
||||
"pattern_templates": "pattern_templates_file",
|
||||
"color_strip_sources": "color_strip_sources_file",
|
||||
"audio_sources": "audio_sources_file",
|
||||
"audio_templates": "audio_templates_file",
|
||||
"value_sources": "value_sources_file",
|
||||
"sync_clocks": "sync_clocks_file",
|
||||
"color_strip_processing_templates": "color_strip_processing_templates_file",
|
||||
"automations": "automations_file",
|
||||
"scene_presets": "scene_presets_file",
|
||||
}
|
||||
|
||||
_SERVER_DIR = Path(__file__).resolve().parents[4]
|
||||
|
||||
|
||||
def _schedule_restart() -> None:
|
||||
"""Spawn a restart script after a short delay so the HTTP response completes."""
|
||||
|
||||
def _restart():
|
||||
import time
|
||||
time.sleep(1)
|
||||
if sys.platform == "win32":
|
||||
subprocess.Popen(
|
||||
["powershell", "-ExecutionPolicy", "Bypass", "-File",
|
||||
str(_SERVER_DIR / "restart.ps1")],
|
||||
creationflags=subprocess.DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP,
|
||||
)
|
||||
else:
|
||||
subprocess.Popen(
|
||||
["bash", str(_SERVER_DIR / "restart.sh")],
|
||||
start_new_session=True,
|
||||
)
|
||||
|
||||
threading.Thread(target=_restart, daemon=True).start()
|
||||
|
||||
|
||||
@router.get("/api/v1/system/api-keys", tags=["System"])
|
||||
def list_api_keys(_: AuthRequired):
|
||||
"""List API key labels (read-only; keys are defined in the YAML config file)."""
|
||||
@@ -363,632 +304,3 @@ def list_api_keys(_: AuthRequired):
|
||||
for label, key in config.auth.api_keys.items()
|
||||
]
|
||||
return {"keys": keys, "count": len(keys)}
|
||||
|
||||
|
||||
@router.get("/api/v1/system/export/{store_key}", tags=["System"])
|
||||
def export_store(store_key: str, _: AuthRequired):
|
||||
"""Download a single entity store as a JSON file."""
|
||||
if store_key not in STORE_MAP:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Unknown store '{store_key}'. Valid keys: {sorted(STORE_MAP.keys())}",
|
||||
)
|
||||
config = get_config()
|
||||
file_path = Path(getattr(config.storage, STORE_MAP[store_key]))
|
||||
if file_path.exists():
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
else:
|
||||
data = {}
|
||||
|
||||
export = {
|
||||
"meta": {
|
||||
"format": "ledgrab-partial-export",
|
||||
"format_version": 1,
|
||||
"store_key": store_key,
|
||||
"app_version": __version__,
|
||||
"created_at": datetime.now(timezone.utc).isoformat() + "Z",
|
||||
},
|
||||
"store": data,
|
||||
}
|
||||
content = json.dumps(export, indent=2, ensure_ascii=False)
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H%M%S")
|
||||
filename = f"ledgrab-{store_key}-{timestamp}.json"
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content.encode("utf-8")),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/system/import/{store_key}", tags=["System"])
|
||||
async def import_store(
|
||||
store_key: str,
|
||||
_: AuthRequired,
|
||||
file: UploadFile = File(...),
|
||||
merge: bool = Query(False, description="Merge into existing data instead of replacing"),
|
||||
):
|
||||
"""Upload a partial export file to replace or merge one entity store. Triggers server restart."""
|
||||
if store_key not in STORE_MAP:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Unknown store '{store_key}'. Valid keys: {sorted(STORE_MAP.keys())}",
|
||||
)
|
||||
|
||||
try:
|
||||
raw = await file.read()
|
||||
if len(raw) > 10 * 1024 * 1024:
|
||||
raise HTTPException(status_code=400, detail="File too large (max 10 MB)")
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON: {e}")
|
||||
|
||||
# Support both full-backup format and partial-export format
|
||||
if "stores" in payload and isinstance(payload.get("meta"), dict):
|
||||
# Full backup: extract the specific store
|
||||
if payload["meta"].get("format") not in ("ledgrab-backup",):
|
||||
raise HTTPException(status_code=400, detail="Not a valid LED Grab backup or partial export file")
|
||||
stores = payload.get("stores", {})
|
||||
if store_key not in stores:
|
||||
raise HTTPException(status_code=400, detail=f"Backup does not contain store '{store_key}'")
|
||||
incoming = stores[store_key]
|
||||
elif isinstance(payload.get("meta"), dict) and payload["meta"].get("format") == "ledgrab-partial-export":
|
||||
# Partial export format
|
||||
if payload["meta"].get("store_key") != store_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"File is for store '{payload['meta']['store_key']}', not '{store_key}'",
|
||||
)
|
||||
incoming = payload.get("store", {})
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Not a valid LED Grab backup or partial export file")
|
||||
|
||||
if not isinstance(incoming, dict):
|
||||
raise HTTPException(status_code=400, detail="Store data must be a JSON object")
|
||||
|
||||
config = get_config()
|
||||
file_path = Path(getattr(config.storage, STORE_MAP[store_key]))
|
||||
|
||||
def _write():
|
||||
if merge and file_path.exists():
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
existing = json.load(f)
|
||||
if isinstance(existing, dict):
|
||||
existing.update(incoming)
|
||||
atomic_write_json(file_path, existing)
|
||||
return len(existing)
|
||||
atomic_write_json(file_path, incoming)
|
||||
return len(incoming)
|
||||
|
||||
count = await asyncio.to_thread(_write)
|
||||
logger.info(f"Imported store '{store_key}' ({count} entries, merge={merge}). Scheduling restart...")
|
||||
_schedule_restart()
|
||||
return {
|
||||
"status": "imported",
|
||||
"store_key": store_key,
|
||||
"entries": count,
|
||||
"merge": merge,
|
||||
"restart_scheduled": True,
|
||||
"message": f"Imported {count} entries for '{store_key}'. Server restarting...",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/v1/system/backup", tags=["System"])
|
||||
def backup_config(_: AuthRequired):
|
||||
"""Download all configuration as a single JSON backup file."""
|
||||
config = get_config()
|
||||
stores = {}
|
||||
for store_key, config_attr in STORE_MAP.items():
|
||||
file_path = Path(getattr(config.storage, config_attr))
|
||||
if file_path.exists():
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
stores[store_key] = json.load(f)
|
||||
else:
|
||||
stores[store_key] = {}
|
||||
|
||||
backup = {
|
||||
"meta": {
|
||||
"format": "ledgrab-backup",
|
||||
"format_version": 1,
|
||||
"app_version": __version__,
|
||||
"created_at": datetime.now(timezone.utc).isoformat() + "Z",
|
||||
"store_count": len(stores),
|
||||
},
|
||||
"stores": stores,
|
||||
}
|
||||
|
||||
content = json.dumps(backup, indent=2, ensure_ascii=False)
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H%M%S")
|
||||
filename = f"ledgrab-backup-{timestamp}.json"
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content.encode("utf-8")),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/system/restart", tags=["System"])
|
||||
def restart_server(_: AuthRequired):
|
||||
"""Schedule a server restart and return immediately."""
|
||||
_schedule_restart()
|
||||
return {"status": "restarting"}
|
||||
|
||||
|
||||
@router.post("/api/v1/system/restore", response_model=RestoreResponse, tags=["System"])
|
||||
async def restore_config(
|
||||
_: AuthRequired,
|
||||
file: UploadFile = File(...),
|
||||
):
|
||||
"""Upload a backup file to restore all configuration. Triggers server restart."""
|
||||
# Read and parse
|
||||
try:
|
||||
raw = await file.read()
|
||||
if len(raw) > 10 * 1024 * 1024: # 10 MB limit
|
||||
raise HTTPException(status_code=400, detail="Backup file too large (max 10 MB)")
|
||||
backup = json.loads(raw)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON file: {e}")
|
||||
|
||||
# Validate envelope
|
||||
meta = backup.get("meta")
|
||||
if not isinstance(meta, dict) or meta.get("format") != "ledgrab-backup":
|
||||
raise HTTPException(status_code=400, detail="Not a valid LED Grab backup file")
|
||||
|
||||
fmt_version = meta.get("format_version", 0)
|
||||
if fmt_version > 1:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Backup format version {fmt_version} is not supported by this server version",
|
||||
)
|
||||
|
||||
stores = backup.get("stores")
|
||||
if not isinstance(stores, dict):
|
||||
raise HTTPException(status_code=400, detail="Backup file missing 'stores' section")
|
||||
|
||||
known_keys = set(STORE_MAP.keys())
|
||||
present_keys = known_keys & set(stores.keys())
|
||||
if not present_keys:
|
||||
raise HTTPException(status_code=400, detail="Backup contains no recognized store data")
|
||||
|
||||
for key in present_keys:
|
||||
if not isinstance(stores[key], dict):
|
||||
raise HTTPException(status_code=400, detail=f"Store '{key}' in backup is not a valid JSON object")
|
||||
|
||||
# Write store files atomically (in thread to avoid blocking event loop)
|
||||
config = get_config()
|
||||
|
||||
def _write_stores():
|
||||
count = 0
|
||||
for store_key, config_attr in STORE_MAP.items():
|
||||
if store_key in stores:
|
||||
file_path = Path(getattr(config.storage, config_attr))
|
||||
atomic_write_json(file_path, stores[store_key])
|
||||
count += 1
|
||||
logger.info(f"Restored store: {store_key} -> {file_path}")
|
||||
return count
|
||||
|
||||
written = await asyncio.to_thread(_write_stores)
|
||||
|
||||
logger.info(f"Restore complete: {written}/{len(STORE_MAP)} stores written. Scheduling restart...")
|
||||
_schedule_restart()
|
||||
|
||||
missing = known_keys - present_keys
|
||||
return RestoreResponse(
|
||||
status="restored",
|
||||
stores_written=written,
|
||||
stores_total=len(STORE_MAP),
|
||||
missing_stores=sorted(missing) if missing else [],
|
||||
restart_scheduled=True,
|
||||
message=f"Restored {written} stores. Server restarting...",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-backup settings & saved backups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/api/v1/system/auto-backup/settings",
|
||||
response_model=AutoBackupStatusResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def get_auto_backup_settings(
|
||||
_: AuthRequired,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""Get auto-backup settings and status."""
|
||||
return engine.get_settings()
|
||||
|
||||
|
||||
@router.put(
|
||||
"/api/v1/system/auto-backup/settings",
|
||||
response_model=AutoBackupStatusResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def update_auto_backup_settings(
|
||||
_: AuthRequired,
|
||||
body: AutoBackupSettings,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""Update auto-backup settings (enable/disable, interval, max backups)."""
|
||||
return await engine.update_settings(
|
||||
enabled=body.enabled,
|
||||
interval_hours=body.interval_hours,
|
||||
max_backups=body.max_backups,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/system/auto-backup/trigger", tags=["System"])
|
||||
async def trigger_backup(
|
||||
_: AuthRequired,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""Manually trigger a backup now."""
|
||||
backup = await engine.trigger_backup()
|
||||
return {"status": "ok", "backup": backup}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/api/v1/system/backups",
|
||||
response_model=BackupListResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def list_backups(
|
||||
_: AuthRequired,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""List all saved backup files."""
|
||||
backups = engine.list_backups()
|
||||
return BackupListResponse(
|
||||
backups=[BackupFileInfo(**b) for b in backups],
|
||||
count=len(backups),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/api/v1/system/backups/{filename}", tags=["System"])
|
||||
def download_saved_backup(
|
||||
filename: str,
|
||||
_: AuthRequired,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""Download a specific saved backup file."""
|
||||
try:
|
||||
path = engine.get_backup_path(filename)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
content = path.read_bytes()
|
||||
return StreamingResponse(
|
||||
io.BytesIO(content),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/api/v1/system/backups/{filename}", tags=["System"])
|
||||
async def delete_saved_backup(
|
||||
filename: str,
|
||||
_: AuthRequired,
|
||||
engine: AutoBackupEngine = Depends(get_auto_backup_engine),
|
||||
):
|
||||
"""Delete a specific saved backup file."""
|
||||
try:
|
||||
engine.delete_backup(filename)
|
||||
except (ValueError, FileNotFoundError) as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return {"status": "deleted", "filename": filename}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MQTT settings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MQTT_SETTINGS_FILE: Path | None = None
|
||||
|
||||
|
||||
def _get_mqtt_settings_path() -> Path:
|
||||
global _MQTT_SETTINGS_FILE
|
||||
if _MQTT_SETTINGS_FILE is None:
|
||||
cfg = get_config()
|
||||
# Derive the data directory from any known storage file path
|
||||
data_dir = Path(cfg.storage.devices_file).parent
|
||||
_MQTT_SETTINGS_FILE = data_dir / "mqtt_settings.json"
|
||||
return _MQTT_SETTINGS_FILE
|
||||
|
||||
|
||||
def _load_mqtt_settings() -> dict:
|
||||
"""Load MQTT settings: YAML config defaults overridden by JSON overrides file."""
|
||||
cfg = get_config()
|
||||
defaults = {
|
||||
"enabled": cfg.mqtt.enabled,
|
||||
"broker_host": cfg.mqtt.broker_host,
|
||||
"broker_port": cfg.mqtt.broker_port,
|
||||
"username": cfg.mqtt.username,
|
||||
"password": cfg.mqtt.password,
|
||||
"client_id": cfg.mqtt.client_id,
|
||||
"base_topic": cfg.mqtt.base_topic,
|
||||
}
|
||||
path = _get_mqtt_settings_path()
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
overrides = json.load(f)
|
||||
defaults.update(overrides)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load MQTT settings override file: {e}")
|
||||
return defaults
|
||||
|
||||
|
||||
def _save_mqtt_settings(settings: dict) -> None:
|
||||
"""Persist MQTT settings to the JSON override file."""
|
||||
from wled_controller.utils import atomic_write_json
|
||||
atomic_write_json(_get_mqtt_settings_path(), settings)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/api/v1/system/mqtt/settings",
|
||||
response_model=MQTTSettingsResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def get_mqtt_settings(_: AuthRequired):
|
||||
"""Get current MQTT broker settings. Password is masked."""
|
||||
s = _load_mqtt_settings()
|
||||
return MQTTSettingsResponse(
|
||||
enabled=s["enabled"],
|
||||
broker_host=s["broker_host"],
|
||||
broker_port=s["broker_port"],
|
||||
username=s["username"],
|
||||
password_set=bool(s.get("password")),
|
||||
client_id=s["client_id"],
|
||||
base_topic=s["base_topic"],
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/api/v1/system/mqtt/settings",
|
||||
response_model=MQTTSettingsResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def update_mqtt_settings(_: AuthRequired, body: MQTTSettingsRequest):
|
||||
"""Update MQTT broker settings. If password is empty string, the existing password is preserved."""
|
||||
current = _load_mqtt_settings()
|
||||
|
||||
# If caller sends an empty password, keep the existing one
|
||||
password = body.password if body.password else current.get("password", "")
|
||||
|
||||
new_settings = {
|
||||
"enabled": body.enabled,
|
||||
"broker_host": body.broker_host,
|
||||
"broker_port": body.broker_port,
|
||||
"username": body.username,
|
||||
"password": password,
|
||||
"client_id": body.client_id,
|
||||
"base_topic": body.base_topic,
|
||||
}
|
||||
_save_mqtt_settings(new_settings)
|
||||
logger.info("MQTT settings updated")
|
||||
|
||||
return MQTTSettingsResponse(
|
||||
enabled=new_settings["enabled"],
|
||||
broker_host=new_settings["broker_host"],
|
||||
broker_port=new_settings["broker_port"],
|
||||
username=new_settings["username"],
|
||||
password_set=bool(new_settings["password"]),
|
||||
client_id=new_settings["client_id"],
|
||||
base_topic=new_settings["base_topic"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# External URL setting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EXTERNAL_URL_FILE: Path | None = None
|
||||
|
||||
|
||||
def _get_external_url_path() -> Path:
|
||||
global _EXTERNAL_URL_FILE
|
||||
if _EXTERNAL_URL_FILE is None:
|
||||
cfg = get_config()
|
||||
data_dir = Path(cfg.storage.devices_file).parent
|
||||
_EXTERNAL_URL_FILE = data_dir / "external_url.json"
|
||||
return _EXTERNAL_URL_FILE
|
||||
|
||||
|
||||
def load_external_url() -> str:
|
||||
"""Load the external URL setting. Returns empty string if not set."""
|
||||
path = _get_external_url_path()
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data.get("external_url", "")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def _save_external_url(url: str) -> None:
|
||||
from wled_controller.utils import atomic_write_json
|
||||
atomic_write_json(_get_external_url_path(), {"external_url": url})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/api/v1/system/external-url",
|
||||
response_model=ExternalUrlResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def get_external_url(_: AuthRequired):
|
||||
"""Get the configured external base URL."""
|
||||
return ExternalUrlResponse(external_url=load_external_url())
|
||||
|
||||
|
||||
@router.put(
|
||||
"/api/v1/system/external-url",
|
||||
response_model=ExternalUrlResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def update_external_url(_: AuthRequired, body: ExternalUrlRequest):
|
||||
"""Set the external base URL used in webhook URLs and other user-visible URLs."""
|
||||
url = body.external_url.strip().rstrip("/")
|
||||
_save_external_url(url)
|
||||
logger.info("External URL updated: %s", url or "(cleared)")
|
||||
return ExternalUrlResponse(external_url=url)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Live log viewer WebSocket
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.websocket("/api/v1/system/logs/ws")
|
||||
async def logs_ws(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(""),
|
||||
):
|
||||
"""WebSocket that streams server log lines in real time.
|
||||
|
||||
Auth via ``?token=<api_key>``. On connect, sends the last ~500 buffered
|
||||
lines as individual text messages, then pushes new lines as they appear.
|
||||
"""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
from wled_controller.utils import log_broadcaster
|
||||
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# Ensure the broadcaster knows the event loop (may be first connection)
|
||||
log_broadcaster.ensure_loop()
|
||||
|
||||
# Subscribe *before* reading the backlog so no lines slip through
|
||||
queue = log_broadcaster.subscribe()
|
||||
|
||||
try:
|
||||
# Send backlog first
|
||||
for line in log_broadcaster.get_backlog():
|
||||
await websocket.send_text(line)
|
||||
|
||||
# Stream new lines
|
||||
while True:
|
||||
try:
|
||||
line = await asyncio.wait_for(queue.get(), timeout=30.0)
|
||||
await websocket.send_text(line)
|
||||
except asyncio.TimeoutError:
|
||||
# Send a keepalive ping so the connection stays alive
|
||||
try:
|
||||
await websocket.send_text("")
|
||||
except Exception:
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
log_broadcaster.unsubscribe(queue)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ADB helpers (for Android / scrcpy engine)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class AdbConnectRequest(BaseModel):
|
||||
address: str
|
||||
|
||||
|
||||
def _get_adb_path() -> str:
|
||||
"""Get the adb binary path from the scrcpy engine's resolver."""
|
||||
from wled_controller.core.capture_engines.scrcpy_engine import _get_adb
|
||||
return _get_adb()
|
||||
|
||||
|
||||
@router.post("/api/v1/adb/connect", tags=["ADB"])
|
||||
async def adb_connect(_: AuthRequired, request: AdbConnectRequest):
|
||||
"""Connect to a WiFi ADB device by IP address.
|
||||
|
||||
Appends ``:5555`` if no port is specified.
|
||||
"""
|
||||
address = request.address.strip()
|
||||
if not address:
|
||||
raise HTTPException(status_code=400, detail="Address is required")
|
||||
if ":" not in address:
|
||||
address = f"{address}:5555"
|
||||
|
||||
adb = _get_adb_path()
|
||||
logger.info(f"Connecting ADB device: {address}")
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
adb, "connect", address,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=10)
|
||||
output = (stdout.decode() + stderr.decode()).strip()
|
||||
if "connected" in output.lower():
|
||||
return {"status": "connected", "address": address, "message": output}
|
||||
raise HTTPException(status_code=400, detail=output or "Connection failed")
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="adb not found on PATH. Install Android SDK Platform-Tools.",
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise HTTPException(status_code=504, detail="ADB connect timed out")
|
||||
|
||||
|
||||
@router.post("/api/v1/adb/disconnect", tags=["ADB"])
|
||||
async def adb_disconnect(_: AuthRequired, request: AdbConnectRequest):
|
||||
"""Disconnect a WiFi ADB device."""
|
||||
address = request.address.strip()
|
||||
if not address:
|
||||
raise HTTPException(status_code=400, detail="Address is required")
|
||||
|
||||
adb = _get_adb_path()
|
||||
logger.info(f"Disconnecting ADB device: {address}")
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
adb, "disconnect", address,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=10)
|
||||
return {"status": "disconnected", "message": stdout.decode().strip()}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=500, detail="adb not found on PATH")
|
||||
except asyncio.TimeoutError:
|
||||
raise HTTPException(status_code=504, detail="ADB disconnect timed out")
|
||||
|
||||
|
||||
# ─── Log level ─────────────────────────────────────────────────
|
||||
|
||||
_VALID_LOG_LEVELS = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||
|
||||
|
||||
@router.get("/api/v1/system/log-level", response_model=LogLevelResponse, tags=["System"])
|
||||
async def get_log_level(_: AuthRequired):
|
||||
"""Get the current root logger log level."""
|
||||
level_int = logging.getLogger().getEffectiveLevel()
|
||||
return LogLevelResponse(level=logging.getLevelName(level_int))
|
||||
|
||||
|
||||
@router.put("/api/v1/system/log-level", response_model=LogLevelResponse, tags=["System"])
|
||||
async def set_log_level(_: AuthRequired, body: LogLevelRequest):
|
||||
"""Change the root logger log level at runtime (no server restart required)."""
|
||||
level_name = body.level.upper()
|
||||
if level_name not in _VALID_LOG_LEVELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid log level '{body.level}'. Must be one of: {', '.join(sorted(_VALID_LOG_LEVELS))}",
|
||||
)
|
||||
level_int = getattr(logging, level_name)
|
||||
root = logging.getLogger()
|
||||
root.setLevel(level_int)
|
||||
# Also update all handlers so they actually emit at the new level
|
||||
for handler in root.handlers:
|
||||
handler.setLevel(level_int)
|
||||
logger.info("Log level changed to %s", level_name)
|
||||
return LogLevelResponse(level=level_name)
|
||||
|
||||
377
server/src/wled_controller/api/routes/system_settings.py
Normal file
377
server/src/wled_controller/api/routes/system_settings.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""System routes: MQTT, external URL, ADB, logs WebSocket, log level.
|
||||
|
||||
Extracted from system.py to keep files under 800 lines.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||||
from pydantic import BaseModel
|
||||
|
||||
from wled_controller.api.auth import AuthRequired
|
||||
from wled_controller.api.schemas.system import (
|
||||
ExternalUrlRequest,
|
||||
ExternalUrlResponse,
|
||||
LogLevelRequest,
|
||||
LogLevelResponse,
|
||||
MQTTSettingsRequest,
|
||||
MQTTSettingsResponse,
|
||||
)
|
||||
from wled_controller.config import get_config
|
||||
from wled_controller.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MQTT settings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MQTT_SETTINGS_FILE: Path | None = None
|
||||
|
||||
|
||||
def _get_mqtt_settings_path() -> Path:
|
||||
global _MQTT_SETTINGS_FILE
|
||||
if _MQTT_SETTINGS_FILE is None:
|
||||
cfg = get_config()
|
||||
# Derive the data directory from any known storage file path
|
||||
data_dir = Path(cfg.storage.devices_file).parent
|
||||
_MQTT_SETTINGS_FILE = data_dir / "mqtt_settings.json"
|
||||
return _MQTT_SETTINGS_FILE
|
||||
|
||||
|
||||
def _load_mqtt_settings() -> dict:
|
||||
"""Load MQTT settings: YAML config defaults overridden by JSON overrides file."""
|
||||
cfg = get_config()
|
||||
defaults = {
|
||||
"enabled": cfg.mqtt.enabled,
|
||||
"broker_host": cfg.mqtt.broker_host,
|
||||
"broker_port": cfg.mqtt.broker_port,
|
||||
"username": cfg.mqtt.username,
|
||||
"password": cfg.mqtt.password,
|
||||
"client_id": cfg.mqtt.client_id,
|
||||
"base_topic": cfg.mqtt.base_topic,
|
||||
}
|
||||
path = _get_mqtt_settings_path()
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
overrides = json.load(f)
|
||||
defaults.update(overrides)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load MQTT settings override file: {e}")
|
||||
return defaults
|
||||
|
||||
|
||||
def _save_mqtt_settings(settings: dict) -> None:
|
||||
"""Persist MQTT settings to the JSON override file."""
|
||||
from wled_controller.utils import atomic_write_json
|
||||
atomic_write_json(_get_mqtt_settings_path(), settings)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/api/v1/system/mqtt/settings",
|
||||
response_model=MQTTSettingsResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def get_mqtt_settings(_: AuthRequired):
|
||||
"""Get current MQTT broker settings. Password is masked."""
|
||||
s = _load_mqtt_settings()
|
||||
return MQTTSettingsResponse(
|
||||
enabled=s["enabled"],
|
||||
broker_host=s["broker_host"],
|
||||
broker_port=s["broker_port"],
|
||||
username=s["username"],
|
||||
password_set=bool(s.get("password")),
|
||||
client_id=s["client_id"],
|
||||
base_topic=s["base_topic"],
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/api/v1/system/mqtt/settings",
|
||||
response_model=MQTTSettingsResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def update_mqtt_settings(_: AuthRequired, body: MQTTSettingsRequest):
|
||||
"""Update MQTT broker settings. If password is empty string, the existing password is preserved."""
|
||||
current = _load_mqtt_settings()
|
||||
|
||||
# If caller sends an empty password, keep the existing one
|
||||
password = body.password if body.password else current.get("password", "")
|
||||
|
||||
new_settings = {
|
||||
"enabled": body.enabled,
|
||||
"broker_host": body.broker_host,
|
||||
"broker_port": body.broker_port,
|
||||
"username": body.username,
|
||||
"password": password,
|
||||
"client_id": body.client_id,
|
||||
"base_topic": body.base_topic,
|
||||
}
|
||||
_save_mqtt_settings(new_settings)
|
||||
logger.info("MQTT settings updated")
|
||||
|
||||
return MQTTSettingsResponse(
|
||||
enabled=new_settings["enabled"],
|
||||
broker_host=new_settings["broker_host"],
|
||||
broker_port=new_settings["broker_port"],
|
||||
username=new_settings["username"],
|
||||
password_set=bool(new_settings["password"]),
|
||||
client_id=new_settings["client_id"],
|
||||
base_topic=new_settings["base_topic"],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# External URL setting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EXTERNAL_URL_FILE: Path | None = None
|
||||
|
||||
|
||||
def _get_external_url_path() -> Path:
|
||||
global _EXTERNAL_URL_FILE
|
||||
if _EXTERNAL_URL_FILE is None:
|
||||
cfg = get_config()
|
||||
data_dir = Path(cfg.storage.devices_file).parent
|
||||
_EXTERNAL_URL_FILE = data_dir / "external_url.json"
|
||||
return _EXTERNAL_URL_FILE
|
||||
|
||||
|
||||
def load_external_url() -> str:
|
||||
"""Load the external URL setting. Returns empty string if not set."""
|
||||
path = _get_external_url_path()
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data.get("external_url", "")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def _save_external_url(url: str) -> None:
|
||||
from wled_controller.utils import atomic_write_json
|
||||
atomic_write_json(_get_external_url_path(), {"external_url": url})
|
||||
|
||||
|
||||
@router.get(
|
||||
"/api/v1/system/external-url",
|
||||
response_model=ExternalUrlResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def get_external_url(_: AuthRequired):
|
||||
"""Get the configured external base URL."""
|
||||
return ExternalUrlResponse(external_url=load_external_url())
|
||||
|
||||
|
||||
@router.put(
|
||||
"/api/v1/system/external-url",
|
||||
response_model=ExternalUrlResponse,
|
||||
tags=["System"],
|
||||
)
|
||||
async def update_external_url(_: AuthRequired, body: ExternalUrlRequest):
|
||||
"""Set the external base URL used in webhook URLs and other user-visible URLs."""
|
||||
url = body.external_url.strip().rstrip("/")
|
||||
_save_external_url(url)
|
||||
logger.info("External URL updated: %s", url or "(cleared)")
|
||||
return ExternalUrlResponse(external_url=url)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Live log viewer WebSocket
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.websocket("/api/v1/system/logs/ws")
|
||||
async def logs_ws(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(""),
|
||||
):
|
||||
"""WebSocket that streams server log lines in real time.
|
||||
|
||||
Auth via ``?token=<api_key>``. On connect, sends the last ~500 buffered
|
||||
lines as individual text messages, then pushes new lines as they appear.
|
||||
"""
|
||||
from wled_controller.api.auth import verify_ws_token
|
||||
from wled_controller.utils import log_broadcaster
|
||||
|
||||
if not verify_ws_token(token):
|
||||
await websocket.close(code=4001, reason="Unauthorized")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# Ensure the broadcaster knows the event loop (may be first connection)
|
||||
log_broadcaster.ensure_loop()
|
||||
|
||||
# Subscribe *before* reading the backlog so no lines slip through
|
||||
queue = log_broadcaster.subscribe()
|
||||
|
||||
try:
|
||||
# Send backlog first
|
||||
for line in log_broadcaster.get_backlog():
|
||||
await websocket.send_text(line)
|
||||
|
||||
# Stream new lines
|
||||
while True:
|
||||
try:
|
||||
line = await asyncio.wait_for(queue.get(), timeout=30.0)
|
||||
await websocket.send_text(line)
|
||||
except asyncio.TimeoutError:
|
||||
# Send a keepalive ping so the connection stays alive
|
||||
try:
|
||||
await websocket.send_text("")
|
||||
except Exception:
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
log_broadcaster.unsubscribe(queue)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ADB helpers (for Android / scrcpy engine)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Regex: IPv4 address with optional port, e.g. "192.168.1.5" or "192.168.1.5:5555"
|
||||
_ADB_ADDRESS_RE = re.compile(
|
||||
r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}(:\d{1,5})?$"
|
||||
)
|
||||
|
||||
|
||||
class AdbConnectRequest(BaseModel):
|
||||
address: str
|
||||
|
||||
|
||||
def _validate_adb_address(address: str) -> None:
|
||||
"""Raise 400 if *address* is not a valid IP:port for ADB."""
|
||||
if not _ADB_ADDRESS_RE.match(address):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Invalid ADB address '{address}'. "
|
||||
"Expected format: <IP> or <IP>:<port>, e.g. 192.168.1.5 or 192.168.1.5:5555"
|
||||
),
|
||||
)
|
||||
# Validate each octet is 0-255 and port is 1-65535
|
||||
parts = address.split(":")
|
||||
ip_parts = parts[0].split(".")
|
||||
for octet in ip_parts:
|
||||
if not (0 <= int(octet) <= 255):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid IP octet '{octet}' in address '{address}'. Each octet must be 0-255.",
|
||||
)
|
||||
if len(parts) == 2:
|
||||
port = int(parts[1])
|
||||
if not (1 <= port <= 65535):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid port '{parts[1]}' in address '{address}'. Port must be 1-65535.",
|
||||
)
|
||||
|
||||
|
||||
def _get_adb_path() -> str:
|
||||
"""Get the adb binary path from the scrcpy engine's resolver."""
|
||||
from wled_controller.core.capture_engines.scrcpy_engine import _get_adb
|
||||
return _get_adb()
|
||||
|
||||
|
||||
@router.post("/api/v1/adb/connect", tags=["ADB"])
|
||||
async def adb_connect(_: AuthRequired, request: AdbConnectRequest):
|
||||
"""Connect to a WiFi ADB device by IP address.
|
||||
|
||||
Appends ``:5555`` if no port is specified.
|
||||
"""
|
||||
address = request.address.strip()
|
||||
if not address:
|
||||
raise HTTPException(status_code=400, detail="Address is required")
|
||||
_validate_adb_address(address)
|
||||
if ":" not in address:
|
||||
address = f"{address}:5555"
|
||||
|
||||
adb = _get_adb_path()
|
||||
logger.info(f"Connecting ADB device: {address}")
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
adb, "connect", address,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=10)
|
||||
output = (stdout.decode() + stderr.decode()).strip()
|
||||
if "connected" in output.lower():
|
||||
return {"status": "connected", "address": address, "message": output}
|
||||
raise HTTPException(status_code=400, detail=output or "Connection failed")
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="adb not found on PATH. Install Android SDK Platform-Tools.",
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise HTTPException(status_code=504, detail="ADB connect timed out")
|
||||
|
||||
|
||||
@router.post("/api/v1/adb/disconnect", tags=["ADB"])
|
||||
async def adb_disconnect(_: AuthRequired, request: AdbConnectRequest):
|
||||
"""Disconnect a WiFi ADB device."""
|
||||
address = request.address.strip()
|
||||
if not address:
|
||||
raise HTTPException(status_code=400, detail="Address is required")
|
||||
|
||||
adb = _get_adb_path()
|
||||
logger.info(f"Disconnecting ADB device: {address}")
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
adb, "disconnect", address,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=10)
|
||||
return {"status": "disconnected", "message": stdout.decode().strip()}
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=500, detail="adb not found on PATH")
|
||||
except asyncio.TimeoutError:
|
||||
raise HTTPException(status_code=504, detail="ADB disconnect timed out")
|
||||
|
||||
|
||||
# --- Log level -----
|
||||
|
||||
_VALID_LOG_LEVELS = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||
|
||||
|
||||
@router.get("/api/v1/system/log-level", response_model=LogLevelResponse, tags=["System"])
|
||||
async def get_log_level(_: AuthRequired):
|
||||
"""Get the current root logger log level."""
|
||||
level_int = logging.getLogger().getEffectiveLevel()
|
||||
return LogLevelResponse(level=logging.getLevelName(level_int))
|
||||
|
||||
|
||||
@router.put("/api/v1/system/log-level", response_model=LogLevelResponse, tags=["System"])
|
||||
async def set_log_level(_: AuthRequired, body: LogLevelRequest):
|
||||
"""Change the root logger log level at runtime (no server restart required)."""
|
||||
level_name = body.level.upper()
|
||||
if level_name not in _VALID_LOG_LEVELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid log level '{body.level}'. Must be one of: {', '.join(sorted(_VALID_LOG_LEVELS))}",
|
||||
)
|
||||
level_int = getattr(logging, level_name)
|
||||
root = logging.getLogger()
|
||||
root.setLevel(level_int)
|
||||
# Also update all handlers so they actually emit at the new level
|
||||
for handler in root.handlers:
|
||||
handler.setLevel(level_int)
|
||||
logger.info("Log level changed to %s", level_name)
|
||||
return LogLevelResponse(level=level_name)
|
||||
@@ -403,7 +403,7 @@ async def test_template_ws(
|
||||
Config is sent as the first client message (JSON with engine_type,
|
||||
engine_config, display_index, capture_duration).
|
||||
"""
|
||||
from wled_controller.api.routes._test_helpers import (
|
||||
from wled_controller.api.routes._preview_helpers import (
|
||||
authenticate_ws_token,
|
||||
stream_capture_test,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,10 @@ automations that have a webhook condition. No API-key auth is required —
|
||||
the secret token itself authenticates the caller.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from wled_controller.api.dependencies import get_automation_engine, get_automation_store
|
||||
@@ -18,6 +21,28 @@ from wled_controller.utils import get_logger
|
||||
logger = get_logger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Simple in-memory rate limiter: 30 requests per 60-second window per IP
|
||||
# ---------------------------------------------------------------------------
|
||||
_RATE_LIMIT = 30
|
||||
_RATE_WINDOW = 60.0 # seconds
|
||||
_rate_hits: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
|
||||
def _check_rate_limit(client_ip: str) -> None:
|
||||
"""Raise 429 if *client_ip* exceeded the webhook rate limit."""
|
||||
now = time.time()
|
||||
window_start = now - _RATE_WINDOW
|
||||
# Prune timestamps outside the window
|
||||
timestamps = _rate_hits[client_ip]
|
||||
_rate_hits[client_ip] = [t for t in timestamps if t > window_start]
|
||||
if len(_rate_hits[client_ip]) >= _RATE_LIMIT:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Rate limit exceeded. Max 30 webhook requests per minute.",
|
||||
)
|
||||
_rate_hits[client_ip].append(now)
|
||||
|
||||
|
||||
class WebhookPayload(BaseModel):
|
||||
action: str = Field(description="'activate' or 'deactivate'")
|
||||
@@ -30,10 +55,13 @@ class WebhookPayload(BaseModel):
|
||||
async def handle_webhook(
|
||||
token: str,
|
||||
body: WebhookPayload,
|
||||
request: Request,
|
||||
store: AutomationStore = Depends(get_automation_store),
|
||||
engine: AutomationEngine = Depends(get_automation_engine),
|
||||
):
|
||||
"""Receive a webhook call and set the corresponding condition state."""
|
||||
_check_rate_limit(request.client.host if request.client else "unknown")
|
||||
|
||||
if body.action not in ("activate", "deactivate"):
|
||||
raise HTTPException(status_code=400, detail="action must be 'activate' or 'deactivate'")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user