fix: comprehensive security, bug, performance, and UI/UX audit
Lint & Test / test (push) Successful in 20s

Security
- Default bind 127.0.0.1; first-run bootstrap generates random api_token
  and refuses to bind non-loopback without auth unless explicitly opted in
- Path-traversal hardened: BrowserService.validate_path rejects absolute
  paths, drive letters, UNC, NUL bytes. /api/browser/{play,metadata,
  thumbnail} now require folder_id and a folder-relative path
- Pydantic validators on links: http(s) URLs only, mdi:<slug> icons only
- Scripts/callbacks/links create/update/delete gated by *_management flags
- Strict CSP, X-Frame-Options DENY, Referrer-Policy no-referrer,
  X-Content-Type-Options nosniff
- CORS locked to localhost:<port> + 127.0.0.1:<port> by default; configurable
- config.yaml writes atomic (tmp + os.replace) and 0o600 on POSIX
- Subprocesses spawned in their own process group / new session so timeout
  kills the whole tree (Windows CREATE_NEW_PROCESS_GROUP, POSIX
  start_new_session=True)
- Frontend XSS: monitor name + details escapeHtml'd; power button moved to
  delegated data-action handler; remote MDI SVGs parsed and sanitized
  (strip script/foreignObject/on*/javascript: hrefs) before innerHTML
- All dynamic URL segments now wrapped in encodeURIComponent

Bugs
- WebSocket reconnect: close previous socket before opening new, clear
  ping interval per-socket, clear reconnectTimeout up-front, retry on
  online/visibilitychange, try/catch JSON.parse
- Artwork fetch race: AbortController + generation guard
- _broadcast_after_open: initialize status, swallow per-poll errors,
  background tasks tracked in a strong-ref set with done-callback cleanup
- Audio analyzer: sticky _unavailable flag prevents infinite start/stop
  spin when no loopback device exists; cleared by set_device()
- Volume short-circuit cache invalidated when server reports remote volume
- Browser thumbnail race: per-folder generation counter + isConnected
  checks; aborts in-flight fetches on navigation
- Track-skip uses cached title instead of full WinRT status round-trip

Performance
- Linux MPRIS/pactl and /api/display DDC-CI handlers wrapped in
  asyncio.to_thread so blocking IO never stalls the event loop
- browse_directory moved off the event loop (SMB shares could freeze it)
- Windows status poll caches one asyncio loop per worker thread via
  threading.local instead of new_event_loop/close on every 0.5s tick
- broadcast() serializes JSON once and uses send_text to all clients
- Hourly thumbnail cache cleanup scheduled in lifespan (was never invoked
  — cache grew unbounded)
- Progress drag listeners attached only while dragging

Quality
- All asyncio.get_event_loop() in coroutines → get_running_loop()
- ThreadPoolExecutors shut down cleanly during lifespan teardown
- config_manager dedup: 12 near-identical methods collapsed onto generic
  _upsert/_delete helpers (~290 lines removed)
- Service worker no longer pass-throughs every fetch
- M3U playlist written via NamedTemporaryFile (no fixed-path symlink
  clobber race)
- __version__ now prefers live pyproject.toml in dev checkouts so
  pip install -e . users see the source-of-truth version, not the stale
  package-metadata version baked in at install time

UI/UX (Studio Reference)
- Green leftover focus rings (rgba(29,185,84,...)) all replaced with
  copper accent (rgba(var(--copper-rgb),...))
- Dialogs: square corners, copper top hairline, unified with editorial
  chrome
- .browser-item: transparent with copper hover border (was filled card)
- Audio device select uses var(--sans) instead of generic system font
- Mobile container padding tuned for ≤480px screens
- Breadcrumb home is a real <button> with aria-label; aria-current on root
- i18n: filled display.msg.power_*, execution.*, scripts.params.execute,
  callbacks.empty in both en + ru
