fix: production-readiness hardening — security, perf, a11y, observability
Lint & Test / test (push) Successful in 20s
Lint & Test / test (push) Successful in 20s
Security - Default scripts_management, callbacks_management, links_management, and media_folders_management to False so a leaked token cannot escalate to RCE through admin CRUD endpoints. - TokenSpec + scope hierarchy (read | control | admin); legacy bare-string api_tokens entries promote to admin for back-compat. Management endpoints now require admin scope. - WebSocket subprotocol auth (Sec-WebSocket-Protocol: media-server.token.<T>) preferred over ?token= query so the token no longer lands in URL/history/ Referer; query fallback retained for HA integration back-compat. - Origin allow-list check on the WS endpoint (CSWSH defence). - In-process token-bucket rate limiter: 5/min for failed auths, 10/min for /api/scripts/execute and /api/callbacks/execute. - shell=False subprocess path (shlex.split) + per-parameter regex `pattern` in ScriptParameterConfig to harden shell=true scripts against parameter injection (Windows cmd.exe env-var expansion). - CSP gains form-action, worker-src, manifest-src directives. - Refuse cors_origins=["*"] at startup; strip token=... from uvicorn access logs; validate Gitea release tag against strict SemVer regex. - noopener noreferrer + no-referrer referrerpolicy on every outbound link. - icacls hardening of config.yaml on Windows (current user + SYSTEM + Administrators only); 0600 still enforced on POSIX. - WS volume handler clamps input and never drops the socket on bad messages. Performance - Album-art read in windows_media gated by track key — was decoding the WinRT thumbnail twice per second regardless of track changes. - /api/media/artwork returns content-derived ETag + Cache-Control so the browser sends If-None-Match and gets 304s on track repeats. - Foreground-service ctypes argtypes hoisted to one-time module init (was re-declaring ~14 prototypes per probe). - display_service _static_cache keyed by (edid_hash, ...) tuple with eviction of disappeared monitors — fixes stale capabilities on hot-plug swaps where the new topology has the same monitor count. - Visualizer rAF loop paused on document.hidden, resumed on visible. Reliability / bug fixes - Lifespan rewritten as try/yield/finally so a partial-startup failure cannot orphan background tasks or executors. - _run_callback in routes/media.py keeps a strong task ref (GC-safe) and uses the dedicated callback executor instead of the default pool. - macos_media.set_volume() no longer always returns True. - TrayManager._restart_requested initialised in __init__; set before signalling exit so the main thread observes it correctly. - Missing static_dir now logs a WARNING instead of silent UI disable. UX / accessibility / PWA - manifest.json theme_color and background_color match the Studio Reference base (#0E0D0B); added id and scope for PWA installability. - ARIA on mini-player icon buttons; inner SVGs marked aria-hidden. - OS mediaSession API wired so headset / lockscreen / Bluetooth buttons drive play/pause/next/prev/seek and show track metadata + artwork. Observability - X-Request-ID middleware (accept upstream id if it matches a safe regex, otherwise UUID4); request_id_var added to ContextVars and included in every log line alongside the token label. - Audit log (append-only JSONL) for every script + callback execution, including the on_play/on_pause/etc. event callbacks. Background-thread writer; queue capped; flushed in lifespan teardown. Deployment - proxy_headers + forwarded_allow_ips plumbed through Settings → uvicorn.Config for reverse-proxy installs. - HTTPS support via ssl_certfile + ssl_keyfile (+ optional password); startup refuses to launch with only one of the pair set. - Thumbnail cache moved from project-root .cache to %LOCALAPPDATA%/media-server/cache (Windows) and $XDG_CACHE_HOME/media-server/thumbnails (POSIX). Tests - 35 new tests across auth scopes, rate limiter, browser path traversal (../ NUL UNC absolute), script-param validation incl. regex, Gitea tag whitelist, config atomic write + POSIX perms. 47 passed / 4 skipped.
This commit is contained in:
+34
-2
@@ -13,6 +13,8 @@ security = HTTPBearer(auto_error=False)
|
||||
|
||||
# Context variable to store current request's token label
|
||||
token_label_var: ContextVar[str] = ContextVar("token_label", default="unknown")
|
||||
# Per-request correlation ID — generated in middleware if upstream didn't send one.
|
||||
request_id_var: ContextVar[str] = ContextVar("request_id", default="-")
|
||||
|
||||
|
||||
def auth_enabled() -> bool:
|
||||
@@ -29,12 +31,42 @@ def get_token_label(token: str) -> Optional[str]:
|
||||
Returns:
|
||||
The label for the token, or None if invalid
|
||||
"""
|
||||
for label, stored_token in settings.api_tokens.items():
|
||||
if secrets.compare_digest(stored_token, token):
|
||||
for label, spec in settings.api_tokens.items():
|
||||
if secrets.compare_digest(spec.token, token):
|
||||
return label
|
||||
return None
|
||||
|
||||
|
||||
def token_has_scope(label: str, required: str) -> bool:
|
||||
"""Whether the token identified by `label` grants `required` scope."""
|
||||
spec = settings.api_tokens.get(label)
|
||||
if spec is None:
|
||||
# Unknown label = no auth or anonymous; treat as full access only
|
||||
# when auth is disabled entirely (matches existing behaviour).
|
||||
return not auth_enabled()
|
||||
return spec.grants(required)
|
||||
|
||||
|
||||
def require_scope(scope: str):
|
||||
"""Build a FastAPI dependency that enforces the given scope.
|
||||
|
||||
Use as ``Depends(require_scope("admin"))`` on management endpoints. When
|
||||
auth is disabled the dependency is a no-op (anonymous access).
|
||||
"""
|
||||
|
||||
async def _checker(label: str = Depends(verify_token)) -> str:
|
||||
if not auth_enabled():
|
||||
return label
|
||||
if not token_has_scope(label, scope):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Token '{label}' lacks required scope: {scope}",
|
||||
)
|
||||
return label
|
||||
|
||||
return _checker
|
||||
|
||||
|
||||
async def verify_token(
|
||||
request: Request,
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
|
||||
+157
-11
@@ -7,12 +7,49 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Token scopes form a strict hierarchy: admin > control > read. Helper utility
|
||||
# used by both auth.py and the validator below.
|
||||
SCOPE_HIERARCHY: dict[str, frozenset[str]] = {
|
||||
"read": frozenset({"read"}),
|
||||
"control": frozenset({"read", "control"}),
|
||||
"admin": frozenset({"read", "control", "admin"}),
|
||||
}
|
||||
ALL_SCOPES: frozenset[str] = frozenset(SCOPE_HIERARCHY.keys())
|
||||
|
||||
|
||||
class TokenSpec(BaseModel):
|
||||
"""Per-token authentication entry with explicit scopes."""
|
||||
|
||||
token: str = Field(..., min_length=8, description="Secret token value")
|
||||
scopes: list[str] = Field(
|
||||
default_factory=lambda: ["admin"],
|
||||
description="Granted scopes (subset of read|control|admin).",
|
||||
)
|
||||
|
||||
@field_validator("scopes")
|
||||
@classmethod
|
||||
def _validate_scopes(cls, v: list[str]) -> list[str]:
|
||||
if not v:
|
||||
raise ValueError("scopes must list at least one of read|control|admin")
|
||||
unknown = set(v) - ALL_SCOPES
|
||||
if unknown:
|
||||
raise ValueError(f"unknown scopes: {sorted(unknown)}; valid={sorted(ALL_SCOPES)}")
|
||||
return v
|
||||
|
||||
def grants(self, required: str) -> bool:
|
||||
"""Whether this token grants the requested scope (with hierarchy expansion)."""
|
||||
granted: set[str] = set()
|
||||
for s in self.scopes:
|
||||
granted |= SCOPE_HIERARCHY.get(s, frozenset({s}))
|
||||
return required in granted
|
||||
|
||||
|
||||
class MediaFolderConfig(BaseModel):
|
||||
"""Configuration for a media folder."""
|
||||
|
||||
@@ -48,6 +85,13 @@ class ScriptParameterConfig(BaseModel):
|
||||
options: Optional[list[str]] = Field(
|
||||
default=None, description="Allowed values (select type only)"
|
||||
)
|
||||
pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional regex (Python flavour) that string-typed values must match."
|
||||
" Use to harden parameters that flow into shell=true scripts."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ScriptConfig(BaseModel):
|
||||
@@ -108,19 +152,84 @@ class Settings(BaseSettings):
|
||||
),
|
||||
)
|
||||
|
||||
# Reverse-proxy deployment: when serving the API behind nginx/Caddy/Traefik,
|
||||
# uvicorn must trust the X-Forwarded-* headers from the proxy so that the
|
||||
# `Origin` allow-list, request URLs, and logs reflect the public-facing
|
||||
# values. Off by default — only enable when there's a real proxy in front
|
||||
# (otherwise clients can spoof their own IP).
|
||||
proxy_headers: bool = Field(
|
||||
default=False,
|
||||
description="Honor X-Forwarded-For / X-Forwarded-Proto from upstream proxy.",
|
||||
)
|
||||
forwarded_allow_ips: str = Field(
|
||||
default="127.0.0.1",
|
||||
description=(
|
||||
"Comma-separated IPs / CIDRs that uvicorn should trust X-Forwarded-* from."
|
||||
" Use '*' to trust all (only safe when bound to a private interface)."
|
||||
),
|
||||
)
|
||||
|
||||
# HTTPS / TLS. Both must be set together to enable TLS; if only one is set
|
||||
# the server refuses to start. Use `mkcert` or letsencrypt to generate the
|
||||
# pair; the server reads them at startup.
|
||||
ssl_certfile: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to TLS certificate (PEM). Pair with ssl_keyfile.",
|
||||
)
|
||||
ssl_keyfile: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to TLS private key (PEM). Pair with ssl_certfile.",
|
||||
)
|
||||
ssl_keyfile_password: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional password for the private key if encrypted.",
|
||||
)
|
||||
|
||||
# Admin-grade operations (script / callback / link / folder create/update/delete).
|
||||
# When True the same token used for read/play can also persist arbitrary shell
|
||||
# commands. Disable to make the API read+execute only.
|
||||
scripts_management: bool = Field(default=True, description="Allow scripts CRUD via API")
|
||||
callbacks_management: bool = Field(default=True, description="Allow callbacks CRUD via API")
|
||||
links_management: bool = Field(default=True, description="Allow links CRUD via API")
|
||||
# commands. Default False so a single leaked token cannot escalate to RCE; opt
|
||||
# in explicitly to manage scripts/callbacks/links via the Web UI.
|
||||
scripts_management: bool = Field(default=False, description="Allow scripts CRUD via API")
|
||||
callbacks_management: bool = Field(default=False, description="Allow callbacks CRUD via API")
|
||||
links_management: bool = Field(default=False, description="Allow links CRUD via API")
|
||||
|
||||
# Authentication (empty = auth disabled, anyone can access the API)
|
||||
api_tokens: dict[str, str] = Field(
|
||||
# Authentication (empty = auth disabled, anyone can access the API).
|
||||
#
|
||||
# Each entry can be either:
|
||||
# • a bare string (legacy form, treated as scopes = ["admin"] for back-compat), OR
|
||||
# • a mapping with explicit scopes, e.g.
|
||||
# "ha": {token: "<token>", scopes: ["read", "control"]}
|
||||
# "kiosk": {token: "<token>", scopes: ["read"]}
|
||||
# "ops": {token: "<token>", scopes: ["admin"]}
|
||||
#
|
||||
# Available scopes:
|
||||
# read — GET /api/* (status, list, browse) but no state-changing calls.
|
||||
# control — read + media transport, display/audio, script EXECUTE, callback EXECUTE.
|
||||
# admin — control + CRUD on scripts/callbacks/links/folders.
|
||||
#
|
||||
# Validation normalises both forms to TokenSpec at load time.
|
||||
api_tokens: dict[str, TokenSpec] = Field(
|
||||
default_factory=dict,
|
||||
description="Named API tokens for access control (label: token pairs). Empty = no auth.",
|
||||
description=(
|
||||
"Named API tokens. Value can be a bare token string (= admin scope) or"
|
||||
" a {token, scopes} mapping. See TokenSpec for scope definitions."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("api_tokens", mode="before")
|
||||
@classmethod
|
||||
def _normalise_tokens(cls, v):
|
||||
"""Accept legacy `label: <bare-token>` form and promote to TokenSpec."""
|
||||
if not isinstance(v, dict):
|
||||
return v
|
||||
out: dict[str, dict | TokenSpec] = {}
|
||||
for label, entry in v.items():
|
||||
if isinstance(entry, str):
|
||||
out[label] = {"token": entry, "scopes": ["admin"]}
|
||||
else:
|
||||
out[label] = entry
|
||||
return out
|
||||
|
||||
# Media controller settings
|
||||
poll_interval: float = Field(
|
||||
default=1.0, description="Media status poll interval in seconds"
|
||||
@@ -156,7 +265,7 @@ class Settings(BaseSettings):
|
||||
description="Media folders available for browsing in the media browser",
|
||||
)
|
||||
media_folders_management: bool = Field(
|
||||
default=True,
|
||||
default=False,
|
||||
description="Allow adding, editing, and deleting media folders from the Web UI",
|
||||
)
|
||||
|
||||
@@ -263,8 +372,11 @@ def generate_default_config(path: Optional[Path] = None) -> Path:
|
||||
config = {
|
||||
"host": "127.0.0.1",
|
||||
"port": 8765,
|
||||
# Default token grants "admin" scope (full access). To create a
|
||||
# read-only or control-only token, add a second entry:
|
||||
# ha_readonly: {token: "<token>", scopes: ["read"]}
|
||||
"api_tokens": {
|
||||
"default": default_token,
|
||||
"default": {"token": default_token, "scopes": ["admin"]},
|
||||
},
|
||||
"poll_interval": 1.0,
|
||||
"log_level": "INFO",
|
||||
@@ -298,8 +410,16 @@ def _write_yaml_atomic(path: Path, data: dict) -> None:
|
||||
|
||||
|
||||
def _restrict_config_perms(path: Path) -> None:
|
||||
"""On POSIX, ensure config file is readable only by owner (0600)."""
|
||||
"""Ensure config file is readable only by its owner.
|
||||
|
||||
POSIX → ``chmod 0600``. On Windows the default NTFS ACL leaves the file
|
||||
readable by every interactive user on the machine (Users group has Read),
|
||||
which is bad given the file stores plaintext API tokens. Use ``icacls`` to
|
||||
grant exclusive access to the current user + SYSTEM + Administrators and
|
||||
strip inheritance.
|
||||
"""
|
||||
if os.name == "nt":
|
||||
_restrict_config_perms_windows(path)
|
||||
return
|
||||
try:
|
||||
os.chmod(path, 0o600)
|
||||
@@ -308,5 +428,31 @@ def _restrict_config_perms(path: Path) -> None:
|
||||
logger.debug("Could not chmod %s", path, exc_info=True)
|
||||
|
||||
|
||||
def _restrict_config_perms_windows(path: Path) -> None:
|
||||
"""Apply restrictive NTFS ACL to a config file (Windows only)."""
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
username = os.environ.get("USERNAME") or os.environ.get("USER")
|
||||
if not username:
|
||||
logger.debug("Cannot detect current user; skipping icacls hardening")
|
||||
return
|
||||
# Disable inheritance and remove every existing ACE, then grant access
|
||||
# only to current user, SYSTEM, and Administrators. /Q suppresses
|
||||
# progress output; /C lets per-file errors not abort the batch.
|
||||
subprocess.run(
|
||||
["icacls", str(path), "/inheritance:r"],
|
||||
check=False, capture_output=True, timeout=5,
|
||||
)
|
||||
for principal in (username, "SYSTEM", "Administrators"):
|
||||
subprocess.run(
|
||||
["icacls", str(path), "/grant:r", f"{principal}:(R,W)"],
|
||||
check=False, capture_output=True, timeout=5,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired, OSError):
|
||||
# `icacls` missing or sandboxed — leave the default ACL in place.
|
||||
logger.debug("icacls hardening failed for %s", path, exc_info=True)
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings.load_from_yaml()
|
||||
|
||||
+247
-90
@@ -15,7 +15,7 @@ from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from . import __version__
|
||||
from .auth import get_token_label, token_label_var
|
||||
from .auth import get_token_label, request_id_var, token_label_var
|
||||
from .config import generate_default_config, get_config_dir, settings
|
||||
from .routes import (
|
||||
audio_router,
|
||||
@@ -33,10 +33,34 @@ from .services.websocket_manager import ws_manager
|
||||
|
||||
|
||||
class TokenLabelFilter(logging.Filter):
|
||||
"""Add token label to log records."""
|
||||
"""Add token label + request_id to log records."""
|
||||
|
||||
def filter(self, record):
|
||||
record.token_label = token_label_var.get("unknown")
|
||||
record.request_id = request_id_var.get("-")
|
||||
return True
|
||||
|
||||
|
||||
class _StripTokenQueryFilter(logging.Filter):
|
||||
"""Strip `token=...` from query strings before they hit the access log.
|
||||
|
||||
uvicorn's default access log format includes the full request line, so
|
||||
`/api/media/artwork?token=SECRET` would otherwise be persisted verbatim
|
||||
in stdout/journald/file sinks.
|
||||
"""
|
||||
|
||||
import re as _re
|
||||
|
||||
_TOKEN_RE = _re.compile(r"([?&])token=[^&\s\"']+")
|
||||
|
||||
def filter(self, record): # type: ignore[override]
|
||||
if isinstance(record.args, tuple):
|
||||
record.args = tuple(
|
||||
self._TOKEN_RE.sub(r"\1token=REDACTED", a) if isinstance(a, str) else a
|
||||
for a in record.args
|
||||
)
|
||||
if isinstance(record.msg, str) and "token=" in record.msg:
|
||||
record.msg = self._TOKEN_RE.sub(r"\1token=REDACTED", record.msg)
|
||||
return True
|
||||
|
||||
|
||||
@@ -49,17 +73,34 @@ def setup_logging():
|
||||
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level.upper()),
|
||||
format="%(asctime)s - %(name)s - [%(token_label)s] - %(levelname)s - %(message)s",
|
||||
format=(
|
||||
"%(asctime)s - %(name)s - [%(token_label)s] [%(request_id)s]"
|
||||
" - %(levelname)s - %(message)s"
|
||||
),
|
||||
handlers=[handler],
|
||||
)
|
||||
|
||||
# Suppress noisy third-party loggers
|
||||
logging.getLogger("screen_brightness_control").setLevel(logging.ERROR)
|
||||
|
||||
# Make sure the uvicorn access log never persists tokens leaked into the
|
||||
# query string (the artwork + WS endpoints accept `?token=` for browser
|
||||
# compatibility — see verify_token_or_query).
|
||||
strip_filter = _StripTokenQueryFilter()
|
||||
for name in ("uvicorn.access", "uvicorn"):
|
||||
logging.getLogger(name).addFilter(strip_filter)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan handler."""
|
||||
"""Application lifespan handler.
|
||||
|
||||
All long-lived resources started during startup are kept in local refs and
|
||||
torn down in a `finally:` so a partial-startup failure cannot orphan tasks
|
||||
or thread pools.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Media Server starting on {settings.host}:{settings.port}")
|
||||
@@ -71,92 +112,125 @@ async def lifespan(app: FastAPI):
|
||||
else:
|
||||
logger.warning("No API tokens configured — authentication is DISABLED")
|
||||
|
||||
# Start WebSocket status monitor
|
||||
controller = get_media_controller()
|
||||
await ws_manager.start_status_monitor(controller.get_status)
|
||||
logger.info("WebSocket status monitor started")
|
||||
|
||||
# Start update checker
|
||||
update_checker = None
|
||||
if settings.update_check_enabled:
|
||||
from .services.gitea_release_provider import GiteaReleaseProvider
|
||||
from .services.update_checker import UpdateChecker
|
||||
|
||||
provider = GiteaReleaseProvider()
|
||||
update_checker = UpdateChecker(provider, __version__)
|
||||
await update_checker.start(settings.update_check_interval)
|
||||
# Store globally so health endpoint can access cached result
|
||||
app.state.update_checker = update_checker
|
||||
|
||||
# Schedule periodic thumbnail cache cleanup so the 500 MB cap is actually
|
||||
# enforced. Runs once at startup and then hourly until shutdown.
|
||||
from .services.thumbnail_service import ThumbnailService
|
||||
|
||||
async def _thumbnail_cleanup_loop() -> None:
|
||||
while True:
|
||||
try:
|
||||
await asyncio.to_thread(ThumbnailService.cleanup_cache)
|
||||
except Exception as e:
|
||||
logger.warning("Thumbnail cache cleanup failed: %s", e)
|
||||
try:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
import asyncio
|
||||
cleanup_task = asyncio.create_task(_thumbnail_cleanup_loop())
|
||||
|
||||
# Register audio visualizer (capture starts on-demand when clients subscribe)
|
||||
cleanup_task: asyncio.Task | None = None
|
||||
analyzer = None
|
||||
if settings.visualizer_enabled:
|
||||
from .services.audio_analyzer import get_audio_analyzer
|
||||
status_monitor_started = False
|
||||
|
||||
analyzer = get_audio_analyzer(
|
||||
num_bins=settings.visualizer_bins,
|
||||
target_fps=settings.visualizer_fps,
|
||||
device_name=settings.visualizer_device,
|
||||
)
|
||||
if analyzer.available:
|
||||
await ws_manager.start_audio_monitor(analyzer)
|
||||
logger.info("Audio visualizer available (capture on-demand)")
|
||||
else:
|
||||
logger.info("Audio visualizer unavailable (install soundcard + numpy)")
|
||||
|
||||
yield
|
||||
|
||||
# Stop update checker
|
||||
if update_checker is not None:
|
||||
await update_checker.stop()
|
||||
|
||||
# Cancel periodic thumbnail cleanup
|
||||
cleanup_task.cancel()
|
||||
try:
|
||||
await cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Start WebSocket status monitor
|
||||
controller = get_media_controller()
|
||||
await ws_manager.start_status_monitor(controller.get_status)
|
||||
status_monitor_started = True
|
||||
logger.info("WebSocket status monitor started")
|
||||
|
||||
# Stop audio visualizer
|
||||
await ws_manager.stop_audio_monitor()
|
||||
if analyzer and analyzer.running:
|
||||
analyzer.stop()
|
||||
# Start update checker
|
||||
if settings.update_check_enabled:
|
||||
from .services.gitea_release_provider import GiteaReleaseProvider
|
||||
from .services.update_checker import UpdateChecker
|
||||
|
||||
# Stop WebSocket status monitor
|
||||
await ws_manager.stop_status_monitor()
|
||||
provider = GiteaReleaseProvider()
|
||||
update_checker = UpdateChecker(provider, __version__)
|
||||
await update_checker.start(settings.update_check_interval)
|
||||
# Store globally so health endpoint can access cached result
|
||||
app.state.update_checker = update_checker
|
||||
|
||||
# Shut down dedicated thread pools so pending scripts don't leak threads
|
||||
from .routes.callbacks import shutdown_callback_executor
|
||||
from .routes.scripts import shutdown_script_executor
|
||||
# Schedule periodic thumbnail cache cleanup so the 500 MB cap is actually
|
||||
# enforced. Runs once at startup and then hourly until shutdown.
|
||||
from .services.thumbnail_service import ThumbnailService
|
||||
|
||||
shutdown_script_executor()
|
||||
shutdown_callback_executor()
|
||||
async def _thumbnail_cleanup_loop() -> None:
|
||||
while True:
|
||||
try:
|
||||
await asyncio.to_thread(ThumbnailService.cleanup_cache)
|
||||
except Exception as e:
|
||||
logger.warning("Thumbnail cache cleanup failed: %s", e)
|
||||
try:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
# Clean up platform-specific resources
|
||||
import platform as _platform
|
||||
if _platform.system() == "Windows":
|
||||
from .services.windows_media import shutdown_executor
|
||||
shutdown_executor()
|
||||
cleanup_task = asyncio.create_task(_thumbnail_cleanup_loop())
|
||||
|
||||
logger.info("Media Server shutting down")
|
||||
# Register audio visualizer (capture starts on-demand when clients subscribe)
|
||||
if settings.visualizer_enabled:
|
||||
from .services.audio_analyzer import get_audio_analyzer
|
||||
|
||||
analyzer = get_audio_analyzer(
|
||||
num_bins=settings.visualizer_bins,
|
||||
target_fps=settings.visualizer_fps,
|
||||
device_name=settings.visualizer_device,
|
||||
)
|
||||
if analyzer.available:
|
||||
await ws_manager.start_audio_monitor(analyzer)
|
||||
logger.info("Audio visualizer available (capture on-demand)")
|
||||
else:
|
||||
logger.info("Audio visualizer unavailable (install soundcard + numpy)")
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Stop update checker
|
||||
if update_checker is not None:
|
||||
try:
|
||||
await update_checker.stop()
|
||||
except Exception:
|
||||
logger.exception("Error stopping update checker")
|
||||
|
||||
# Cancel periodic thumbnail cleanup
|
||||
if cleanup_task is not None:
|
||||
cleanup_task.cancel()
|
||||
try:
|
||||
await cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Error awaiting thumbnail cleanup task")
|
||||
|
||||
# Stop audio visualizer
|
||||
try:
|
||||
await ws_manager.stop_audio_monitor()
|
||||
except Exception:
|
||||
logger.exception("Error stopping audio monitor")
|
||||
if analyzer and analyzer.running:
|
||||
try:
|
||||
analyzer.stop()
|
||||
except Exception:
|
||||
logger.exception("Error stopping audio analyzer")
|
||||
|
||||
# Stop WebSocket status monitor
|
||||
if status_monitor_started:
|
||||
try:
|
||||
await ws_manager.stop_status_monitor()
|
||||
except Exception:
|
||||
logger.exception("Error stopping status monitor")
|
||||
|
||||
# Shut down dedicated thread pools so pending scripts don't leak threads
|
||||
try:
|
||||
from .routes.callbacks import shutdown_callback_executor
|
||||
from .routes.scripts import shutdown_script_executor
|
||||
|
||||
shutdown_script_executor()
|
||||
shutdown_callback_executor()
|
||||
except Exception:
|
||||
logger.exception("Error shutting down script/callback executors")
|
||||
|
||||
# Flush audit log writer
|
||||
try:
|
||||
from .services.audit_log import shutdown_audit_log
|
||||
shutdown_audit_log()
|
||||
except Exception:
|
||||
logger.exception("Error flushing audit log")
|
||||
|
||||
# Clean up platform-specific resources
|
||||
import platform as _platform
|
||||
if _platform.system() == "Windows":
|
||||
try:
|
||||
from .services.windows_media import shutdown_executor
|
||||
shutdown_executor()
|
||||
except Exception:
|
||||
logger.exception("Error shutting down windows_media executor")
|
||||
|
||||
logger.info("Media Server shutting down")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@@ -173,7 +247,15 @@ def create_app() -> FastAPI:
|
||||
|
||||
# CORS — restrict to same-origin by default; users that integrate the API
|
||||
# from another origin (e.g. Home Assistant on a different host) can set
|
||||
# cors_origins in config.yaml.
|
||||
# cors_origins in config.yaml. Refuse "*" outright: combined with the
|
||||
# admin endpoints this would let any origin in the universe run
|
||||
# arbitrary shell. If users genuinely need every origin, they can list
|
||||
# them explicitly.
|
||||
if any(o.strip() == "*" for o in settings.cors_origins):
|
||||
raise RuntimeError(
|
||||
"cors_origins must not contain '*' — list exact origins instead. "
|
||||
"This protects the script-execution endpoints from any-origin abuse."
|
||||
)
|
||||
cors_origins = settings.cors_origins or [
|
||||
f"http://localhost:{settings.port}",
|
||||
f"http://127.0.0.1:{settings.port}",
|
||||
@@ -186,6 +268,23 @@ def create_app() -> FastAPI:
|
||||
allow_headers=["Authorization", "Content-Type"],
|
||||
)
|
||||
|
||||
# Request correlation ID — accept upstream X-Request-ID if it's a sane
|
||||
# ASCII id, otherwise mint a fresh UUID4. Emitted on the response so
|
||||
# clients can quote it back in bug reports.
|
||||
import re
|
||||
import uuid as _uuid
|
||||
|
||||
_REQ_ID_RE = re.compile(r"^[A-Za-z0-9._\-]{1,128}$")
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_id_middleware(request: Request, call_next):
|
||||
incoming = request.headers.get("x-request-id", "")
|
||||
req_id = incoming if _REQ_ID_RE.match(incoming) else _uuid.uuid4().hex[:16]
|
||||
request_id_var.set(req_id)
|
||||
response = await call_next(request)
|
||||
response.headers["X-Request-ID"] = req_id
|
||||
return response
|
||||
|
||||
# Security headers — strict CSP for the bundled UI, disallow framing, hide referrer.
|
||||
@app.middleware("http")
|
||||
async def security_headers_middleware(request: Request, call_next):
|
||||
@@ -200,6 +299,9 @@ def create_app() -> FastAPI:
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"font-src 'self' data:; "
|
||||
"frame-ancestors 'none'; "
|
||||
"form-action 'self'; "
|
||||
"worker-src 'self'; "
|
||||
"manifest-src 'self'; "
|
||||
"base-uri 'self'"
|
||||
),
|
||||
)
|
||||
@@ -208,32 +310,63 @@ def create_app() -> FastAPI:
|
||||
response.headers.setdefault("Referrer-Policy", "no-referrer")
|
||||
return response
|
||||
|
||||
# Add token logging middleware
|
||||
# Add token logging middleware + auth-failure rate limit
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from .services.rate_limit import check as ratelimit_check
|
||||
from .services.rate_limit import get_peer
|
||||
|
||||
@app.middleware("http")
|
||||
async def token_logging_middleware(request: Request, call_next):
|
||||
"""Extract token label and set in context for logging."""
|
||||
"""Extract token label, set in context, and rate-limit failed auths."""
|
||||
if not settings.api_tokens:
|
||||
token_label_var.set("anonymous")
|
||||
else:
|
||||
token_label = "unknown"
|
||||
token_present = False
|
||||
token_valid = False
|
||||
|
||||
# Try Authorization header
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token_present = True
|
||||
token = auth_header[7:]
|
||||
label = get_token_label(token)
|
||||
if label:
|
||||
token_label = label
|
||||
token_valid = True
|
||||
|
||||
# Try query parameter (for artwork endpoint)
|
||||
elif "token" in request.query_params:
|
||||
token_present = True
|
||||
token = request.query_params["token"]
|
||||
label = get_token_label(token)
|
||||
if label:
|
||||
token_label = label
|
||||
token_valid = True
|
||||
|
||||
token_label_var.set(token_label)
|
||||
|
||||
# Brute-force gate: a peer that produces a wrong/missing token gets
|
||||
# 5 failures per minute before being throttled. Static-asset
|
||||
# requests (GET /static/*, /, /sw.js) and the docs endpoint are
|
||||
# exempt — they're served unauthenticated by design.
|
||||
if token_present and not token_valid:
|
||||
path = request.url.path
|
||||
if not (
|
||||
path == "/" or path == "/sw.js"
|
||||
or path.startswith("/static/")
|
||||
or path.startswith("/docs") or path.startswith("/openapi")
|
||||
or path.startswith("/redoc")
|
||||
):
|
||||
allowed, retry_after = ratelimit_check("auth", get_peer(request))
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Too many authentication failures"},
|
||||
headers={"Retry-After": str(int(retry_after or 60))},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
@@ -266,6 +399,11 @@ def create_app() -> FastAPI:
|
||||
async def serve_ui():
|
||||
"""Serve the Web UI."""
|
||||
return FileResponse(static_dir / "index.html")
|
||||
else:
|
||||
logging.getLogger(__name__).warning(
|
||||
"static_dir not found at %s — Web UI disabled (API only)",
|
||||
static_dir,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
@@ -316,8 +454,9 @@ def main():
|
||||
print(f"Config directory: {get_config_dir()}")
|
||||
if settings.api_tokens:
|
||||
print("\nAPI Tokens:")
|
||||
for label, token in settings.api_tokens.items():
|
||||
print(f" {label:20} {token}")
|
||||
for label, spec in settings.api_tokens.items():
|
||||
scope_str = ",".join(spec.scopes)
|
||||
print(f" {label:20} {spec.token} [scopes: {scope_str}]")
|
||||
else:
|
||||
print("\nAuthentication is DISABLED (no tokens configured)")
|
||||
return
|
||||
@@ -374,6 +513,27 @@ def main():
|
||||
|
||||
use_tray = PYSTRAY_AVAILABLE and not args.no_tray
|
||||
|
||||
# Validate TLS pair consistency before either path so we don't fail late.
|
||||
if bool(settings.ssl_certfile) ^ bool(settings.ssl_keyfile):
|
||||
_fatal(
|
||||
"ERROR: ssl_certfile and ssl_keyfile must both be set, or both unset."
|
||||
)
|
||||
|
||||
def _uvicorn_kwargs() -> dict:
|
||||
kw: dict = {
|
||||
"host": args.host,
|
||||
"port": args.port,
|
||||
"log_level": settings.log_level.lower(),
|
||||
"proxy_headers": settings.proxy_headers,
|
||||
"forwarded_allow_ips": settings.forwarded_allow_ips,
|
||||
}
|
||||
if settings.ssl_certfile and settings.ssl_keyfile:
|
||||
kw["ssl_certfile"] = settings.ssl_certfile
|
||||
kw["ssl_keyfile"] = settings.ssl_keyfile
|
||||
if settings.ssl_keyfile_password:
|
||||
kw["ssl_keyfile_password"] = settings.ssl_keyfile_password
|
||||
return kw
|
||||
|
||||
if use_tray:
|
||||
import asyncio
|
||||
import threading
|
||||
@@ -381,9 +541,7 @@ def main():
|
||||
# Run uvicorn in a background thread so tray owns the main thread message loop
|
||||
uv_config = uvicorn.Config(
|
||||
"media_server.main:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=settings.log_level.lower(),
|
||||
**_uvicorn_kwargs(),
|
||||
)
|
||||
server = uvicorn.Server(uv_config)
|
||||
|
||||
@@ -421,9 +579,8 @@ def main():
|
||||
else:
|
||||
uvicorn.run(
|
||||
"media_server.main:app",
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
reload=False,
|
||||
**_uvicorn_kwargs(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -36,12 +36,20 @@ def _spawn_background(coro) -> asyncio.Task:
|
||||
|
||||
|
||||
def _require_folder_management() -> None:
|
||||
"""Raise 403 if media folder management is disabled in config."""
|
||||
"""Raise 403 if media folder management is disabled OR caller lacks admin scope."""
|
||||
if not settings.media_folders_management:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Media folder management is disabled. Set media_folders_management: true in config.yaml to enable.",
|
||||
)
|
||||
from ..auth import auth_enabled, token_has_scope, token_label_var
|
||||
if auth_enabled():
|
||||
label = token_label_var.get("unknown")
|
||||
if not token_has_scope(label, "admin"):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Token '{label}' lacks required scope: admin",
|
||||
)
|
||||
|
||||
|
||||
async def _broadcast_after_open(controller, label: str, max_wait: float = 2.0) -> None:
|
||||
|
||||
@@ -8,12 +8,14 @@ import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..auth import verify_token
|
||||
from ..config import CallbackConfig, settings
|
||||
from ..config_manager import config_manager
|
||||
from ..services.rate_limit import check as ratelimit_check
|
||||
from ..services.rate_limit import get_peer
|
||||
|
||||
router = APIRouter(prefix="/api/callbacks", tags=["callbacks"])
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -28,6 +30,7 @@ def shutdown_callback_executor() -> None:
|
||||
|
||||
|
||||
def _require_callbacks_management() -> None:
|
||||
"""Authorise a callbacks-CRUD operation. Operator flag + per-token admin scope."""
|
||||
if not settings.callbacks_management:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -36,6 +39,14 @@ def _require_callbacks_management() -> None:
|
||||
" in config.yaml to enable."
|
||||
),
|
||||
)
|
||||
from ..auth import auth_enabled, token_has_scope, token_label_var
|
||||
if auth_enabled():
|
||||
label = token_label_var.get("unknown")
|
||||
if not token_has_scope(label, "admin"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Token '{label}' lacks required scope: admin",
|
||||
)
|
||||
|
||||
|
||||
class CallbackInfo(BaseModel):
|
||||
@@ -122,6 +133,7 @@ async def list_callbacks(_: str = Depends(verify_token)) -> list[CallbackInfo]:
|
||||
@router.post("/execute/{callback_name}")
|
||||
async def execute_callback(
|
||||
callback_name: str,
|
||||
http_request: Request,
|
||||
_: str = Depends(verify_token),
|
||||
) -> CallbackExecuteResponse:
|
||||
"""Execute a callback for debugging purposes.
|
||||
@@ -132,6 +144,16 @@ async def execute_callback(
|
||||
Returns:
|
||||
Execution result including stdout, stderr, and exit code
|
||||
"""
|
||||
# Rate-limit callback execution per peer (10/min) — callbacks also run
|
||||
# subprocesses and need the same protection as scripts.
|
||||
allowed, retry_after = ratelimit_check("execute", get_peer(http_request))
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many callback executions, slow down",
|
||||
headers={"Retry-After": str(int(retry_after or 60))},
|
||||
)
|
||||
|
||||
# Validate callback name
|
||||
_validate_callback_name(callback_name)
|
||||
|
||||
@@ -146,6 +168,8 @@ async def execute_callback(
|
||||
|
||||
logger.info(f"Executing callback for debugging: {callback_name}")
|
||||
|
||||
from ..services.audit_log import record_script_execution
|
||||
|
||||
try:
|
||||
# Execute in dedicated thread pool to not block the default executor
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -159,6 +183,15 @@ async def execute_callback(
|
||||
),
|
||||
)
|
||||
|
||||
record_script_execution(
|
||||
kind="callback",
|
||||
name=callback_name,
|
||||
exit_code=result["exit_code"],
|
||||
duration=result.get("execution_time"),
|
||||
stdout=result.get("stdout"),
|
||||
stderr=result.get("stderr"),
|
||||
)
|
||||
|
||||
return CallbackExecuteResponse(
|
||||
success=result["exit_code"] == 0,
|
||||
callback=callback_name,
|
||||
@@ -170,6 +203,13 @@ async def execute_callback(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Callback execution error: {e}")
|
||||
record_script_execution(
|
||||
kind="callback",
|
||||
name=callback_name,
|
||||
exit_code=None,
|
||||
duration=None,
|
||||
error=str(e),
|
||||
)
|
||||
return CallbackExecuteResponse(
|
||||
success=False,
|
||||
callback=callback_name,
|
||||
|
||||
@@ -39,11 +39,20 @@ def _validate_icon(icon: str) -> str:
|
||||
|
||||
|
||||
def _require_links_management() -> None:
|
||||
"""Authorise a links-CRUD operation. Operator flag + per-token admin scope."""
|
||||
if not settings.links_management:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Links management is disabled. Set links_management: true in config.yaml to enable.",
|
||||
)
|
||||
from ..auth import auth_enabled, token_has_scope, token_label_var
|
||||
if auth_enabled():
|
||||
label = token_label_var.get("unknown")
|
||||
if not token_has_scope(label, "admin"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Token '{label}' lacks required scope: admin",
|
||||
)
|
||||
|
||||
|
||||
class LinkInfo(BaseModel):
|
||||
|
||||
+130
-25
@@ -3,7 +3,16 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Query,
|
||||
Request,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import Response
|
||||
|
||||
from ..auth import verify_token, verify_token_or_query
|
||||
@@ -17,19 +26,28 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/media", tags=["media"])
|
||||
|
||||
|
||||
# Strong refs to background tasks so the asyncio GC can't drop them before
|
||||
# they run. Mirrors the pattern used in routes/browser.py.
|
||||
_background_callback_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def _run_callback(callback_name: str) -> None:
|
||||
"""Fire-and-forget a callback if configured. Failures are logged but don't block."""
|
||||
if not settings.callbacks or callback_name not in settings.callbacks:
|
||||
return
|
||||
|
||||
async def _execute():
|
||||
# Use the dedicated callback executor (not the default loop pool) so a
|
||||
# misbehaving callback can't starve the rest of the app's sync tasks.
|
||||
from ..services.audit_log import record_script_execution
|
||||
from .callbacks import _callback_executor
|
||||
from .scripts import _run_script
|
||||
|
||||
try:
|
||||
callback = settings.callbacks[callback_name]
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
_callback_executor,
|
||||
lambda: _run_script(
|
||||
command=callback.command,
|
||||
timeout=callback.timeout,
|
||||
@@ -37,6 +55,14 @@ def _run_callback(callback_name: str) -> None:
|
||||
working_dir=callback.working_dir,
|
||||
),
|
||||
)
|
||||
record_script_execution(
|
||||
kind="event-callback",
|
||||
name=callback_name,
|
||||
exit_code=result["exit_code"],
|
||||
duration=result.get("execution_time"),
|
||||
stdout=result.get("stdout"),
|
||||
stderr=result.get("stderr"),
|
||||
)
|
||||
if result["exit_code"] != 0:
|
||||
logger.warning(
|
||||
"Callback %s failed with exit code %s: %s",
|
||||
@@ -46,8 +72,18 @@ def _run_callback(callback_name: str) -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Callback %s error: %s", callback_name, e)
|
||||
from ..services.audit_log import record_script_execution as _rec
|
||||
_rec(
|
||||
kind="event-callback",
|
||||
name=callback_name,
|
||||
exit_code=None,
|
||||
duration=None,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
asyncio.create_task(_execute())
|
||||
task = asyncio.create_task(_execute())
|
||||
_background_callback_tasks.add(task)
|
||||
task.add_done_callback(_background_callback_tasks.discard)
|
||||
|
||||
|
||||
@router.get("/status", response_model=MediaStatus)
|
||||
@@ -242,11 +278,14 @@ async def toggle(_: str = Depends(verify_token)) -> dict:
|
||||
|
||||
|
||||
@router.get("/artwork")
|
||||
async def get_artwork(_: str = Depends(verify_token_or_query)) -> Response:
|
||||
async def get_artwork(
|
||||
request: Request,
|
||||
_: str = Depends(verify_token_or_query),
|
||||
) -> Response:
|
||||
"""Get the current album artwork.
|
||||
|
||||
Returns:
|
||||
The album art image as PNG/JPEG
|
||||
Returns the bytes with a content-derived ETag so the browser can serve a
|
||||
304 when the same track is re-requested.
|
||||
"""
|
||||
art_bytes = get_current_album_art()
|
||||
if art_bytes is None:
|
||||
@@ -255,16 +294,34 @@ async def get_artwork(_: str = Depends(verify_token_or_query)) -> Response:
|
||||
detail="No album artwork available",
|
||||
)
|
||||
|
||||
# Try to detect image type from magic bytes
|
||||
content_type = "image/png" # Default
|
||||
# Detect image type from magic bytes
|
||||
if art_bytes[:3] == b"\xff\xd8\xff":
|
||||
content_type = "image/jpeg"
|
||||
elif art_bytes[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
content_type = "image/png"
|
||||
elif art_bytes[:4] == b"RIFF" and art_bytes[8:12] == b"WEBP":
|
||||
elif art_bytes[:4] == b"RIFF" and len(art_bytes) > 12 and art_bytes[8:12] == b"WEBP":
|
||||
content_type = "image/webp"
|
||||
elif art_bytes[:2] == b"BM":
|
||||
content_type = "image/bmp"
|
||||
else:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
return Response(content=art_bytes, media_type=content_type)
|
||||
# Content-derived ETag (blake2b-128 — non-crypto cache key, ruff S324-safe)
|
||||
import hashlib
|
||||
|
||||
etag = '"' + hashlib.blake2b(art_bytes, digest_size=16).hexdigest() + '"'
|
||||
|
||||
if request.headers.get("if-none-match") == etag:
|
||||
return Response(status_code=status.HTTP_304_NOT_MODIFIED, headers={"ETag": etag})
|
||||
|
||||
return Response(
|
||||
content=art_bytes,
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"ETag": etag,
|
||||
"Cache-Control": "private, max-age=0, must-revalidate",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/visualizer/status")
|
||||
@@ -323,12 +380,17 @@ async def set_visualizer_device(
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
token: str | None = Query(None, description="API authentication token"),
|
||||
token: str | None = Query(None, description="API authentication token (legacy)"),
|
||||
) -> None:
|
||||
"""WebSocket endpoint for real-time media status updates.
|
||||
|
||||
Authentication is done via query parameter since WebSocket
|
||||
doesn't support custom headers in the browser.
|
||||
Authentication is accepted from two sources, in priority order:
|
||||
1. ``Sec-WebSocket-Protocol`` subprotocol of the form
|
||||
``media-server.token.<TOKEN>``. This is the preferred path because
|
||||
the token never lands in the URL, request logs, or browser history.
|
||||
The browser WebSocket API supports custom subprotocols natively.
|
||||
2. ``?token=<TOKEN>`` query parameter (legacy, kept for back-compat
|
||||
with older clients and the HA integration).
|
||||
|
||||
Messages sent to client:
|
||||
- {"type": "status", "data": {...}} - Initial status on connect
|
||||
@@ -339,11 +401,40 @@ async def websocket_endpoint(
|
||||
- {"type": "ping"} - Keepalive, server responds with {"type": "pong"}
|
||||
- {"type": "get_status"} - Request current status
|
||||
"""
|
||||
# Pull token from subprotocol if present. WebSocket spec lets either side
|
||||
# negotiate exactly one subprotocol back; we accept the token one and
|
||||
# answer with the same string so browsers consider the negotiation
|
||||
# successful.
|
||||
subprotocol_token: str | None = None
|
||||
accept_subprotocol: str | None = None
|
||||
raw_protocols = websocket.headers.get("sec-websocket-protocol", "")
|
||||
for proto in (p.strip() for p in raw_protocols.split(",") if p.strip()):
|
||||
if proto.startswith("media-server.token."):
|
||||
subprotocol_token = proto[len("media-server.token."):]
|
||||
accept_subprotocol = proto
|
||||
break
|
||||
effective_token = subprotocol_token or token
|
||||
# Origin check — block CSWSH from third-party LAN pages. We accept the same
|
||||
# set of origins as CORS plus the default localhost loopback.
|
||||
allowed_origins = set(
|
||||
settings.cors_origins
|
||||
or [
|
||||
f"http://localhost:{settings.port}",
|
||||
f"http://127.0.0.1:{settings.port}",
|
||||
]
|
||||
)
|
||||
origin = websocket.headers.get("origin")
|
||||
# Same-origin connections from native apps may omit Origin entirely; only
|
||||
# reject when an Origin is present AND not in the allow-list.
|
||||
if origin is not None and origin not in allowed_origins:
|
||||
await websocket.close(code=4003, reason="Origin not allowed")
|
||||
return
|
||||
|
||||
# Verify token
|
||||
from ..auth import auth_enabled, get_token_label, token_label_var
|
||||
|
||||
if auth_enabled():
|
||||
label = get_token_label(token) if token else None
|
||||
label = get_token_label(effective_token) if effective_token else None
|
||||
if label is None:
|
||||
await websocket.close(code=4001, reason="Invalid authentication token")
|
||||
return
|
||||
@@ -351,16 +442,25 @@ async def websocket_endpoint(
|
||||
else:
|
||||
token_label_var.set("anonymous")
|
||||
|
||||
await ws_manager.connect(websocket)
|
||||
# Accept with the negotiated subprotocol if one was used. Starlette's
|
||||
# connect() calls accept() with no subprotocol — we need to accept first
|
||||
# explicitly to echo the subprotocol back, then hand off to the manager.
|
||||
if accept_subprotocol is not None:
|
||||
await websocket.accept(subprotocol=accept_subprotocol)
|
||||
await ws_manager.connect(websocket, already_accepted=True)
|
||||
else:
|
||||
await ws_manager.connect(websocket)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait for messages from client (for keepalive/ping)
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data.get("type") == "ping":
|
||||
msg_type = data.get("type") if isinstance(data, dict) else None
|
||||
|
||||
if msg_type == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
elif data.get("type") == "get_status":
|
||||
elif msg_type == "get_status":
|
||||
# Allow manual status request
|
||||
controller = get_media_controller()
|
||||
status_data = await controller.get_status()
|
||||
@@ -368,15 +468,20 @@ async def websocket_endpoint(
|
||||
"type": "status",
|
||||
"data": status_data.model_dump(),
|
||||
})
|
||||
elif data.get("type") == "volume":
|
||||
# Low-latency volume control via WebSocket
|
||||
volume = data.get("volume")
|
||||
if volume is not None:
|
||||
controller = get_media_controller()
|
||||
await controller.set_volume(int(volume))
|
||||
elif data.get("type") == "enable_visualizer":
|
||||
elif msg_type == "volume":
|
||||
# Low-latency volume control via WebSocket. Coerce, clamp, and
|
||||
# never drop the socket on a single bad message — that would
|
||||
# turn the WS into a one-shot DoS for any holder of a token.
|
||||
try:
|
||||
volume = int(data.get("volume"))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
volume = max(0, min(100, volume))
|
||||
controller = get_media_controller()
|
||||
await controller.set_volume(volume)
|
||||
elif msg_type == "enable_visualizer":
|
||||
await ws_manager.subscribe_visualizer(websocket)
|
||||
elif data.get("type") == "disable_visualizer":
|
||||
elif msg_type == "disable_visualizer":
|
||||
await ws_manager.unsubscribe_visualizer(websocket)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
|
||||
@@ -10,12 +10,14 @@ import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..auth import verify_token
|
||||
from ..config import ScriptConfig, ScriptParameterConfig, settings
|
||||
from ..config_manager import config_manager
|
||||
from ..services.rate_limit import check as ratelimit_check
|
||||
from ..services.rate_limit import get_peer
|
||||
from ..services.websocket_manager import ws_manager
|
||||
|
||||
router = APIRouter(prefix="/api/scripts", tags=["scripts"])
|
||||
@@ -31,6 +33,12 @@ def shutdown_script_executor() -> None:
|
||||
|
||||
|
||||
def _require_scripts_management() -> None:
|
||||
"""Authorise a scripts-CRUD operation.
|
||||
|
||||
Two gates: the operator-level `scripts_management` flag in config.yaml,
|
||||
AND the per-token `admin` scope check (read from request-context). Either
|
||||
failure → 403.
|
||||
"""
|
||||
if not settings.scripts_management:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
@@ -39,6 +47,14 @@ def _require_scripts_management() -> None:
|
||||
" in config.yaml to enable."
|
||||
),
|
||||
)
|
||||
from ..auth import auth_enabled, token_has_scope, token_label_var
|
||||
if auth_enabled():
|
||||
label = token_label_var.get("unknown")
|
||||
if not token_has_scope(label, "admin"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Token '{label}' lacks required scope: admin",
|
||||
)
|
||||
|
||||
|
||||
class ScriptExecuteRequest(BaseModel):
|
||||
@@ -215,6 +231,28 @@ def _validate_params(
|
||||
# string — just convert to str
|
||||
value = str(value)
|
||||
|
||||
# Optional regex constraint, validated against the *string form* of the
|
||||
# value. This is the only practical defence for string parameters that
|
||||
# flow into shell=true scripts via env vars (Windows cmd.exe expands
|
||||
# `%VAR%` after argument parsing, so embedded `&`/`|`/`%` would inject
|
||||
# commands). Authors of shell scripts should ALWAYS define a pattern.
|
||||
if pdef.pattern:
|
||||
try:
|
||||
if not re.fullmatch(pdef.pattern, str(value)):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=(
|
||||
f"Parameter '{pname}' value {value!r} does not match"
|
||||
f" required pattern: {pdef.pattern}"
|
||||
),
|
||||
)
|
||||
except re.error as e:
|
||||
# Bad pattern in config — fail closed.
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Parameter '{pname}' has invalid pattern: {e}",
|
||||
) from e
|
||||
|
||||
env_vars[f"SCRIPT_PARAM_{pname.upper()}"] = str(value)
|
||||
|
||||
return env_vars
|
||||
@@ -223,6 +261,7 @@ def _validate_params(
|
||||
@router.post("/execute/{script_name}")
|
||||
async def execute_script(
|
||||
script_name: str,
|
||||
http_request: Request,
|
||||
request: ScriptExecuteRequest | None = None,
|
||||
_: str = Depends(verify_token),
|
||||
) -> ScriptExecuteResponse:
|
||||
@@ -235,6 +274,16 @@ async def execute_script(
|
||||
Returns:
|
||||
Execution result including stdout, stderr, and exit code
|
||||
"""
|
||||
# Rate-limit script execution per peer so a leaked token can't be used to
|
||||
# spam the shell-exec endpoint.
|
||||
allowed, retry_after = ratelimit_check("execute", get_peer(http_request))
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many script executions, slow down",
|
||||
headers={"Retry-After": str(int(retry_after or 60))},
|
||||
)
|
||||
|
||||
# Check if script exists
|
||||
if script_name not in settings.scripts:
|
||||
raise HTTPException(
|
||||
@@ -249,6 +298,8 @@ async def execute_script(
|
||||
|
||||
logger.info(f"Executing script: {script_name}")
|
||||
|
||||
from ..services.audit_log import record_script_execution
|
||||
|
||||
try:
|
||||
# Execute in dedicated thread pool to not block the default executor
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -263,6 +314,15 @@ async def execute_script(
|
||||
),
|
||||
)
|
||||
|
||||
record_script_execution(
|
||||
kind="script",
|
||||
name=script_name,
|
||||
exit_code=result["exit_code"],
|
||||
duration=result.get("execution_time"),
|
||||
stdout=result.get("stdout"),
|
||||
stderr=result.get("stderr"),
|
||||
)
|
||||
|
||||
return ScriptExecuteResponse(
|
||||
success=result["exit_code"] == 0,
|
||||
script=script_name,
|
||||
@@ -274,6 +334,13 @@ async def execute_script(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Script execution error: {e}")
|
||||
record_script_execution(
|
||||
kind="script",
|
||||
name=script_name,
|
||||
exit_code=None,
|
||||
duration=None,
|
||||
error=str(e),
|
||||
)
|
||||
return ScriptExecuteResponse(
|
||||
success=False,
|
||||
script=script_name,
|
||||
@@ -313,9 +380,21 @@ def _run_script(
|
||||
else:
|
||||
popen_kwargs["start_new_session"] = True
|
||||
|
||||
# When shell=False, the user-provided command string is split via shlex
|
||||
# (POSIX rules — also works for Windows args without backslashes). This
|
||||
# disables shell metacharacter expansion entirely, so SCRIPT_PARAM_* env
|
||||
# vars referenced as $FOO / %FOO% will be treated as literal text by the
|
||||
# process, not interpreted by a shell. Use shell=false for any script
|
||||
# whose params come from external input.
|
||||
if shell:
|
||||
run_command: str | list[str] = command
|
||||
else:
|
||||
import shlex
|
||||
run_command = shlex.split(command, posix=(sys.platform != "win32"))
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
command,
|
||||
run_command,
|
||||
shell=shell,
|
||||
cwd=working_dir,
|
||||
capture_output=True,
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
"""Append-only audit log for sensitive actions (script + callback execution).
|
||||
|
||||
Writes a single JSONL line per event to ``<config_dir>/audit.log``. The log is
|
||||
write-only from the app's perspective — it never reads back, and rotation is
|
||||
left to the operator (the file size is dominated by stdout/stderr truncation,
|
||||
which is already capped at 10 KB per stream in `_run_script`).
|
||||
|
||||
Designed to be cheap: the write goes through a small background thread so the
|
||||
hot path never blocks on disk I/O, and a failure to write is logged at WARNING
|
||||
but never raised to callers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from ..auth import token_label_var
|
||||
from ..config import get_config_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cap on stdout/stderr inside the audit record so a chatty script doesn't
|
||||
# explode the log. Mirrors the 10k cap used by _run_script.
|
||||
_OUTPUT_CAP = 2000
|
||||
|
||||
_audit_queue: "queue.Queue[dict[str, Any] | None]" = queue.Queue(maxsize=1000)
|
||||
_audit_thread: threading.Thread | None = None
|
||||
_audit_lock = threading.Lock()
|
||||
|
||||
|
||||
def _ensure_writer_started() -> None:
|
||||
global _audit_thread
|
||||
with _audit_lock:
|
||||
if _audit_thread is not None and _audit_thread.is_alive():
|
||||
return
|
||||
_audit_thread = threading.Thread(
|
||||
target=_audit_writer_loop,
|
||||
name="audit-log",
|
||||
daemon=True,
|
||||
)
|
||||
_audit_thread.start()
|
||||
|
||||
|
||||
def _audit_writer_loop() -> None:
|
||||
log_path = get_config_dir() / "audit.log"
|
||||
while True:
|
||||
try:
|
||||
record = _audit_queue.get()
|
||||
except Exception:
|
||||
return
|
||||
if record is None:
|
||||
return
|
||||
try:
|
||||
line = json.dumps(record, ensure_ascii=False, default=str)
|
||||
with open(log_path, "a", encoding="utf-8") as f:
|
||||
f.write(line + "\n")
|
||||
except OSError as e:
|
||||
logger.warning("Failed to write audit record: %s", e)
|
||||
|
||||
|
||||
def _truncate(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if len(value) <= _OUTPUT_CAP:
|
||||
return value
|
||||
return value[:_OUTPUT_CAP] + f"\n…[truncated, {len(value) - _OUTPUT_CAP} chars]"
|
||||
|
||||
|
||||
def record_script_execution(
|
||||
*,
|
||||
kind: str,
|
||||
name: str,
|
||||
exit_code: int | None,
|
||||
duration: float | None,
|
||||
stdout: str | None = None,
|
||||
stderr: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Append a single audit record. Never raises."""
|
||||
_ensure_writer_started()
|
||||
try:
|
||||
record = {
|
||||
"ts": time.time(),
|
||||
"iso": time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()),
|
||||
"token_label": token_label_var.get("unknown"),
|
||||
"kind": kind,
|
||||
"name": name,
|
||||
"exit_code": exit_code,
|
||||
"duration_s": round(duration, 4) if duration is not None else None,
|
||||
"success": exit_code == 0 if exit_code is not None else False,
|
||||
"stdout": _truncate(stdout),
|
||||
"stderr": _truncate(stderr),
|
||||
"error": error,
|
||||
}
|
||||
_audit_queue.put_nowait(record)
|
||||
except queue.Full:
|
||||
# Backpressure: drop oldest record to make room. We'd rather lose an
|
||||
# old entry than block the script that just ran.
|
||||
try:
|
||||
_audit_queue.get_nowait()
|
||||
_audit_queue.put_nowait(record)
|
||||
except queue.Empty:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enqueue audit record: %s", e)
|
||||
|
||||
|
||||
def shutdown_audit_log() -> None:
|
||||
"""Flush the audit queue on app shutdown."""
|
||||
try:
|
||||
_audit_queue.put_nowait(None)
|
||||
except queue.Full:
|
||||
pass
|
||||
if _audit_thread is not None:
|
||||
_audit_thread.join(timeout=2)
|
||||
@@ -192,10 +192,11 @@ _CACHE_TTL = 5.0 # seconds
|
||||
# Per-monitor cache of static capabilities (option lists + support flags).
|
||||
# DDC/CI capability discovery is the slow part — it only changes when a
|
||||
# monitor is replaced or rewired, so we probe it once per monitor and reuse
|
||||
# it across refreshes. Cleared on explicit `rediscover` or when the monitor
|
||||
# count changes (cheap stale-detection for hot-plug events).
|
||||
_static_cache: dict[int, dict] = {}
|
||||
_static_cache_monitor_count: int = -1
|
||||
# it across refreshes. Keyed by a stable identity tuple
|
||||
# (manufacturer, model, edid_hash) so that hot-plug swaps where the new
|
||||
# topology has the same number of monitors but different devices still
|
||||
# refresh the cache for the new monitor instead of serving stale capabilities.
|
||||
_static_cache: dict[tuple, dict] = {}
|
||||
|
||||
|
||||
def _enum_name(value, enum_cls=None) -> str | None:
|
||||
@@ -353,7 +354,7 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list
|
||||
next probe re-runs DDC/CI capability discovery. Use after hot-plug
|
||||
or when a monitor's reported capabilities change.
|
||||
"""
|
||||
global _monitor_cache, _cache_time, _static_cache_monitor_count
|
||||
global _monitor_cache, _cache_time
|
||||
|
||||
if (
|
||||
not force_refresh
|
||||
@@ -372,12 +373,11 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list
|
||||
info_list = sbc.list_monitors_info()
|
||||
brightnesses = sbc.get_brightness()
|
||||
|
||||
# Invalidate the static cache on explicit rediscover OR on topology
|
||||
# change (hot-plug / disconnect). Both indicate the cached probe is
|
||||
# potentially stale.
|
||||
if rediscover or len(info_list) != _static_cache_monitor_count:
|
||||
# Explicit rediscover wipes the whole cache; otherwise rely on stable
|
||||
# per-monitor keys (manufacturer|model|edid_hash) so a hot-plug swap
|
||||
# invalidates the entry for the missing monitor automatically.
|
||||
if rediscover:
|
||||
_static_cache.clear()
|
||||
_static_cache_monitor_count = len(info_list)
|
||||
|
||||
mc = _load_monitorcontrol()
|
||||
ddc_monitors = []
|
||||
@@ -387,6 +387,9 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import hashlib
|
||||
|
||||
seen_keys: set[tuple] = set()
|
||||
for i, info in enumerate(info_list):
|
||||
name = info.get("name", f"Monitor {i}")
|
||||
model = info.get("model", "")
|
||||
@@ -400,6 +403,21 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list
|
||||
edid = info.get("edid", "")
|
||||
resolution = _parse_edid_resolution(edid) if edid else None
|
||||
|
||||
# Stable cache key — EDID hash is unique per physical monitor.
|
||||
# Fall back to (manufacturer, model, serial-ish) when EDID is
|
||||
# missing, then to the legacy index as a last resort.
|
||||
if edid:
|
||||
edid_hash = hashlib.blake2b(
|
||||
edid.encode("utf-8") if isinstance(edid, str) else bytes(edid),
|
||||
digest_size=8,
|
||||
).hexdigest()
|
||||
cache_key: tuple = ("edid", edid_hash)
|
||||
elif manufacturer or model:
|
||||
cache_key = ("mm", manufacturer, model, name)
|
||||
else:
|
||||
cache_key = ("idx", i)
|
||||
seen_keys.add(cache_key)
|
||||
|
||||
static: dict = {}
|
||||
dynamic: dict = {}
|
||||
|
||||
@@ -409,13 +427,13 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list
|
||||
if power_supported and i < len(ddc_monitors):
|
||||
try:
|
||||
with ddc_monitors[i] as mon:
|
||||
if i not in _static_cache:
|
||||
_static_cache[i] = _probe_static_open(mon, mc, i)
|
||||
static = _static_cache[i]
|
||||
if cache_key not in _static_cache:
|
||||
_static_cache[cache_key] = _probe_static_open(mon, mc, i)
|
||||
static = _static_cache[cache_key]
|
||||
dynamic = _probe_dynamic_open(mon, mc, i, static)
|
||||
except Exception as e:
|
||||
logger.debug("Monitor %d: DDC/CI session failed: %s", i, e)
|
||||
static = _static_cache.get(i, {})
|
||||
static = _static_cache.get(cache_key, {})
|
||||
|
||||
monitors.append(MonitorInfo(
|
||||
id=i,
|
||||
@@ -439,6 +457,12 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list
|
||||
available_picture_modes=static.get("available_picture_modes", []),
|
||||
picture_mode_supported=static.get("picture_mode_supported", False),
|
||||
))
|
||||
# Evict cache entries for monitors that disappeared from this scan so
|
||||
# the next hot-plug of a different monitor with the same identity
|
||||
# tuple (e.g. same model) doesn't hit a stale entry first.
|
||||
for stale_key in list(_static_cache.keys()):
|
||||
if stale_key not in seen_keys:
|
||||
_static_cache.pop(stale_key, None)
|
||||
except Exception as e:
|
||||
logger.error("Failed to enumerate monitors: %s", e)
|
||||
|
||||
|
||||
@@ -86,9 +86,29 @@ class _Cache:
|
||||
|
||||
_cache = _Cache()
|
||||
|
||||
# Win32 handles + signatures are declared once at module load (when running on
|
||||
# Windows). The TTL cache fires this hundreds of times per minute; redoing the
|
||||
# DLL load + ~10 argtype assignments per call was the largest chunk of probe
|
||||
# cost. Keep these guarded behind a lazy init so non-Windows platforms don't
|
||||
# pay the import.
|
||||
_WIN32_INITIALIZED = False
|
||||
_win32_user32 = None
|
||||
_win32_kernel32 = None
|
||||
_win32_psapi = None
|
||||
|
||||
|
||||
def _init_win32_apis() -> None:
|
||||
"""Declare ctypes argtypes/restype on every Win32 call we make.
|
||||
|
||||
CRITICAL: ctypes defaults to `c_int` (32-bit) for HANDLE/HWND/HMONITOR
|
||||
which silently truncates 64-bit pointer values on x64 — that corrupts the
|
||||
handle so `CloseHandle()` can either fail or close the wrong kernel
|
||||
object, and pointer-equality comparisons (monitor index lookup) miss.
|
||||
"""
|
||||
global _WIN32_INITIALIZED, _win32_user32, _win32_kernel32, _win32_psapi
|
||||
if _WIN32_INITIALIZED:
|
||||
return
|
||||
|
||||
def _probe_windows() -> ForegroundInfo:
|
||||
"""Probe foreground window state on Windows via Win32 API."""
|
||||
import ctypes
|
||||
import ctypes.wintypes as wt
|
||||
|
||||
@@ -96,11 +116,6 @@ def _probe_windows() -> ForegroundInfo:
|
||||
kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
|
||||
psapi = ctypes.WinDLL("psapi", use_last_error=True)
|
||||
|
||||
# CRITICAL: declare argtypes/restype on every Win32 call that returns a
|
||||
# HANDLE/HWND/HMONITOR. ctypes defaults to `c_int` (32-bit) which
|
||||
# silently truncates 64-bit pointer values on x64 — that corrupts the
|
||||
# handle so `CloseHandle()` can either fail or close the wrong kernel
|
||||
# object, and pointer-equality comparisons (monitor index lookup) miss.
|
||||
user32.GetForegroundWindow.restype = wt.HWND
|
||||
user32.GetWindowThreadProcessId.argtypes = [wt.HWND, ctypes.POINTER(wt.DWORD)]
|
||||
user32.GetWindowThreadProcessId.restype = wt.DWORD
|
||||
@@ -137,6 +152,20 @@ def _probe_windows() -> ForegroundInfo:
|
||||
psapi.GetModuleFileNameExW.argtypes = [wt.HANDLE, wt.HMODULE, wt.LPWSTR, wt.DWORD]
|
||||
psapi.GetModuleFileNameExW.restype = wt.DWORD
|
||||
|
||||
_win32_user32, _win32_kernel32, _win32_psapi = user32, kernel32, psapi
|
||||
_WIN32_INITIALIZED = True
|
||||
|
||||
|
||||
def _probe_windows() -> ForegroundInfo:
|
||||
"""Probe foreground window state on Windows via Win32 API."""
|
||||
import ctypes
|
||||
import ctypes.wintypes as wt
|
||||
|
||||
_init_win32_apis()
|
||||
user32 = _win32_user32
|
||||
kernel32 = _win32_kernel32
|
||||
psapi = _win32_psapi
|
||||
|
||||
hwnd = user32.GetForegroundWindow()
|
||||
if not hwnd:
|
||||
return ForegroundInfo(available=True, error="no foreground window")
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from typing import Optional
|
||||
@@ -15,6 +16,11 @@ _DEFAULT_BASE_URL = "https://git.dolgolyov-family.by"
|
||||
_DEFAULT_OWNER = "alexei.dolgolyov"
|
||||
_DEFAULT_REPO = "media-player-server"
|
||||
|
||||
# Restrictive tag whitelist — prevents a hostile Gitea response (or MITM) from
|
||||
# injecting `..`, slashes, or URL-altering characters into the release URL we
|
||||
# broadcast to clients. SemVer + pre-release suffix only.
|
||||
_TAG_RE = re.compile(r"^v?\d+\.\d+\.\d+(?:[\w.\-+]{0,32})?$")
|
||||
|
||||
|
||||
class GiteaReleaseProvider(ReleaseProvider):
|
||||
"""Fetches the latest release from a Gitea repository."""
|
||||
@@ -53,6 +59,9 @@ class GiteaReleaseProvider(ReleaseProvider):
|
||||
continue
|
||||
|
||||
tag = release.get("tag_name", "")
|
||||
if not isinstance(tag, str) or not _TAG_RE.match(tag):
|
||||
logger.warning("Rejecting malformed release tag from upstream: %r", tag)
|
||||
continue
|
||||
version = tag.lstrip("v")
|
||||
if not version:
|
||||
continue
|
||||
|
||||
@@ -264,8 +264,12 @@ class MacOSMediaController(MediaController):
|
||||
|
||||
async def set_volume(self, volume: int) -> bool:
|
||||
"""Set system volume."""
|
||||
result = self._run_osascript(f"set volume output volume {volume}")
|
||||
return result is not None or True # osascript returns empty on success
|
||||
# osascript returns empty string on success and None on failure (the
|
||||
# _run_osascript helper catches subprocess errors). The previous
|
||||
# `result is not None or True` always returned True regardless of
|
||||
# outcome — surface real failures so the route can return 503.
|
||||
result = self._run_osascript(f"set volume output volume {int(volume)}")
|
||||
return result is not None
|
||||
|
||||
async def toggle_mute(self) -> bool:
|
||||
"""Toggle mute state."""
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
"""In-process token-bucket rate limiter.
|
||||
|
||||
Light enough for a single-process app: one dict keyed by ``(bucket, peer)``
|
||||
guarded by a thread lock. No extra dependency, no Redis. Good enough for
|
||||
defeating credential-stuffing and runaway clients on a LAN; not a substitute
|
||||
for an upstream WAF in a public deployment.
|
||||
|
||||
Buckets:
|
||||
auth — failed-auth attempts, 5/min/peer (used in auth middleware)
|
||||
execute — script + callback execute calls, 10/min/peer (LAN-friendly)
|
||||
default — generic POST/DELETE writes, 60/min/peer
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BucketConfig:
|
||||
capacity: float # max tokens (= burst size)
|
||||
refill_per_sec: float # tokens added per second
|
||||
|
||||
|
||||
# Defaults — tuned for "trusted LAN" use; operator can override via Settings.
|
||||
BUCKETS: dict[str, BucketConfig] = {
|
||||
"auth": BucketConfig(capacity=5, refill_per_sec=5 / 60), # 5/min
|
||||
"execute": BucketConfig(capacity=10, refill_per_sec=10 / 60), # 10/min
|
||||
"default": BucketConfig(capacity=60, refill_per_sec=60 / 60), # 60/min
|
||||
}
|
||||
|
||||
|
||||
_state: dict[tuple[str, str], tuple[float, float]] = {}
|
||||
_lock = threading.Lock()
|
||||
_LAST_CLEANUP = 0.0
|
||||
|
||||
|
||||
def _evict_stale_locked(now: float) -> None:
|
||||
"""Drop entries whose buckets are full (= idle for capacity / refill seconds)."""
|
||||
global _LAST_CLEANUP
|
||||
if now - _LAST_CLEANUP < 60:
|
||||
return
|
||||
_LAST_CLEANUP = now
|
||||
stale = []
|
||||
for key, (tokens, last) in _state.items():
|
||||
bucket = BUCKETS.get(key[0])
|
||||
if bucket is None:
|
||||
continue
|
||||
if tokens >= bucket.capacity and (now - last) > 3600:
|
||||
stale.append(key)
|
||||
for key in stale:
|
||||
_state.pop(key, None)
|
||||
|
||||
|
||||
def check(bucket: str, peer: str) -> tuple[bool, Optional[float]]:
|
||||
"""Try to consume one token from ``(bucket, peer)``.
|
||||
|
||||
Returns:
|
||||
(allowed, retry_after_seconds). When allowed=True retry_after is None.
|
||||
When allowed=False, retry_after is the seconds to wait for one more token.
|
||||
"""
|
||||
cfg = BUCKETS.get(bucket) or BUCKETS["default"]
|
||||
now = time.monotonic()
|
||||
with _lock:
|
||||
_evict_stale_locked(now)
|
||||
tokens, last = _state.get((bucket, peer), (cfg.capacity, now))
|
||||
elapsed = max(0.0, now - last)
|
||||
tokens = min(cfg.capacity, tokens + elapsed * cfg.refill_per_sec)
|
||||
if tokens >= 1:
|
||||
tokens -= 1
|
||||
_state[(bucket, peer)] = (tokens, now)
|
||||
return True, None
|
||||
deficit = 1 - tokens
|
||||
retry = deficit / cfg.refill_per_sec if cfg.refill_per_sec > 0 else 60
|
||||
_state[(bucket, peer)] = (tokens, now)
|
||||
return False, retry
|
||||
|
||||
|
||||
def get_peer(request) -> str:
|
||||
"""Best-effort peer identifier from a Starlette request.
|
||||
|
||||
Honors X-Forwarded-For (only when settings.proxy_headers is True, which is
|
||||
already enforced by uvicorn's middleware) so a reverse-proxied install
|
||||
still rate-limits per real client.
|
||||
"""
|
||||
client = getattr(request, "client", None)
|
||||
if client and client.host:
|
||||
return client.host
|
||||
return "unknown"
|
||||
@@ -26,12 +26,23 @@ class ThumbnailService:
|
||||
def get_cache_dir() -> Path:
|
||||
"""Get the thumbnail cache directory path.
|
||||
|
||||
Returns:
|
||||
Path to the cache directory (project-local).
|
||||
Returns user-writable platform cache dir so installs under
|
||||
``%PROGRAMFILES%`` / ``/opt`` work without elevated permissions.
|
||||
Mirrors the platform branching of ``config.get_config_dir``.
|
||||
"""
|
||||
# Store cache in project directory: media-server/.cache/thumbnails/
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
cache_dir = project_root / ".cache" / "thumbnails"
|
||||
import os
|
||||
|
||||
if os.name == "nt":
|
||||
# %LOCALAPPDATA% so the cache survives roaming-profile sync.
|
||||
base = Path(os.environ.get("LOCALAPPDATA")
|
||||
or os.environ.get("APPDATA")
|
||||
or Path.home() / "AppData" / "Local")
|
||||
cache_dir = base / "media-server" / "cache" / "thumbnails"
|
||||
else:
|
||||
# XDG_CACHE_HOME convention; falls back to ~/.cache.
|
||||
xdg = os.environ.get("XDG_CACHE_HOME")
|
||||
base = Path(xdg) if xdg else Path.home() / ".cache"
|
||||
cache_dir = base / "media-server" / "thumbnails"
|
||||
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cache_dir
|
||||
|
||||
@@ -33,9 +33,15 @@ class ConnectionManager:
|
||||
self._audio_task: asyncio.Task | None = None
|
||||
self._audio_analyzer = None
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> None:
|
||||
"""Accept a new WebSocket connection."""
|
||||
await websocket.accept()
|
||||
async def connect(self, websocket: WebSocket, already_accepted: bool = False) -> None:
|
||||
"""Accept a new WebSocket connection.
|
||||
|
||||
``already_accepted=True`` is for callers that needed to call
|
||||
``websocket.accept(subprotocol=...)`` themselves (token-via-subprotocol
|
||||
auth path).
|
||||
"""
|
||||
if not already_accepted:
|
||||
await websocket.accept()
|
||||
async with self._lock:
|
||||
self._active_connections.add(websocket)
|
||||
logger.info(
|
||||
|
||||
@@ -31,8 +31,15 @@ def _thread_loop() -> asyncio.AbstractEventLoop:
|
||||
_thread_local.loop = loop
|
||||
return loop
|
||||
|
||||
# Global storage for current album art (as bytes)
|
||||
# Global storage for current album art (as bytes). Guarded by _art_lock so the
|
||||
# WinRT polling thread and the FastAPI handler thread don't race on swap.
|
||||
_current_album_art_bytes: bytes | None = None
|
||||
_art_lock = threading.Lock()
|
||||
|
||||
# Identity of the track whose art is currently in _current_album_art_bytes.
|
||||
# Used to gate the expensive WinRT thumbnail.open_read_async() so the bytes
|
||||
# aren't re-decoded on every 500ms status poll.
|
||||
_current_album_art_key: tuple | None = None
|
||||
|
||||
# Lock protecting _position_cache and _track_skip_pending from concurrent access
|
||||
_position_lock = threading.Lock()
|
||||
@@ -56,8 +63,9 @@ _track_skip_pending = {
|
||||
|
||||
|
||||
def get_current_album_art() -> bytes | None:
|
||||
"""Get the current album art bytes."""
|
||||
return _current_album_art_bytes
|
||||
"""Get the current album art bytes (thread-safe snapshot)."""
|
||||
with _art_lock:
|
||||
return _current_album_art_bytes
|
||||
|
||||
# Windows-specific imports
|
||||
try:
|
||||
@@ -379,28 +387,48 @@ def _sync_get_media_status() -> dict[str, Any]:
|
||||
except Exception as e:
|
||||
logger.debug(f"Timeline parse error: {e}")
|
||||
|
||||
# Try to get album art (requires media_props)
|
||||
# Try to get album art (requires media_props). Gated by track key so
|
||||
# the WinRT IPC + bytes copy only runs when the track actually
|
||||
# changes; otherwise we just preserve the existing cached bytes.
|
||||
if media_props:
|
||||
try:
|
||||
thumbnail = media_props.thumbnail
|
||||
if thumbnail:
|
||||
stream = loop.run_until_complete(thumbnail.open_read_async())
|
||||
if stream:
|
||||
size = stream.size
|
||||
if size > 0 and size < 10 * 1024 * 1024: # Max 10MB
|
||||
from winsdk.windows.storage.streams import DataReader
|
||||
reader = DataReader(stream)
|
||||
loop.run_until_complete(reader.load_async(size))
|
||||
buffer = bytearray(size)
|
||||
reader.read_bytes(buffer)
|
||||
reader.close()
|
||||
stream.close()
|
||||
track_key = (
|
||||
getattr(media_props, "title", "") or "",
|
||||
getattr(media_props, "artist", "") or "",
|
||||
getattr(media_props, "album_title", "") or "",
|
||||
)
|
||||
global _current_album_art_bytes, _current_album_art_key
|
||||
if track_key == _current_album_art_key and _current_album_art_bytes:
|
||||
# Same track — reuse cached art bytes without touching WinRT.
|
||||
result["album_art_url"] = "/api/media/artwork"
|
||||
else:
|
||||
try:
|
||||
thumbnail = media_props.thumbnail
|
||||
if thumbnail:
|
||||
stream = loop.run_until_complete(thumbnail.open_read_async())
|
||||
if stream:
|
||||
size = stream.size
|
||||
if size > 0 and size < 10 * 1024 * 1024: # Max 10MB
|
||||
from winsdk.windows.storage.streams import DataReader
|
||||
reader = DataReader(stream)
|
||||
loop.run_until_complete(reader.load_async(size))
|
||||
buffer = bytearray(size)
|
||||
reader.read_bytes(buffer)
|
||||
reader.close()
|
||||
stream.close()
|
||||
|
||||
global _current_album_art_bytes
|
||||
_current_album_art_bytes = bytes(buffer)
|
||||
result["album_art_url"] = "/api/media/artwork"
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get album art: {e}")
|
||||
with _art_lock:
|
||||
_current_album_art_bytes = bytes(buffer)
|
||||
_current_album_art_key = track_key
|
||||
result["album_art_url"] = "/api/media/artwork"
|
||||
else:
|
||||
# No thumbnail on this track — drop stale bytes so
|
||||
# the ETag flips and clients don't keep showing the
|
||||
# previous album's cover.
|
||||
with _art_lock:
|
||||
_current_album_art_bytes = None
|
||||
_current_album_art_key = track_key
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get album art: {e}")
|
||||
|
||||
result["source"] = session.source_app_user_model_id
|
||||
|
||||
|
||||
@@ -26,16 +26,16 @@
|
||||
</div>
|
||||
</div>
|
||||
<div class="mini-controls">
|
||||
<button class="mini-control-btn mini-nav-btn" data-onclick="previousTrack()" data-i18n-title="player.previous" title="Previous">
|
||||
<svg viewBox="0 0 24 24"><path d="M6 6h2v12H6zm3.5 6l8.5 6V6z"/></svg>
|
||||
<button class="mini-control-btn mini-nav-btn" data-onclick="previousTrack()" data-i18n-title="player.previous" data-i18n-aria-label="player.previous" title="Previous" aria-label="Previous">
|
||||
<svg viewBox="0 0 24 24" aria-hidden="true" focusable="false"><path d="M6 6h2v12H6zm3.5 6l8.5 6V6z"/></svg>
|
||||
</button>
|
||||
<button class="mini-control-btn" data-onclick="togglePlayPause()" id="mini-btn-play-pause" title="Play/Pause">
|
||||
<svg viewBox="0 0 24 24" id="mini-play-pause-icon">
|
||||
<button class="mini-control-btn" data-onclick="togglePlayPause()" id="mini-btn-play-pause" title="Play/Pause" aria-label="Play/Pause">
|
||||
<svg viewBox="0 0 24 24" id="mini-play-pause-icon" aria-hidden="true" focusable="false">
|
||||
<path d="M8 5v14l11-7z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<button class="mini-control-btn mini-nav-btn" data-onclick="nextTrack()" data-i18n-title="player.next" title="Next">
|
||||
<svg viewBox="0 0 24 24"><path d="M6 18l8.5-6L6 6v12zM16 6v12h2V6h-2z"/></svg>
|
||||
<button class="mini-control-btn mini-nav-btn" data-onclick="nextTrack()" data-i18n-title="player.next" data-i18n-aria-label="player.next" title="Next" aria-label="Next">
|
||||
<svg viewBox="0 0 24 24" aria-hidden="true" focusable="false"><path d="M6 18l8.5-6L6 6v12zM16 6v12h2V6h-2z"/></svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="mini-progress-container">
|
||||
@@ -48,8 +48,8 @@
|
||||
</div>
|
||||
</div>
|
||||
<div class="mini-volume-container">
|
||||
<button class="mini-control-btn" data-onclick="toggleMute()" id="mini-btn-mute" title="Mute">
|
||||
<svg viewBox="0 0 24 24" id="mini-mute-icon">
|
||||
<button class="mini-control-btn" data-onclick="toggleMute()" id="mini-btn-mute" title="Mute" aria-label="Mute" aria-pressed="false">
|
||||
<svg viewBox="0 0 24 24" id="mini-mute-icon" aria-hidden="true" focusable="false">
|
||||
<path d="M3 9v6h4l5 5V4L7 9H3zm13.5 3c0-1.77-1.02-3.29-2.5-4.03v8.05c1.48-.73 2.5-2.25 2.5-4.02z"/>
|
||||
</svg>
|
||||
</button>
|
||||
@@ -88,7 +88,7 @@
|
||||
</div>
|
||||
<div class="header-toolbar">
|
||||
<div id="headerLinks" class="header-links"></div>
|
||||
<a class="header-btn" href="/docs" target="_blank" title="API Documentation" aria-label="API Documentation">
|
||||
<a class="header-btn" href="/docs" target="_blank" rel="noopener noreferrer" referrerpolicy="no-referrer" title="API Documentation" aria-label="API Documentation">
|
||||
<svg viewBox="0 0 24 24"><path fill="currentColor" d="M14 2H6c-1.1 0-2 .9-2 2v16c0 1.1.9 2 2 2h12c1.1 0 2-.9 2-2V8l-6-6zm-1 2l5 5h-5V4zM6 20V4h5v7h7v9H6zm2-4h8v2H8v-2zm0-3h8v2H8v-2z"/></svg>
|
||||
</a>
|
||||
<button class="header-btn" data-onclick="showAboutDialog()" data-i18n-title="about.button_title" title="About" aria-label="About">
|
||||
|
||||
@@ -536,6 +536,8 @@ export async function loadHeaderLinks() {
|
||||
a.href = link.url;
|
||||
a.target = '_blank';
|
||||
a.rel = 'noopener noreferrer';
|
||||
// Prevent leaking the WebUI URL (with ?token=) via Referer.
|
||||
a.referrerPolicy = 'no-referrer';
|
||||
a.className = 'header-link';
|
||||
a.title = link.label || link.url;
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
lastStatus, setLastStatus, currentPlayState, setCurrentPlayState,
|
||||
POSITION_INTERPOLATION_MS, seek, notifyRemoteVolume,
|
||||
getAuthHeaders, hasCredentials,
|
||||
togglePlayPause, nextTrack, previousTrack,
|
||||
} from './core.js';
|
||||
import { updateBackgroundColors } from './background.js';
|
||||
import { loadDisplayMonitors } from './links.js';
|
||||
@@ -381,11 +382,85 @@ function buildVisualizerGradient() {
|
||||
|
||||
function startVisualizerRender() {
|
||||
if (visualizerAnimFrame) return;
|
||||
// Don't even queue a frame while the tab is hidden — rAF still fires on
|
||||
// hidden tabs (throttled but not paused) and would burn CPU + battery
|
||||
// smoothing into bars no one can see. We resume on `visibilitychange`.
|
||||
if (typeof document !== 'undefined' && document.hidden) return;
|
||||
// Cache editorial spectrum bar refs once per start.
|
||||
cacheEditorialSpectrumBars();
|
||||
renderVisualizerFrame();
|
||||
}
|
||||
|
||||
// ─── OS Media Session integration ─────────────────────────────
|
||||
// Hooks the page into the system's media session so headset / lockscreen /
|
||||
// Bluetooth play-pause-skip buttons drive the active track. Action handlers
|
||||
// are set once and never re-registered; only the metadata + playback state
|
||||
// flip when a track changes.
|
||||
let _mediaSessionInitialised = false;
|
||||
let _lastMediaSessionKey = '';
|
||||
function syncMediaSession(status) {
|
||||
if (typeof navigator === 'undefined' || !('mediaSession' in navigator)) return;
|
||||
const session = navigator.mediaSession;
|
||||
|
||||
if (!_mediaSessionInitialised) {
|
||||
const setHandler = (name, fn) => {
|
||||
try { session.setActionHandler(name, fn); } catch { /* unsupported action */ }
|
||||
};
|
||||
setHandler('play', () => togglePlayPause());
|
||||
setHandler('pause', () => togglePlayPause());
|
||||
setHandler('nexttrack', () => nextTrack());
|
||||
setHandler('previoustrack', () => previousTrack());
|
||||
setHandler('seekto', (ev) => { if (ev && typeof ev.seekTime === 'number') seek(ev.seekTime); });
|
||||
_mediaSessionInitialised = true;
|
||||
}
|
||||
|
||||
// Track-identity key — re-build metadata only when title/artist/album change.
|
||||
const artworkSrc = status && status.album_art_url ? '/api/media/artwork' : '';
|
||||
const key = `${status.title || ''}|${status.artist || ''}|${status.album || ''}|${artworkSrc}`;
|
||||
if (key !== _lastMediaSessionKey) {
|
||||
_lastMediaSessionKey = key;
|
||||
try {
|
||||
session.metadata = new MediaMetadata({
|
||||
title: status.title || '',
|
||||
artist: status.artist || '',
|
||||
album: status.album || '',
|
||||
artwork: artworkSrc ? [{ src: artworkSrc, sizes: '512x512', type: 'image/*' }] : [],
|
||||
});
|
||||
} catch { /* MediaMetadata unsupported on very old browsers */ }
|
||||
}
|
||||
|
||||
session.playbackState =
|
||||
status.state === 'playing' ? 'playing'
|
||||
: status.state === 'paused' ? 'paused'
|
||||
: 'none';
|
||||
|
||||
if (typeof session.setPositionState === 'function'
|
||||
&& status.duration && status.duration > 0
|
||||
&& typeof status.position === 'number') {
|
||||
try {
|
||||
session.setPositionState({
|
||||
duration: status.duration,
|
||||
position: Math.min(status.position, status.duration),
|
||||
playbackRate: 1.0,
|
||||
});
|
||||
} catch { /* invalid range — ignore */ }
|
||||
}
|
||||
}
|
||||
|
||||
// Pause / resume the visualizer with tab visibility. Idempotent: called once
|
||||
// at module init below, no-op if no listener support.
|
||||
if (typeof document !== 'undefined' && document.addEventListener) {
|
||||
document.addEventListener('visibilitychange', () => {
|
||||
if (document.hidden) {
|
||||
stopVisualizerRender();
|
||||
} else if (frequencyData) {
|
||||
// Only restart if a payload is live (otherwise startVisualizerRender
|
||||
// would queue a no-op rAF chain forever waiting for one).
|
||||
startVisualizerRender();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
export function stopVisualizerRender() {
|
||||
if (visualizerAnimFrame) {
|
||||
cancelAnimationFrame(visualizerAnimFrame);
|
||||
@@ -903,6 +978,11 @@ export function updateUI(status) {
|
||||
|
||||
updateMuteIcon(status.muted);
|
||||
|
||||
// Wire the OS Media Session so headset buttons, lockscreen controls, and
|
||||
// Bluetooth remotes drive the active media (not the WebUI tab). Cheap and
|
||||
// idempotent — re-running setActionHandler with the same fn is a no-op.
|
||||
syncMediaSession(status);
|
||||
|
||||
const src = resolveMediaSource(status.source);
|
||||
dom.source.textContent = src ? src.name : t('player.unknown_source');
|
||||
dom.sourceIcon.innerHTML = src?.icon || '';
|
||||
|
||||
@@ -80,6 +80,9 @@ export async function displayQuickAccess() {
|
||||
card.href = link.url;
|
||||
card.target = '_blank';
|
||||
card.rel = 'noopener noreferrer';
|
||||
// Prevent the WebUI's URL (which may carry ?token=...) from
|
||||
// leaking to third-party sites via Referer.
|
||||
card.referrerPolicy = 'no-referrer';
|
||||
|
||||
if (link.icon) {
|
||||
const iconEl = document.createElement('div');
|
||||
|
||||
@@ -88,9 +88,22 @@ export function connectWebSocket(token) {
|
||||
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const wsBase = `${protocol}//${window.location.host}/api/media/ws`;
|
||||
const wsUrl = token ? `${wsBase}?token=${encodeURIComponent(token)}` : wsBase;
|
||||
|
||||
const newWs = new WebSocket(wsUrl);
|
||||
// Prefer Sec-WebSocket-Protocol-based auth so the token never appears in
|
||||
// the URL (which would otherwise land in browser history, server access
|
||||
// logs, and Referer headers). Keep the ?token=... fallback for clients
|
||||
// that pre-date this change and don't speak the subprotocol.
|
||||
let newWs;
|
||||
if (token) {
|
||||
try {
|
||||
newWs = new WebSocket(wsBase, [`media-server.token.${token}`]);
|
||||
} catch (e) {
|
||||
console.warn('Subprotocol WS handshake failed, falling back to ?token=', e);
|
||||
newWs = new WebSocket(`${wsBase}?token=${encodeURIComponent(token)}`);
|
||||
}
|
||||
} else {
|
||||
newWs = new WebSocket(wsBase);
|
||||
}
|
||||
activeSocket = newWs;
|
||||
setWs(newWs);
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
{
|
||||
"id": "/",
|
||||
"name": "Media Server",
|
||||
"short_name": "Media",
|
||||
"description": "Remote media player control and file browser",
|
||||
"start_url": "/",
|
||||
"scope": "/",
|
||||
"display": "standalone",
|
||||
"orientation": "any",
|
||||
"background_color": "#121212",
|
||||
"theme_color": "#121212",
|
||||
"background_color": "#0E0D0B",
|
||||
"theme_color": "#0E0D0B",
|
||||
"icons": [
|
||||
{
|
||||
"src": "/static/icons/icon.svg",
|
||||
|
||||
@@ -101,6 +101,9 @@ class TrayManager:
|
||||
|
||||
self._port = port
|
||||
self._on_exit = on_exit
|
||||
# Initialize so the property and any cross-thread reader cannot ever
|
||||
# observe an uninitialized attribute. Set before _on_exit() fires.
|
||||
self._restart_requested = False
|
||||
|
||||
menu = pystray.Menu(
|
||||
pystray.MenuItem("Show UI", self._show_ui, default=True),
|
||||
@@ -123,13 +126,16 @@ class TrayManager:
|
||||
if not _confirm("Media Server", "Restart the server?"):
|
||||
return
|
||||
logger.info("Restart requested from tray")
|
||||
# Set the flag BEFORE signalling exit so the main thread observes it
|
||||
# when it wakes from server_thread.join() — order matters across the
|
||||
# tray/uvicorn handoff.
|
||||
self._restart_requested = True
|
||||
self._on_exit()
|
||||
self._icon.stop()
|
||||
|
||||
@property
|
||||
def restart_requested(self) -> bool:
|
||||
return getattr(self, "_restart_requested", False)
|
||||
return self._restart_requested
|
||||
|
||||
def _shutdown(self, icon: "pystray.Icon", item: "pystray.MenuItem") -> None:
|
||||
if not _confirm("Media Server", "Shut down the server?"):
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Tests for token scope hierarchy + back-compat with legacy bare-string tokens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from media_server.config import Settings, TokenSpec
|
||||
|
||||
|
||||
def test_bare_string_token_promotes_to_admin_scope():
|
||||
"""Legacy `label: <token>` form must still work and grant admin."""
|
||||
s = Settings(api_tokens={"legacy": "deadbeef-deadbeef-deadbeef-deadbeef"})
|
||||
spec = s.api_tokens["legacy"]
|
||||
assert isinstance(spec, TokenSpec)
|
||||
assert spec.token == "deadbeef-deadbeef-deadbeef-deadbeef"
|
||||
assert spec.scopes == ["admin"]
|
||||
assert spec.grants("admin")
|
||||
assert spec.grants("control")
|
||||
assert spec.grants("read")
|
||||
|
||||
|
||||
def test_dict_token_with_explicit_scopes():
|
||||
s = Settings(api_tokens={
|
||||
"ha": {"token": "aaaaaaaaaaaaaaaa", "scopes": ["read", "control"]},
|
||||
})
|
||||
spec = s.api_tokens["ha"]
|
||||
assert spec.grants("control")
|
||||
assert spec.grants("read")
|
||||
assert not spec.grants("admin")
|
||||
|
||||
|
||||
def test_read_only_scope_grants_only_read():
|
||||
spec = TokenSpec(token="xxxxxxxxxxxxxxxx", scopes=["read"])
|
||||
assert spec.grants("read")
|
||||
assert not spec.grants("control")
|
||||
assert not spec.grants("admin")
|
||||
|
||||
|
||||
def test_admin_scope_implies_control_and_read():
|
||||
spec = TokenSpec(token="xxxxxxxxxxxxxxxx", scopes=["admin"])
|
||||
assert spec.grants("read")
|
||||
assert spec.grants("control")
|
||||
assert spec.grants("admin")
|
||||
|
||||
|
||||
def test_unknown_scope_rejected():
|
||||
with pytest.raises(ValueError, match="unknown scopes"):
|
||||
TokenSpec(token="xxxxxxxxxxxxxxxx", scopes=["root"])
|
||||
|
||||
|
||||
def test_empty_scopes_rejected():
|
||||
with pytest.raises(ValueError, match="at least one"):
|
||||
TokenSpec(token="xxxxxxxxxxxxxxxx", scopes=[])
|
||||
|
||||
|
||||
def test_short_token_rejected():
|
||||
with pytest.raises(ValueError):
|
||||
TokenSpec(token="short", scopes=["read"])
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Path traversal defence for BrowserService.validate_path.
|
||||
|
||||
The browser endpoint is the single most security-critical filesystem entry
|
||||
point in the app: it serves file contents and folder listings to the WebUI.
|
||||
A bypass here = arbitrary read of any file the server process can see.
|
||||
|
||||
The current implementation signals rejection by *raising* (ValueError for
|
||||
traversal/NUL/unknown folder, FileNotFoundError for non-existent absolute
|
||||
paths). Either rejection mode is acceptable — these tests assert that the
|
||||
adversarial input never returns a path *inside* the configured base.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from media_server.services.browser_service import BrowserService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_media_folder():
|
||||
"""A real temp dir registered as a media folder for the test duration."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
base = Path(tmp).resolve()
|
||||
(base / "ok.mp3").write_bytes(b"id3")
|
||||
(base / "sub").mkdir()
|
||||
(base / "sub" / "nested.mp3").write_bytes(b"id3")
|
||||
|
||||
from media_server.config import MediaFolderConfig
|
||||
folders = {"test": MediaFolderConfig(path=str(base), label="Test", enabled=True)}
|
||||
with patch("media_server.services.browser_service.settings.media_folders", folders):
|
||||
yield base
|
||||
|
||||
|
||||
def _is_rejected(folder_id: str, rel: str) -> bool:
|
||||
"""Helper: True iff validate_path either raises or returns None."""
|
||||
try:
|
||||
result = BrowserService.validate_path(folder_id, rel)
|
||||
except (ValueError, FileNotFoundError, OSError):
|
||||
return True
|
||||
return result is None
|
||||
|
||||
|
||||
def test_validate_path_accepts_a_real_file(tmp_media_folder: Path):
|
||||
p = BrowserService.validate_path("test", "ok.mp3")
|
||||
assert p is not None
|
||||
assert p.is_file()
|
||||
# Defence-in-depth: resolved path must live inside the base.
|
||||
assert tmp_media_folder in p.resolve().parents or p.resolve().parent == tmp_media_folder
|
||||
|
||||
|
||||
def test_validate_path_accepts_nested(tmp_media_folder: Path):
|
||||
p = BrowserService.validate_path("test", "sub/nested.mp3")
|
||||
assert p is not None
|
||||
|
||||
|
||||
def test_unknown_folder_rejected(tmp_media_folder: Path):
|
||||
assert _is_rejected("ghost", "ok.mp3")
|
||||
|
||||
|
||||
def test_dotdot_traversal_rejected(tmp_media_folder: Path):
|
||||
assert _is_rejected("test", "../etc/passwd")
|
||||
assert _is_rejected("test", "..\\..\\Windows\\System32")
|
||||
assert _is_rejected("test", "sub/../../etc/passwd")
|
||||
|
||||
|
||||
def test_absolute_path_rejected(tmp_media_folder: Path):
|
||||
assert _is_rejected("test", "/etc/passwd")
|
||||
assert _is_rejected("test", "C:\\Windows\\System32")
|
||||
assert _is_rejected("test", "C:/Windows")
|
||||
|
||||
|
||||
def test_unc_path_rejected(tmp_media_folder: Path):
|
||||
assert _is_rejected("test", "\\\\server\\share")
|
||||
assert _is_rejected("test", "//server/share")
|
||||
|
||||
|
||||
def test_null_byte_rejected(tmp_media_folder: Path):
|
||||
assert _is_rejected("test", "ok.mp3\x00.png")
|
||||
assert _is_rejected("test", "sub\x00/nested.mp3")
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Atomic config writes + POSIX permission hardening."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import stat
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from media_server.config import _restrict_config_perms, _write_yaml_atomic
|
||||
|
||||
|
||||
def test_atomic_write_round_trip():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "config.yaml"
|
||||
_write_yaml_atomic(path, {"port": 8765, "host": "127.0.0.1"})
|
||||
assert path.exists()
|
||||
# Tmp file from the rename should be gone.
|
||||
assert not path.with_suffix(path.suffix + ".tmp").exists()
|
||||
# Contents are valid YAML and round-trip.
|
||||
import yaml
|
||||
data = yaml.safe_load(path.read_text())
|
||||
assert data == {"port": 8765, "host": "127.0.0.1"}
|
||||
|
||||
|
||||
def test_atomic_write_replaces_existing():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "config.yaml"
|
||||
path.write_text("old: 1\n")
|
||||
_write_yaml_atomic(path, {"new": 2})
|
||||
import yaml
|
||||
assert yaml.safe_load(path.read_text()) == {"new": 2}
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="POSIX-only permission check")
|
||||
def test_restrict_config_perms_posix():
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "config.yaml"
|
||||
path.write_text("token: secret\n")
|
||||
_restrict_config_perms(path)
|
||||
mode = stat.S_IMODE(os.stat(path).st_mode)
|
||||
# Owner read+write only.
|
||||
assert mode == 0o600, f"got {oct(mode)}"
|
||||
@@ -8,10 +8,6 @@ Windows/Linux/macOS probes themselves are exercised through manual runs.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from media_server.services import foreground_service as fg
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Tag-name validation in the Gitea release provider.
|
||||
|
||||
Whitelist regex protects the URL we broadcast to clients from any path
|
||||
traversal or character-set abuse in a hostile (or MITM'd) upstream response.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from media_server.services.gitea_release_provider import _TAG_RE
|
||||
|
||||
|
||||
def test_accepts_plain_semver():
|
||||
assert _TAG_RE.match("1.0.0")
|
||||
assert _TAG_RE.match("v1.0.0")
|
||||
assert _TAG_RE.match("0.3.7")
|
||||
|
||||
|
||||
def test_accepts_pre_release_suffix():
|
||||
assert _TAG_RE.match("v1.0.0-alpha.1")
|
||||
assert _TAG_RE.match("v2.3.4-rc.10")
|
||||
assert _TAG_RE.match("v0.2.7+build.42")
|
||||
|
||||
|
||||
def test_rejects_path_traversal():
|
||||
assert not _TAG_RE.match("../etc/passwd")
|
||||
assert not _TAG_RE.match("v1.0.0/../../evil")
|
||||
assert not _TAG_RE.match("v1.0.0/secret")
|
||||
|
||||
|
||||
def test_rejects_url_injection():
|
||||
assert not _TAG_RE.match("v1.0.0?evil=1")
|
||||
assert not _TAG_RE.match("v1.0.0#frag")
|
||||
assert not _TAG_RE.match("v1.0.0 OR 1=1")
|
||||
assert not _TAG_RE.match("https://evil.example/")
|
||||
|
||||
|
||||
def test_rejects_empty_and_garbage():
|
||||
assert not _TAG_RE.match("")
|
||||
assert not _TAG_RE.match("not-a-version")
|
||||
assert not _TAG_RE.match("v")
|
||||
|
||||
|
||||
def test_rejects_excessively_long_suffix():
|
||||
long_suffix = "x" * 40
|
||||
assert not _TAG_RE.match(f"v1.0.0-{long_suffix}")
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Token-bucket rate limiter behaviour."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from media_server.services import rate_limit
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_state():
|
||||
rate_limit._state.clear()
|
||||
rate_limit._LAST_CLEANUP = 0.0
|
||||
yield
|
||||
rate_limit._state.clear()
|
||||
|
||||
|
||||
def test_allows_up_to_capacity_then_blocks(monkeypatch):
|
||||
"""Default execute bucket = 10/min."""
|
||||
peer = "10.0.0.1"
|
||||
for i in range(10):
|
||||
ok, retry = rate_limit.check("execute", peer)
|
||||
assert ok, f"expected allow on attempt {i + 1}, got block (retry={retry})"
|
||||
ok, retry = rate_limit.check("execute", peer)
|
||||
assert ok is False
|
||||
assert retry is not None and retry > 0
|
||||
|
||||
|
||||
def test_different_peers_independent():
|
||||
for _ in range(10):
|
||||
assert rate_limit.check("execute", "10.0.0.1")[0]
|
||||
# Different peer should still be allowed.
|
||||
assert rate_limit.check("execute", "10.0.0.2")[0]
|
||||
|
||||
|
||||
def test_unknown_bucket_uses_default():
|
||||
peer = "10.0.0.3"
|
||||
# default = 60/min — first call always allowed.
|
||||
allowed, _ = rate_limit.check("nonexistent-bucket", peer)
|
||||
assert allowed
|
||||
|
||||
|
||||
def test_auth_bucket_is_strict():
|
||||
"""auth bucket = 5/min."""
|
||||
peer = "10.0.0.4"
|
||||
for _ in range(5):
|
||||
assert rate_limit.check("auth", peer)[0]
|
||||
blocked, retry = rate_limit.check("auth", peer)
|
||||
assert not blocked
|
||||
assert retry is not None
|
||||
|
||||
|
||||
def test_refill_eventually_unblocks(monkeypatch):
|
||||
"""Verify the bucket refills — exhaust then wait one refill period."""
|
||||
peer = "10.0.0.5"
|
||||
# Replace BUCKETS with a fast-refilling one for the test only.
|
||||
monkeypatch.setitem(
|
||||
rate_limit.BUCKETS,
|
||||
"fast-test",
|
||||
rate_limit.BucketConfig(capacity=2, refill_per_sec=10.0),
|
||||
)
|
||||
assert rate_limit.check("fast-test", peer)[0]
|
||||
assert rate_limit.check("fast-test", peer)[0]
|
||||
assert not rate_limit.check("fast-test", peer)[0]
|
||||
time.sleep(0.15) # 0.15 * 10 = 1.5 tokens
|
||||
assert rate_limit.check("fast-test", peer)[0]
|
||||
@@ -0,0 +1,77 @@
|
||||
"""Validation rules for script parameters (type coercion, regex pattern)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from media_server.config import ScriptParameterConfig
|
||||
from media_server.routes.scripts import _validate_params
|
||||
|
||||
|
||||
def _defs(**kwargs) -> dict[str, ScriptParameterConfig]:
|
||||
return {name: ScriptParameterConfig(**spec) for name, spec in kwargs.items()}
|
||||
|
||||
|
||||
def test_unknown_param_rejected():
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
_validate_params({"x": "1"}, _defs())
|
||||
assert ei.value.status_code == 400
|
||||
assert "Unknown" in ei.value.detail
|
||||
|
||||
|
||||
def test_missing_required_rejected():
|
||||
defs = _defs(name={"type": "string", "required": True})
|
||||
with pytest.raises(HTTPException, match="missing"):
|
||||
_validate_params({}, defs)
|
||||
|
||||
|
||||
def test_integer_coercion_and_bounds():
|
||||
defs = _defs(volume={"type": "integer", "min": 0, "max": 100})
|
||||
out = _validate_params({"volume": "42"}, defs)
|
||||
assert out == {"SCRIPT_PARAM_VOLUME": "42"}
|
||||
|
||||
with pytest.raises(HTTPException, match="<="):
|
||||
_validate_params({"volume": 200}, defs)
|
||||
with pytest.raises(HTTPException, match=">="):
|
||||
_validate_params({"volume": -1}, defs)
|
||||
with pytest.raises(HTTPException, match="integer"):
|
||||
_validate_params({"volume": "not-a-number"}, defs)
|
||||
|
||||
|
||||
def test_boolean_coercion():
|
||||
defs = _defs(flag={"type": "boolean"})
|
||||
assert _validate_params({"flag": "true"}, defs) == {"SCRIPT_PARAM_FLAG": "True"}
|
||||
assert _validate_params({"flag": "no"}, defs) == {"SCRIPT_PARAM_FLAG": "False"}
|
||||
with pytest.raises(HTTPException, match="boolean"):
|
||||
_validate_params({"flag": "maybe"}, defs)
|
||||
|
||||
|
||||
def test_select_rejects_non_option():
|
||||
defs = _defs(mode={"type": "select", "options": ["a", "b", "c"]})
|
||||
assert _validate_params({"mode": "a"}, defs) == {"SCRIPT_PARAM_MODE": "a"}
|
||||
with pytest.raises(HTTPException, match="must be one of"):
|
||||
_validate_params({"mode": "z"}, defs)
|
||||
|
||||
|
||||
def test_pattern_enforced_on_string():
|
||||
"""Regex pattern is the defence against shell metachars in shell=true scripts."""
|
||||
defs = _defs(host={"type": "string", "pattern": r"^[a-z0-9.\-]+$"})
|
||||
assert _validate_params({"host": "example.com"}, defs) == {"SCRIPT_PARAM_HOST": "example.com"}
|
||||
with pytest.raises(HTTPException, match="pattern"):
|
||||
_validate_params({"host": "evil & calc.exe"}, defs)
|
||||
with pytest.raises(HTTPException, match="pattern"):
|
||||
_validate_params({"host": "$(rm -rf /)"}, defs)
|
||||
|
||||
|
||||
def test_pattern_can_disallow_empty():
|
||||
defs = _defs(host={"type": "string", "pattern": r"^[a-z]+$"})
|
||||
with pytest.raises(HTTPException, match="pattern"):
|
||||
_validate_params({"host": ""}, defs)
|
||||
|
||||
|
||||
def test_invalid_pattern_in_config_fails_closed():
|
||||
defs = _defs(host={"type": "string", "pattern": r"["}) # unmatched bracket
|
||||
with pytest.raises(HTTPException) as ei:
|
||||
_validate_params({"host": "x"}, defs)
|
||||
assert ei.value.status_code == 500
|
||||
Reference in New Issue
Block a user