This commit is contained in:
2026-05-16 13:22:46 +03:00
parent 770bba7e60
commit bcc6d40ed7
28 changed files with 1063 additions and 876 deletions
+100 -80
View File
@@ -23,6 +23,17 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/browser", tags=["browser"])
# Strong refs to background tasks so they don't get garbage-collected mid-flight.
_background_tasks: set[asyncio.Task] = set()
def _spawn_background(coro) -> asyncio.Task:
"""Schedule a background coroutine and keep a strong ref to its Task."""
task = asyncio.create_task(coro)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
return task
def _require_folder_management() -> None:
"""Raise 403 if media folder management is disabled in config."""
@@ -38,16 +49,23 @@ async def _broadcast_after_open(controller, label: str, max_wait: float = 2.0) -
Fires as a background task so the HTTP response returns immediately.
"""
status = None
try:
interval = 0.3
elapsed = 0.0
while elapsed < max_wait:
await asyncio.sleep(interval)
elapsed += interval
status = await controller.get_status()
try:
status = await controller.get_status()
except Exception as poll_err: # noqa: BLE001 — broadcast is best-effort
logger.debug("get_status during broadcast poll failed: %s", poll_err)
continue
if status.state in ("playing", "paused"):
break
if status is None:
return
status_dict = status.model_dump()
await ws_manager.broadcast({"type": "status", "data": status_dict})
logger.info(f"Broadcasted status update after opening: {label}")
@@ -74,9 +92,14 @@ class FolderUpdateRequest(BaseModel):
class PlayRequest(BaseModel):
"""Request model for playing a media file."""
"""Request model for playing a media file.
path: str = Field(..., description="Full path to the media file")
Both ``folder_id`` and ``path`` are required so the server can validate
the file lives inside a configured media folder.
"""
folder_id: str = Field(..., description="Media folder ID")
path: str = Field(..., description="Path relative to folder root")
class PlayFolderRequest(BaseModel):
@@ -128,8 +151,10 @@ async def create_folder(
"""
_require_folder_management()
try:
# Validate folder_id format (alphanumeric and underscore only)
if not request.folder_id.replace("_", "").isalnum():
# Validate folder_id format (alphanumeric and underscore only).
# Same constraint is enforced when validating paths so traversal can't
# be smuggled through the ID itself.
if not request.folder_id or not request.folder_id.replace("_", "").isalnum():
raise HTTPException(
status_code=400,
detail="Folder ID must contain only alphanumeric characters and underscores",
@@ -277,13 +302,15 @@ async def browse(
# URL decode the path
decoded_path = unquote(path)
# Browse directory
result = BrowserService.browse_directory(
folder_id=folder_id,
path=decoded_path,
offset=offset,
limit=limit,
nocache=nocache,
# Browse directory in a thread — iterdir() + stat() can block on
# network shares for many seconds; never run on the event loop.
result = await asyncio.to_thread(
BrowserService.browse_directory,
folder_id,
decoded_path,
offset,
limit,
nocache,
)
return result
@@ -307,41 +334,40 @@ async def browse(
# Metadata Endpoint
@router.get("/metadata")
async def get_metadata(
path: str = Query(..., description="Full path to media file (URL-encoded)"),
folder_id: str = Query(..., description="Media folder ID"),
path: str = Query(..., description="Path relative to folder root (URL-encoded)"),
_: str = Depends(verify_token),
):
"""Get metadata for a media file.
"""Get metadata for a media file inside a configured media folder.
Args:
path: Full path to the media file (URL-encoded).
folder_id: ID of the media folder.
path: Path relative to folder root (URL-encoded).
Returns:
Media file metadata.
Raises:
HTTPException: If file not found or metadata extraction fails.
"""
try:
# URL decode the path
decoded_path = unquote(path)
file_path = Path(decoded_path)
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
file_path = BrowserService.validate_path(folder_id, decoded_path)
if not file_path.is_file():
raise HTTPException(status_code=400, detail="Path is not a file")
if not BrowserService.is_media_file(file_path):
raise HTTPException(status_code=400, detail="File is not a media file")
# Extract metadata in executor (blocking operation)
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
metadata = await loop.run_in_executor(
None,
MetadataService.extract_metadata,
file_path,
)
return metadata
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except HTTPException:
raise
except Exception as e:
@@ -352,59 +378,47 @@ async def get_metadata(
# Thumbnail Endpoint
@router.get("/thumbnail")
async def get_thumbnail(
path: str = Query(..., description="Full path to media file (URL-encoded)"),
folder_id: str = Query(..., description="Media folder ID"),
path: str = Query(..., description="Path relative to folder root (URL-encoded)"),
size: str = Query(default="medium", description='Thumbnail size: "small" or "medium"'),
_: str = Depends(verify_token),
):
"""Get thumbnail for a media file.
Args:
path: Full path to the media file (URL-encoded).
size: Thumbnail size ("small" or "medium").
Returns:
JPEG image bytes.
Raises:
HTTPException: If file not found or thumbnail generation fails.
"""
"""Get thumbnail for a media file inside a configured media folder."""
try:
# URL decode the path
decoded_path = unquote(path)
file_path = Path(decoded_path)
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
file_path = BrowserService.validate_path(folder_id, decoded_path)
if not file_path.is_file():
raise HTTPException(status_code=400, detail="Path is not a file")
if not BrowserService.is_media_file(file_path):
raise HTTPException(status_code=400, detail="File is not a media file")
# Validate size
if size not in ("small", "medium"):
size = "medium"
# Get thumbnail
thumbnail_data = await ThumbnailService.get_thumbnail(file_path, size)
if thumbnail_data is None:
return Response(status_code=204)
# Calculate ETag (hash of path + mtime)
import hashlib
stat = file_path.stat()
etag_data = f"{file_path}:{stat.st_mtime}:{size}".encode()
etag = hashlib.md5(etag_data).hexdigest()
# Return image with caching headers
return Response(
content=thumbnail_data,
media_type="image/jpeg",
headers={
"ETag": f'"{etag}"',
"Cache-Control": "public, max-age=86400", # 24 hours
"Cache-Control": "public, max-age=86400",
},
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except HTTPException:
raise
except Exception as e:
@@ -420,44 +434,37 @@ async def play_file(
):
"""Open a media file with the default system player.
Args:
request: Play request with file path.
Returns:
Success message.
Raises:
HTTPException: If file not found or playback fails.
Requires both ``folder_id`` and a folder-relative ``path``; the resolved
file must live inside the configured media folder and be a recognized
media file. This prevents arbitrary OS-handler invocation (e.g.,
``os.startfile`` on Windows ``.lnk``/UNC paths).
"""
try:
file_path = Path(request.path)
# Validate file exists
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
decoded_path = unquote(request.path)
file_path = BrowserService.validate_path(request.folder_id, decoded_path)
if not file_path.is_file():
raise HTTPException(status_code=400, detail="Path is not a file")
# Validate file is a media file
if not BrowserService.is_media_file(file_path):
raise HTTPException(status_code=400, detail="File is not a media file")
# Get media controller and open file
controller = get_media_controller()
success = await controller.open_file(str(file_path))
if not success:
raise HTTPException(status_code=500, detail="Failed to open file")
# Poll until player registers with media session API (up to 2s)
asyncio.create_task(_broadcast_after_open(controller, file_path.name))
_spawn_background(_broadcast_after_open(controller, file_path.name))
return {
"success": True,
"message": f"Playing {file_path.name}",
}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e))
except HTTPException:
raise
except Exception as e:
@@ -489,26 +496,38 @@ async def play_folder(
if not full_path.is_dir():
raise HTTPException(status_code=400, detail="Path is not a directory")
# Collect all media files sorted by name
media_files = sorted(
[f for f in full_path.iterdir() if f.is_file() and BrowserService.is_media_file(f)],
key=lambda f: f.name.lower(),
)
def _scan(directory: Path) -> list[Path]:
return sorted(
(
f for f in directory.iterdir()
if f.is_file() and BrowserService.is_media_file(f)
),
key=lambda f: f.name.lower(),
)
media_files = await asyncio.to_thread(_scan, full_path)
if not media_files:
raise HTTPException(status_code=404, detail="No media files found in this folder")
# Generate M3U playlist with absolute paths and EXTINF entries
# Written to local temp dir to avoid extra SMB file handle on network shares
# Uses utf-8-sig (BOM) so players detect encoding properly
# Generate M3U playlist with absolute paths and EXTINF entries.
# Use NamedTemporaryFile to get a fresh per-call path — prevents
# symlink-clobber races between concurrent /play-folder requests
# and any local user pre-creating a fixed temp filename.
lines = ["#EXTM3U"]
for f in media_files:
lines.append(f"#EXTINF:-1,{f.stem}")
lines.append(str(f))
m3u_content = "\r\n".join(lines) + "\r\n"
m3u_content = ("\r\n".join(lines) + "\r\n").encode("utf-8-sig")
playlist_path = Path(tempfile.gettempdir()) / ".media_server_playlist.m3u"
playlist_path.write_text(m3u_content, encoding="utf-8-sig")
with tempfile.NamedTemporaryFile(
mode="wb",
prefix=".media_server_playlist_",
suffix=".m3u",
delete=False,
) as f:
f.write(m3u_content)
playlist_path = Path(f.name)
# Open playlist with default player
controller = get_media_controller()
@@ -517,8 +536,9 @@ async def play_folder(
if not success:
raise HTTPException(status_code=500, detail="Failed to open playlist")
# Poll until player registers with media session API (up to 2s)
asyncio.create_task(_broadcast_after_open(controller, f"playlist ({len(media_files)} files)"))
_spawn_background(
_broadcast_after_open(controller, f"playlist ({len(media_files)} files)")
)
return {
"success": True,
+27 -4
View File
@@ -3,6 +3,7 @@
import asyncio
import logging
import subprocess
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any
@@ -21,6 +22,22 @@ logger = logging.getLogger(__name__)
_callback_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="callback")
def shutdown_callback_executor() -> None:
"""Shut down the callback executor cleanly on application teardown."""
_callback_executor.shutdown(wait=False, cancel_futures=True)
def _require_callbacks_management() -> None:
if not settings.callbacks_management:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
"Callbacks management is disabled. Set callbacks_management: true"
" in config.yaml to enable."
),
)
class CallbackInfo(BaseModel):
"""Information about a configured callback."""
@@ -131,7 +148,7 @@ async def execute_callback(
try:
# Execute in dedicated thread pool to not block the default executor
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
_callback_executor,
lambda: _run_callback(
@@ -178,6 +195,11 @@ def _run_callback(
Dict with exit_code, stdout, stderr, execution_time
"""
start_time = time.time()
popen_kwargs: dict[str, Any] = {}
if sys.platform == "win32":
popen_kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
else:
popen_kwargs["start_new_session"] = True
try:
result = subprocess.run(
command,
@@ -186,6 +208,7 @@ def _run_callback(
capture_output=True,
text=True,
timeout=timeout,
**popen_kwargs,
)
execution_time = time.time() - start_time
return {
@@ -230,7 +253,7 @@ async def create_callback(
Raises:
HTTPException: If callback already exists or name is invalid.
"""
# Validate name
_require_callbacks_management()
_validate_callback_name(callback_name)
# Check if callback already exists
@@ -278,7 +301,7 @@ async def update_callback(
Raises:
HTTPException: If callback does not exist.
"""
# Validate name
_require_callbacks_management()
_validate_callback_name(callback_name)
# Check if callback exists
@@ -324,7 +347,7 @@ async def delete_callback(
Raises:
HTTPException: If callback does not exist.
"""
# Validate name
_require_callbacks_management()
_validate_callback_name(callback_name)
# Check if callback exists
+16 -13
View File
@@ -1,5 +1,6 @@
"""Display brightness, power, contrast, input-source, color-preset and picture-mode API."""
import asyncio
import logging
from fastapi import APIRouter, Depends
@@ -45,19 +46,21 @@ class PictureModeRequest(BaseModel):
code: int = Field(ge=0, le=255)
# DDC/CI hardware writes open a per-monitor handle and can take seconds —
# every public endpoint dispatches into a worker thread so the event loop
# stays responsive.
@router.get("/monitors")
async def get_monitors(
refresh: bool = False,
rediscover: bool = False,
_: str = Depends(verify_token),
) -> list[dict]:
"""List all connected monitors with their reported DDC/CI capabilities.
- `refresh=true` bypasses the response TTL cache (re-reads current state).
- `rediscover=true` also drops the per-monitor capability cache, forcing
a full DDC/CI capability probe. Use after a monitor hot-swap.
"""
monitors = list_monitors(force_refresh=refresh, rediscover=rediscover)
"""List all connected monitors with their reported DDC/CI capabilities."""
monitors = await asyncio.to_thread(
list_monitors, force_refresh=refresh, rediscover=rediscover
)
logger.debug("Found %d monitors", len(monitors))
return [m.to_dict() for m in monitors]
@@ -67,7 +70,7 @@ async def set_monitor_brightness(
monitor_id: int, request: BrightnessRequest, _: str = Depends(verify_token)
) -> dict:
"""Set brightness for a specific monitor."""
success = set_brightness(monitor_id, request.brightness)
success = await asyncio.to_thread(set_brightness, monitor_id, request.brightness)
if success:
logger.info("Set monitor %d brightness to %d", monitor_id, request.brightness)
return {"success": success}
@@ -79,7 +82,7 @@ async def set_monitor_power(
) -> dict:
"""Turn a monitor on or off."""
action = "on" if request.on else "off"
success = set_power(monitor_id, request.on)
success = await asyncio.to_thread(set_power, monitor_id, request.on)
if success:
logger.info("Set monitor %d power %s", monitor_id, action)
return {"success": success}
@@ -90,7 +93,7 @@ async def set_monitor_contrast(
monitor_id: int, request: ContrastRequest, _: str = Depends(verify_token)
) -> dict:
"""Set DDC/CI contrast for a specific monitor."""
success = set_contrast(monitor_id, request.contrast)
success = await asyncio.to_thread(set_contrast, monitor_id, request.contrast)
if success:
logger.info("Set monitor %d contrast to %d", monitor_id, request.contrast)
return {"success": success}
@@ -101,7 +104,7 @@ async def set_monitor_input_source(
monitor_id: int, request: InputSourceRequest, _: str = Depends(verify_token)
) -> dict:
"""Switch a monitor's DDC/CI input source (e.g. HDMI1, DP1)."""
success = set_input_source(monitor_id, request.source)
success = await asyncio.to_thread(set_input_source, monitor_id, request.source)
if success:
logger.info("Set monitor %d input source to %s", monitor_id, request.source)
return {"success": success}
@@ -112,7 +115,7 @@ async def set_monitor_color_preset(
monitor_id: int, request: ColorPresetRequest, _: str = Depends(verify_token)
) -> dict:
"""Apply a DDC/CI color preset (color temperature) to the monitor."""
success = set_color_preset(monitor_id, request.preset)
success = await asyncio.to_thread(set_color_preset, monitor_id, request.preset)
if success:
logger.info("Set monitor %d color preset to %s", monitor_id, request.preset)
return {"success": success}
@@ -123,7 +126,7 @@ async def set_monitor_picture_mode(
monitor_id: int, request: PictureModeRequest, _: str = Depends(verify_token)
) -> dict:
"""Apply a DDC/CI picture/scene mode (VCP 0xDC) by raw code."""
success = set_picture_mode(monitor_id, request.code)
success = await asyncio.to_thread(set_picture_mode, monitor_id, request.code)
if success:
logger.info("Set monitor %d picture mode to code %d", monitor_id, request.code)
return {"success": success}
+49 -13
View File
@@ -3,9 +3,10 @@
import logging
import re
from typing import Any
from urllib.parse import urlparse
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from ..auth import verify_token
from ..config import LinkConfig, settings
@@ -15,6 +16,35 @@ from ..services.websocket_manager import ws_manager
router = APIRouter(prefix="/api/links", tags=["links"])
logger = logging.getLogger(__name__)
# Only allow MDI iconify slugs and safe `http(s)`-ish URLs through the API.
_MDI_ICON_RE = re.compile(r"^mdi:[a-z0-9][a-z0-9-]{0,63}$")
_ALLOWED_URL_SCHEMES = {"http", "https"}
def _validate_url(url: str) -> str:
"""Ensure the URL is well-formed http(s) — no ``javascript:`` etc."""
parsed = urlparse(url)
if parsed.scheme.lower() not in _ALLOWED_URL_SCHEMES:
raise ValueError("URL must start with http:// or https://")
if not parsed.netloc:
raise ValueError("URL must include a host")
return url
def _validate_icon(icon: str) -> str:
"""Restrict icon names to safe Material Design Icons slugs."""
if not _MDI_ICON_RE.match(icon):
raise ValueError("Icon must be of the form 'mdi:<lowercase-slug>'")
return icon
def _require_links_management() -> None:
if not settings.links_management:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Links management is disabled. Set links_management: true in config.yaml to enable.",
)
class LinkInfo(BaseModel):
"""Information about a configured link."""
@@ -29,22 +59,25 @@ class LinkInfo(BaseModel):
class LinkCreateRequest(BaseModel):
"""Request model for creating or updating a link."""
url: str = Field(..., description="URL to open", min_length=1)
url: str = Field(..., description="URL to open", min_length=1, max_length=2048)
icon: str = Field(default="mdi:link", description="MDI icon name (e.g., 'mdi:led-strip-variant')")
label: str = Field(default="", description="Tooltip text")
description: str = Field(default="", description="Optional description")
label: str = Field(default="", description="Tooltip text", max_length=128)
description: str = Field(default="", description="Optional description", max_length=512)
@field_validator("url")
@classmethod
def _check_url(cls, v: str) -> str:
return _validate_url(v)
@field_validator("icon")
@classmethod
def _check_icon(cls, v: str) -> str:
return _validate_icon(v)
def _validate_link_name(name: str) -> None:
"""Validate link name.
Args:
name: Link name to validate.
Raises:
HTTPException: If name is invalid.
"""
if not re.match(r'^[a-zA-Z0-9_]+$', name):
"""Validate link name."""
if not re.match(r"^[a-zA-Z0-9_]+$", name):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Link name must contain only letters, numbers, and underscores",
@@ -90,6 +123,7 @@ async def create_link(
Returns:
Success response with link name.
"""
_require_links_management()
_validate_link_name(link_name)
if link_name in settings.links:
@@ -129,6 +163,7 @@ async def update_link(
Returns:
Success response with link name.
"""
_require_links_management()
_validate_link_name(link_name)
if link_name not in settings.links:
@@ -166,6 +201,7 @@ async def delete_link(
Returns:
Success response with link name.
"""
_require_links_management()
_validate_link_name(link_name)
if link_name not in settings.links:
+2 -2
View File
@@ -27,7 +27,7 @@ def _run_callback(callback_name: str) -> None:
try:
callback = settings.callbacks[callback_name]
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
None,
lambda: _run_script(
@@ -285,7 +285,7 @@ async def visualizer_devices(_: str = Depends(verify_token)) -> list[dict[str, s
"""List available loopback audio devices for the visualizer."""
from ..services.audio_analyzer import AudioAnalyzer
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, AudioAnalyzer.list_loopback_devices)
+32 -5
View File
@@ -2,8 +2,10 @@
import asyncio
import logging
import os
import re
import subprocess
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any
@@ -23,6 +25,22 @@ _script_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="script"
logger = logging.getLogger(__name__)
def shutdown_script_executor() -> None:
"""Shut down the dedicated executor cleanly on application teardown."""
_script_executor.shutdown(wait=False, cancel_futures=True)
def _require_scripts_management() -> None:
if not settings.scripts_management:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
"Scripts management is disabled. Set scripts_management: true"
" in config.yaml to enable."
),
)
class ScriptExecuteRequest(BaseModel):
"""Request model for script execution with optional parameters."""
@@ -233,7 +251,7 @@ async def execute_script(
try:
# Execute in dedicated thread pool to not block the default executor
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(
_script_executor,
lambda: _run_script(
@@ -285,8 +303,16 @@ def _run_script(
start_time = time.time()
env = None
if extra_env:
import os
env = {**os.environ, **extra_env}
# Spawn the script in its own process group / job so a timeout kills the
# whole tree, not just the shell (POSIX) and not just the parent (Windows).
popen_kwargs: dict[str, Any] = {}
if sys.platform == "win32":
popen_kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
else:
popen_kwargs["start_new_session"] = True
try:
result = subprocess.run(
command,
@@ -296,6 +322,7 @@ def _run_script(
text=True,
timeout=timeout,
env=env,
**popen_kwargs,
)
execution_time = time.time() - start_time
return {
@@ -455,7 +482,7 @@ async def create_script(
Raises:
HTTPException: If script already exists or name is invalid.
"""
# Validate name
_require_scripts_management()
_validate_script_name(script_name)
# Check if script already exists
@@ -511,7 +538,7 @@ async def update_script(
Raises:
HTTPException: If script does not exist.
"""
# Validate name
_require_scripts_management()
_validate_script_name(script_name)
# Check if script exists
@@ -565,7 +592,7 @@ async def delete_script(
Raises:
HTTPException: If script does not exist.
"""
# Validate name
_require_scripts_management()
_validate_script_name(script_name)
# Check if script exists