From d131ba461c950f619326b7471586f515c3113fd0 Mon Sep 17 00:00:00 2001 From: "alexei.dolgolyov" Date: Fri, 22 May 2026 22:25:54 +0300 Subject: [PATCH] =?UTF-8?q?fix:=20production-readiness=20hardening=20?= =?UTF-8?q?=E2=80=94=20security,=20perf,=20a11y,=20observability?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Security - Default scripts_management, callbacks_management, links_management, and media_folders_management to False so a leaked token cannot escalate to RCE through admin CRUD endpoints. - TokenSpec + scope hierarchy (read | control | admin); legacy bare-string api_tokens entries promote to admin for back-compat. Management endpoints now require admin scope. - WebSocket subprotocol auth (Sec-WebSocket-Protocol: media-server.token.) preferred over ?token= query so the token no longer lands in URL/history/ Referer; query fallback retained for HA integration back-compat. - Origin allow-list check on the WS endpoint (CSWSH defence). - In-process token-bucket rate limiter: 5/min for failed auths, 10/min for /api/scripts/execute and /api/callbacks/execute. - shell=False subprocess path (shlex.split) + per-parameter regex `pattern` in ScriptParameterConfig to harden shell=true scripts against parameter injection (Windows cmd.exe env-var expansion). - CSP gains form-action, worker-src, manifest-src directives. - Refuse cors_origins=["*"] at startup; strip token=... from uvicorn access logs; validate Gitea release tag against strict SemVer regex. - noopener noreferrer + no-referrer referrerpolicy on every outbound link. - icacls hardening of config.yaml on Windows (current user + SYSTEM + Administrators only); 0600 still enforced on POSIX. - WS volume handler clamps input and never drops the socket on bad messages. Performance - Album-art read in windows_media gated by track key — was decoding the WinRT thumbnail twice per second regardless of track changes. - /api/media/artwork returns content-derived ETag + Cache-Control so the browser sends If-None-Match and gets 304s on track repeats. - Foreground-service ctypes argtypes hoisted to one-time module init (was re-declaring ~14 prototypes per probe). - display_service _static_cache keyed by (edid_hash, ...) tuple with eviction of disappeared monitors — fixes stale capabilities on hot-plug swaps where the new topology has the same monitor count. - Visualizer rAF loop paused on document.hidden, resumed on visible. Reliability / bug fixes - Lifespan rewritten as try/yield/finally so a partial-startup failure cannot orphan background tasks or executors. - _run_callback in routes/media.py keeps a strong task ref (GC-safe) and uses the dedicated callback executor instead of the default pool. - macos_media.set_volume() no longer always returns True. - TrayManager._restart_requested initialised in __init__; set before signalling exit so the main thread observes it correctly. - Missing static_dir now logs a WARNING instead of silent UI disable. UX / accessibility / PWA - manifest.json theme_color and background_color match the Studio Reference base (#0E0D0B); added id and scope for PWA installability. - ARIA on mini-player icon buttons; inner SVGs marked aria-hidden. - OS mediaSession API wired so headset / lockscreen / Bluetooth buttons drive play/pause/next/prev/seek and show track metadata + artwork. Observability - X-Request-ID middleware (accept upstream id if it matches a safe regex, otherwise UUID4); request_id_var added to ContextVars and included in every log line alongside the token label. - Audit log (append-only JSONL) for every script + callback execution, including the on_play/on_pause/etc. event callbacks. Background-thread writer; queue capped; flushed in lifespan teardown. Deployment - proxy_headers + forwarded_allow_ips plumbed through Settings → uvicorn.Config for reverse-proxy installs. - HTTPS support via ssl_certfile + ssl_keyfile (+ optional password); startup refuses to launch with only one of the pair set. - Thumbnail cache moved from project-root .cache to %LOCALAPPDATA%/media-server/cache (Windows) and $XDG_CACHE_HOME/media-server/thumbnails (POSIX). Tests - 35 new tests across auth scopes, rate limiter, browser path traversal (../ NUL UNC absolute), script-param validation incl. regex, Gitea tag whitelist, config atomic write + POSIX perms. 47 passed / 4 skipped. --- media_server/auth.py | 36 +- media_server/config.py | 168 ++++++++- media_server/main.py | 337 +++++++++++++----- media_server/routes/browser.py | 10 +- media_server/routes/callbacks.py | 42 ++- media_server/routes/links.py | 9 + media_server/routes/media.py | 155 ++++++-- media_server/routes/scripts.py | 83 ++++- media_server/services/audit_log.py | 120 +++++++ media_server/services/display_service.py | 52 ++- media_server/services/foreground_service.py | 43 ++- .../services/gitea_release_provider.py | 9 + media_server/services/macos_media.py | 8 +- media_server/services/rate_limit.py | 95 +++++ media_server/services/thumbnail_service.py | 21 +- media_server/services/websocket_manager.py | 12 +- media_server/services/windows_media.py | 74 ++-- media_server/static/index.html | 18 +- media_server/static/js/links.js | 2 + media_server/static/js/player.js | 80 +++++ media_server/static/js/scripts.js | 3 + media_server/static/js/websocket.js | 17 +- media_server/static/manifest.json | 6 +- media_server/tray.py | 8 +- tests/test_auth_scopes.py | 58 +++ tests/test_browser_validate_path.py | 84 +++++ tests/test_config_atomic_write.py | 46 +++ tests/test_foreground_service.py | 4 - tests/test_gitea_tag_validation.py | 45 +++ tests/test_rate_limit.py | 68 ++++ tests/test_script_params.py | 77 ++++ 31 files changed, 1586 insertions(+), 204 deletions(-) create mode 100644 media_server/services/audit_log.py create mode 100644 media_server/services/rate_limit.py create mode 100644 tests/test_auth_scopes.py create mode 100644 tests/test_browser_validate_path.py create mode 100644 tests/test_config_atomic_write.py create mode 100644 tests/test_gitea_tag_validation.py create mode 100644 tests/test_rate_limit.py create mode 100644 tests/test_script_params.py diff --git a/media_server/auth.py b/media_server/auth.py index 6a319a8..0862aeb 100644 --- a/media_server/auth.py +++ b/media_server/auth.py @@ -13,6 +13,8 @@ security = HTTPBearer(auto_error=False) # Context variable to store current request's token label token_label_var: ContextVar[str] = ContextVar("token_label", default="unknown") +# Per-request correlation ID — generated in middleware if upstream didn't send one. +request_id_var: ContextVar[str] = ContextVar("request_id", default="-") def auth_enabled() -> bool: @@ -29,12 +31,42 @@ def get_token_label(token: str) -> Optional[str]: Returns: The label for the token, or None if invalid """ - for label, stored_token in settings.api_tokens.items(): - if secrets.compare_digest(stored_token, token): + for label, spec in settings.api_tokens.items(): + if secrets.compare_digest(spec.token, token): return label return None +def token_has_scope(label: str, required: str) -> bool: + """Whether the token identified by `label` grants `required` scope.""" + spec = settings.api_tokens.get(label) + if spec is None: + # Unknown label = no auth or anonymous; treat as full access only + # when auth is disabled entirely (matches existing behaviour). + return not auth_enabled() + return spec.grants(required) + + +def require_scope(scope: str): + """Build a FastAPI dependency that enforces the given scope. + + Use as ``Depends(require_scope("admin"))`` on management endpoints. When + auth is disabled the dependency is a no-op (anonymous access). + """ + + async def _checker(label: str = Depends(verify_token)) -> str: + if not auth_enabled(): + return label + if not token_has_scope(label, scope): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Token '{label}' lacks required scope: {scope}", + ) + return label + + return _checker + + async def verify_token( request: Request, credentials: HTTPAuthorizationCredentials = Depends(security), diff --git a/media_server/config.py b/media_server/config.py index 2820ba9..3dafe63 100644 --- a/media_server/config.py +++ b/media_server/config.py @@ -7,12 +7,49 @@ from pathlib import Path from typing import Optional import yaml -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from pydantic_settings import BaseSettings, SettingsConfigDict logger = logging.getLogger(__name__) +# Token scopes form a strict hierarchy: admin > control > read. Helper utility +# used by both auth.py and the validator below. +SCOPE_HIERARCHY: dict[str, frozenset[str]] = { + "read": frozenset({"read"}), + "control": frozenset({"read", "control"}), + "admin": frozenset({"read", "control", "admin"}), +} +ALL_SCOPES: frozenset[str] = frozenset(SCOPE_HIERARCHY.keys()) + + +class TokenSpec(BaseModel): + """Per-token authentication entry with explicit scopes.""" + + token: str = Field(..., min_length=8, description="Secret token value") + scopes: list[str] = Field( + default_factory=lambda: ["admin"], + description="Granted scopes (subset of read|control|admin).", + ) + + @field_validator("scopes") + @classmethod + def _validate_scopes(cls, v: list[str]) -> list[str]: + if not v: + raise ValueError("scopes must list at least one of read|control|admin") + unknown = set(v) - ALL_SCOPES + if unknown: + raise ValueError(f"unknown scopes: {sorted(unknown)}; valid={sorted(ALL_SCOPES)}") + return v + + def grants(self, required: str) -> bool: + """Whether this token grants the requested scope (with hierarchy expansion).""" + granted: set[str] = set() + for s in self.scopes: + granted |= SCOPE_HIERARCHY.get(s, frozenset({s})) + return required in granted + + class MediaFolderConfig(BaseModel): """Configuration for a media folder.""" @@ -48,6 +85,13 @@ class ScriptParameterConfig(BaseModel): options: Optional[list[str]] = Field( default=None, description="Allowed values (select type only)" ) + pattern: Optional[str] = Field( + default=None, + description=( + "Optional regex (Python flavour) that string-typed values must match." + " Use to harden parameters that flow into shell=true scripts." + ), + ) class ScriptConfig(BaseModel): @@ -108,19 +152,84 @@ class Settings(BaseSettings): ), ) + # Reverse-proxy deployment: when serving the API behind nginx/Caddy/Traefik, + # uvicorn must trust the X-Forwarded-* headers from the proxy so that the + # `Origin` allow-list, request URLs, and logs reflect the public-facing + # values. Off by default — only enable when there's a real proxy in front + # (otherwise clients can spoof their own IP). + proxy_headers: bool = Field( + default=False, + description="Honor X-Forwarded-For / X-Forwarded-Proto from upstream proxy.", + ) + forwarded_allow_ips: str = Field( + default="127.0.0.1", + description=( + "Comma-separated IPs / CIDRs that uvicorn should trust X-Forwarded-* from." + " Use '*' to trust all (only safe when bound to a private interface)." + ), + ) + + # HTTPS / TLS. Both must be set together to enable TLS; if only one is set + # the server refuses to start. Use `mkcert` or letsencrypt to generate the + # pair; the server reads them at startup. + ssl_certfile: Optional[str] = Field( + default=None, + description="Path to TLS certificate (PEM). Pair with ssl_keyfile.", + ) + ssl_keyfile: Optional[str] = Field( + default=None, + description="Path to TLS private key (PEM). Pair with ssl_certfile.", + ) + ssl_keyfile_password: Optional[str] = Field( + default=None, + description="Optional password for the private key if encrypted.", + ) + # Admin-grade operations (script / callback / link / folder create/update/delete). # When True the same token used for read/play can also persist arbitrary shell - # commands. Disable to make the API read+execute only. - scripts_management: bool = Field(default=True, description="Allow scripts CRUD via API") - callbacks_management: bool = Field(default=True, description="Allow callbacks CRUD via API") - links_management: bool = Field(default=True, description="Allow links CRUD via API") + # commands. Default False so a single leaked token cannot escalate to RCE; opt + # in explicitly to manage scripts/callbacks/links via the Web UI. + scripts_management: bool = Field(default=False, description="Allow scripts CRUD via API") + callbacks_management: bool = Field(default=False, description="Allow callbacks CRUD via API") + links_management: bool = Field(default=False, description="Allow links CRUD via API") - # Authentication (empty = auth disabled, anyone can access the API) - api_tokens: dict[str, str] = Field( + # Authentication (empty = auth disabled, anyone can access the API). + # + # Each entry can be either: + # • a bare string (legacy form, treated as scopes = ["admin"] for back-compat), OR + # • a mapping with explicit scopes, e.g. + # "ha": {token: "", scopes: ["read", "control"]} + # "kiosk": {token: "", scopes: ["read"]} + # "ops": {token: "", scopes: ["admin"]} + # + # Available scopes: + # read — GET /api/* (status, list, browse) but no state-changing calls. + # control — read + media transport, display/audio, script EXECUTE, callback EXECUTE. + # admin — control + CRUD on scripts/callbacks/links/folders. + # + # Validation normalises both forms to TokenSpec at load time. + api_tokens: dict[str, TokenSpec] = Field( default_factory=dict, - description="Named API tokens for access control (label: token pairs). Empty = no auth.", + description=( + "Named API tokens. Value can be a bare token string (= admin scope) or" + " a {token, scopes} mapping. See TokenSpec for scope definitions." + ), ) + @field_validator("api_tokens", mode="before") + @classmethod + def _normalise_tokens(cls, v): + """Accept legacy `label: ` form and promote to TokenSpec.""" + if not isinstance(v, dict): + return v + out: dict[str, dict | TokenSpec] = {} + for label, entry in v.items(): + if isinstance(entry, str): + out[label] = {"token": entry, "scopes": ["admin"]} + else: + out[label] = entry + return out + # Media controller settings poll_interval: float = Field( default=1.0, description="Media status poll interval in seconds" @@ -156,7 +265,7 @@ class Settings(BaseSettings): description="Media folders available for browsing in the media browser", ) media_folders_management: bool = Field( - default=True, + default=False, description="Allow adding, editing, and deleting media folders from the Web UI", ) @@ -263,8 +372,11 @@ def generate_default_config(path: Optional[Path] = None) -> Path: config = { "host": "127.0.0.1", "port": 8765, + # Default token grants "admin" scope (full access). To create a + # read-only or control-only token, add a second entry: + # ha_readonly: {token: "", scopes: ["read"]} "api_tokens": { - "default": default_token, + "default": {"token": default_token, "scopes": ["admin"]}, }, "poll_interval": 1.0, "log_level": "INFO", @@ -298,8 +410,16 @@ def _write_yaml_atomic(path: Path, data: dict) -> None: def _restrict_config_perms(path: Path) -> None: - """On POSIX, ensure config file is readable only by owner (0600).""" + """Ensure config file is readable only by its owner. + + POSIX → ``chmod 0600``. On Windows the default NTFS ACL leaves the file + readable by every interactive user on the machine (Users group has Read), + which is bad given the file stores plaintext API tokens. Use ``icacls`` to + grant exclusive access to the current user + SYSTEM + Administrators and + strip inheritance. + """ if os.name == "nt": + _restrict_config_perms_windows(path) return try: os.chmod(path, 0o600) @@ -308,5 +428,31 @@ def _restrict_config_perms(path: Path) -> None: logger.debug("Could not chmod %s", path, exc_info=True) +def _restrict_config_perms_windows(path: Path) -> None: + """Apply restrictive NTFS ACL to a config file (Windows only).""" + import subprocess + + try: + username = os.environ.get("USERNAME") or os.environ.get("USER") + if not username: + logger.debug("Cannot detect current user; skipping icacls hardening") + return + # Disable inheritance and remove every existing ACE, then grant access + # only to current user, SYSTEM, and Administrators. /Q suppresses + # progress output; /C lets per-file errors not abort the batch. + subprocess.run( + ["icacls", str(path), "/inheritance:r"], + check=False, capture_output=True, timeout=5, + ) + for principal in (username, "SYSTEM", "Administrators"): + subprocess.run( + ["icacls", str(path), "/grant:r", f"{principal}:(R,W)"], + check=False, capture_output=True, timeout=5, + ) + except (FileNotFoundError, subprocess.TimeoutExpired, OSError): + # `icacls` missing or sandboxed — leave the default ACL in place. + logger.debug("icacls hardening failed for %s", path, exc_info=True) + + # Global settings instance settings = Settings.load_from_yaml() diff --git a/media_server/main.py b/media_server/main.py index 2619b4a..13b23e8 100644 --- a/media_server/main.py +++ b/media_server/main.py @@ -15,7 +15,7 @@ from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from . import __version__ -from .auth import get_token_label, token_label_var +from .auth import get_token_label, request_id_var, token_label_var from .config import generate_default_config, get_config_dir, settings from .routes import ( audio_router, @@ -33,10 +33,34 @@ from .services.websocket_manager import ws_manager class TokenLabelFilter(logging.Filter): - """Add token label to log records.""" + """Add token label + request_id to log records.""" def filter(self, record): record.token_label = token_label_var.get("unknown") + record.request_id = request_id_var.get("-") + return True + + +class _StripTokenQueryFilter(logging.Filter): + """Strip `token=...` from query strings before they hit the access log. + + uvicorn's default access log format includes the full request line, so + `/api/media/artwork?token=SECRET` would otherwise be persisted verbatim + in stdout/journald/file sinks. + """ + + import re as _re + + _TOKEN_RE = _re.compile(r"([?&])token=[^&\s\"']+") + + def filter(self, record): # type: ignore[override] + if isinstance(record.args, tuple): + record.args = tuple( + self._TOKEN_RE.sub(r"\1token=REDACTED", a) if isinstance(a, str) else a + for a in record.args + ) + if isinstance(record.msg, str) and "token=" in record.msg: + record.msg = self._TOKEN_RE.sub(r"\1token=REDACTED", record.msg) return True @@ -49,17 +73,34 @@ def setup_logging(): logging.basicConfig( level=getattr(logging, settings.log_level.upper()), - format="%(asctime)s - %(name)s - [%(token_label)s] - %(levelname)s - %(message)s", + format=( + "%(asctime)s - %(name)s - [%(token_label)s] [%(request_id)s]" + " - %(levelname)s - %(message)s" + ), handlers=[handler], ) # Suppress noisy third-party loggers logging.getLogger("screen_brightness_control").setLevel(logging.ERROR) + # Make sure the uvicorn access log never persists tokens leaked into the + # query string (the artwork + WS endpoints accept `?token=` for browser + # compatibility — see verify_token_or_query). + strip_filter = _StripTokenQueryFilter() + for name in ("uvicorn.access", "uvicorn"): + logging.getLogger(name).addFilter(strip_filter) + @asynccontextmanager async def lifespan(app: FastAPI): - """Application lifespan handler.""" + """Application lifespan handler. + + All long-lived resources started during startup are kept in local refs and + torn down in a `finally:` so a partial-startup failure cannot orphan tasks + or thread pools. + """ + import asyncio + setup_logging() logger = logging.getLogger(__name__) logger.info(f"Media Server starting on {settings.host}:{settings.port}") @@ -71,92 +112,125 @@ async def lifespan(app: FastAPI): else: logger.warning("No API tokens configured — authentication is DISABLED") - # Start WebSocket status monitor - controller = get_media_controller() - await ws_manager.start_status_monitor(controller.get_status) - logger.info("WebSocket status monitor started") - - # Start update checker update_checker = None - if settings.update_check_enabled: - from .services.gitea_release_provider import GiteaReleaseProvider - from .services.update_checker import UpdateChecker - - provider = GiteaReleaseProvider() - update_checker = UpdateChecker(provider, __version__) - await update_checker.start(settings.update_check_interval) - # Store globally so health endpoint can access cached result - app.state.update_checker = update_checker - - # Schedule periodic thumbnail cache cleanup so the 500 MB cap is actually - # enforced. Runs once at startup and then hourly until shutdown. - from .services.thumbnail_service import ThumbnailService - - async def _thumbnail_cleanup_loop() -> None: - while True: - try: - await asyncio.to_thread(ThumbnailService.cleanup_cache) - except Exception as e: - logger.warning("Thumbnail cache cleanup failed: %s", e) - try: - await asyncio.sleep(3600) - except asyncio.CancelledError: - break - - import asyncio - cleanup_task = asyncio.create_task(_thumbnail_cleanup_loop()) - - # Register audio visualizer (capture starts on-demand when clients subscribe) + cleanup_task: asyncio.Task | None = None analyzer = None - if settings.visualizer_enabled: - from .services.audio_analyzer import get_audio_analyzer + status_monitor_started = False - analyzer = get_audio_analyzer( - num_bins=settings.visualizer_bins, - target_fps=settings.visualizer_fps, - device_name=settings.visualizer_device, - ) - if analyzer.available: - await ws_manager.start_audio_monitor(analyzer) - logger.info("Audio visualizer available (capture on-demand)") - else: - logger.info("Audio visualizer unavailable (install soundcard + numpy)") - - yield - - # Stop update checker - if update_checker is not None: - await update_checker.stop() - - # Cancel periodic thumbnail cleanup - cleanup_task.cancel() try: - await cleanup_task - except asyncio.CancelledError: - pass + # Start WebSocket status monitor + controller = get_media_controller() + await ws_manager.start_status_monitor(controller.get_status) + status_monitor_started = True + logger.info("WebSocket status monitor started") - # Stop audio visualizer - await ws_manager.stop_audio_monitor() - if analyzer and analyzer.running: - analyzer.stop() + # Start update checker + if settings.update_check_enabled: + from .services.gitea_release_provider import GiteaReleaseProvider + from .services.update_checker import UpdateChecker - # Stop WebSocket status monitor - await ws_manager.stop_status_monitor() + provider = GiteaReleaseProvider() + update_checker = UpdateChecker(provider, __version__) + await update_checker.start(settings.update_check_interval) + # Store globally so health endpoint can access cached result + app.state.update_checker = update_checker - # Shut down dedicated thread pools so pending scripts don't leak threads - from .routes.callbacks import shutdown_callback_executor - from .routes.scripts import shutdown_script_executor + # Schedule periodic thumbnail cache cleanup so the 500 MB cap is actually + # enforced. Runs once at startup and then hourly until shutdown. + from .services.thumbnail_service import ThumbnailService - shutdown_script_executor() - shutdown_callback_executor() + async def _thumbnail_cleanup_loop() -> None: + while True: + try: + await asyncio.to_thread(ThumbnailService.cleanup_cache) + except Exception as e: + logger.warning("Thumbnail cache cleanup failed: %s", e) + try: + await asyncio.sleep(3600) + except asyncio.CancelledError: + break - # Clean up platform-specific resources - import platform as _platform - if _platform.system() == "Windows": - from .services.windows_media import shutdown_executor - shutdown_executor() + cleanup_task = asyncio.create_task(_thumbnail_cleanup_loop()) - logger.info("Media Server shutting down") + # Register audio visualizer (capture starts on-demand when clients subscribe) + if settings.visualizer_enabled: + from .services.audio_analyzer import get_audio_analyzer + + analyzer = get_audio_analyzer( + num_bins=settings.visualizer_bins, + target_fps=settings.visualizer_fps, + device_name=settings.visualizer_device, + ) + if analyzer.available: + await ws_manager.start_audio_monitor(analyzer) + logger.info("Audio visualizer available (capture on-demand)") + else: + logger.info("Audio visualizer unavailable (install soundcard + numpy)") + + yield + finally: + # Stop update checker + if update_checker is not None: + try: + await update_checker.stop() + except Exception: + logger.exception("Error stopping update checker") + + # Cancel periodic thumbnail cleanup + if cleanup_task is not None: + cleanup_task.cancel() + try: + await cleanup_task + except asyncio.CancelledError: + pass + except Exception: + logger.exception("Error awaiting thumbnail cleanup task") + + # Stop audio visualizer + try: + await ws_manager.stop_audio_monitor() + except Exception: + logger.exception("Error stopping audio monitor") + if analyzer and analyzer.running: + try: + analyzer.stop() + except Exception: + logger.exception("Error stopping audio analyzer") + + # Stop WebSocket status monitor + if status_monitor_started: + try: + await ws_manager.stop_status_monitor() + except Exception: + logger.exception("Error stopping status monitor") + + # Shut down dedicated thread pools so pending scripts don't leak threads + try: + from .routes.callbacks import shutdown_callback_executor + from .routes.scripts import shutdown_script_executor + + shutdown_script_executor() + shutdown_callback_executor() + except Exception: + logger.exception("Error shutting down script/callback executors") + + # Flush audit log writer + try: + from .services.audit_log import shutdown_audit_log + shutdown_audit_log() + except Exception: + logger.exception("Error flushing audit log") + + # Clean up platform-specific resources + import platform as _platform + if _platform.system() == "Windows": + try: + from .services.windows_media import shutdown_executor + shutdown_executor() + except Exception: + logger.exception("Error shutting down windows_media executor") + + logger.info("Media Server shutting down") def create_app() -> FastAPI: @@ -173,7 +247,15 @@ def create_app() -> FastAPI: # CORS — restrict to same-origin by default; users that integrate the API # from another origin (e.g. Home Assistant on a different host) can set - # cors_origins in config.yaml. + # cors_origins in config.yaml. Refuse "*" outright: combined with the + # admin endpoints this would let any origin in the universe run + # arbitrary shell. If users genuinely need every origin, they can list + # them explicitly. + if any(o.strip() == "*" for o in settings.cors_origins): + raise RuntimeError( + "cors_origins must not contain '*' — list exact origins instead. " + "This protects the script-execution endpoints from any-origin abuse." + ) cors_origins = settings.cors_origins or [ f"http://localhost:{settings.port}", f"http://127.0.0.1:{settings.port}", @@ -186,6 +268,23 @@ def create_app() -> FastAPI: allow_headers=["Authorization", "Content-Type"], ) + # Request correlation ID — accept upstream X-Request-ID if it's a sane + # ASCII id, otherwise mint a fresh UUID4. Emitted on the response so + # clients can quote it back in bug reports. + import re + import uuid as _uuid + + _REQ_ID_RE = re.compile(r"^[A-Za-z0-9._\-]{1,128}$") + + @app.middleware("http") + async def request_id_middleware(request: Request, call_next): + incoming = request.headers.get("x-request-id", "") + req_id = incoming if _REQ_ID_RE.match(incoming) else _uuid.uuid4().hex[:16] + request_id_var.set(req_id) + response = await call_next(request) + response.headers["X-Request-ID"] = req_id + return response + # Security headers — strict CSP for the bundled UI, disallow framing, hide referrer. @app.middleware("http") async def security_headers_middleware(request: Request, call_next): @@ -200,6 +299,9 @@ def create_app() -> FastAPI: "style-src 'self' 'unsafe-inline'; " "font-src 'self' data:; " "frame-ancestors 'none'; " + "form-action 'self'; " + "worker-src 'self'; " + "manifest-src 'self'; " "base-uri 'self'" ), ) @@ -208,32 +310,63 @@ def create_app() -> FastAPI: response.headers.setdefault("Referrer-Policy", "no-referrer") return response - # Add token logging middleware + # Add token logging middleware + auth-failure rate limit + from fastapi.responses import JSONResponse + + from .services.rate_limit import check as ratelimit_check + from .services.rate_limit import get_peer + @app.middleware("http") async def token_logging_middleware(request: Request, call_next): - """Extract token label and set in context for logging.""" + """Extract token label, set in context, and rate-limit failed auths.""" if not settings.api_tokens: token_label_var.set("anonymous") else: token_label = "unknown" + token_present = False + token_valid = False # Try Authorization header auth_header = request.headers.get("authorization", "") if auth_header.startswith("Bearer "): + token_present = True token = auth_header[7:] label = get_token_label(token) if label: token_label = label + token_valid = True # Try query parameter (for artwork endpoint) elif "token" in request.query_params: + token_present = True token = request.query_params["token"] label = get_token_label(token) if label: token_label = label + token_valid = True token_label_var.set(token_label) + # Brute-force gate: a peer that produces a wrong/missing token gets + # 5 failures per minute before being throttled. Static-asset + # requests (GET /static/*, /, /sw.js) and the docs endpoint are + # exempt — they're served unauthenticated by design. + if token_present and not token_valid: + path = request.url.path + if not ( + path == "/" or path == "/sw.js" + or path.startswith("/static/") + or path.startswith("/docs") or path.startswith("/openapi") + or path.startswith("/redoc") + ): + allowed, retry_after = ratelimit_check("auth", get_peer(request)) + if not allowed: + return JSONResponse( + status_code=429, + content={"detail": "Too many authentication failures"}, + headers={"Retry-After": str(int(retry_after or 60))}, + ) + response = await call_next(request) return response @@ -266,6 +399,11 @@ def create_app() -> FastAPI: async def serve_ui(): """Serve the Web UI.""" return FileResponse(static_dir / "index.html") + else: + logging.getLogger(__name__).warning( + "static_dir not found at %s — Web UI disabled (API only)", + static_dir, + ) return app @@ -316,8 +454,9 @@ def main(): print(f"Config directory: {get_config_dir()}") if settings.api_tokens: print("\nAPI Tokens:") - for label, token in settings.api_tokens.items(): - print(f" {label:20} {token}") + for label, spec in settings.api_tokens.items(): + scope_str = ",".join(spec.scopes) + print(f" {label:20} {spec.token} [scopes: {scope_str}]") else: print("\nAuthentication is DISABLED (no tokens configured)") return @@ -374,6 +513,27 @@ def main(): use_tray = PYSTRAY_AVAILABLE and not args.no_tray + # Validate TLS pair consistency before either path so we don't fail late. + if bool(settings.ssl_certfile) ^ bool(settings.ssl_keyfile): + _fatal( + "ERROR: ssl_certfile and ssl_keyfile must both be set, or both unset." + ) + + def _uvicorn_kwargs() -> dict: + kw: dict = { + "host": args.host, + "port": args.port, + "log_level": settings.log_level.lower(), + "proxy_headers": settings.proxy_headers, + "forwarded_allow_ips": settings.forwarded_allow_ips, + } + if settings.ssl_certfile and settings.ssl_keyfile: + kw["ssl_certfile"] = settings.ssl_certfile + kw["ssl_keyfile"] = settings.ssl_keyfile + if settings.ssl_keyfile_password: + kw["ssl_keyfile_password"] = settings.ssl_keyfile_password + return kw + if use_tray: import asyncio import threading @@ -381,9 +541,7 @@ def main(): # Run uvicorn in a background thread so tray owns the main thread message loop uv_config = uvicorn.Config( "media_server.main:app", - host=args.host, - port=args.port, - log_level=settings.log_level.lower(), + **_uvicorn_kwargs(), ) server = uvicorn.Server(uv_config) @@ -421,9 +579,8 @@ def main(): else: uvicorn.run( "media_server.main:app", - host=args.host, - port=args.port, reload=False, + **_uvicorn_kwargs(), ) diff --git a/media_server/routes/browser.py b/media_server/routes/browser.py index 5986bd0..d0577fe 100644 --- a/media_server/routes/browser.py +++ b/media_server/routes/browser.py @@ -36,12 +36,20 @@ def _spawn_background(coro) -> asyncio.Task: def _require_folder_management() -> None: - """Raise 403 if media folder management is disabled in config.""" + """Raise 403 if media folder management is disabled OR caller lacks admin scope.""" if not settings.media_folders_management: raise HTTPException( status_code=403, detail="Media folder management is disabled. Set media_folders_management: true in config.yaml to enable.", ) + from ..auth import auth_enabled, token_has_scope, token_label_var + if auth_enabled(): + label = token_label_var.get("unknown") + if not token_has_scope(label, "admin"): + raise HTTPException( + status_code=403, + detail=f"Token '{label}' lacks required scope: admin", + ) async def _broadcast_after_open(controller, label: str, max_wait: float = 2.0) -> None: diff --git a/media_server/routes/callbacks.py b/media_server/routes/callbacks.py index 09ca170..35eb106 100644 --- a/media_server/routes/callbacks.py +++ b/media_server/routes/callbacks.py @@ -8,12 +8,14 @@ import time from concurrent.futures import ThreadPoolExecutor from typing import Any -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel, Field from ..auth import verify_token from ..config import CallbackConfig, settings from ..config_manager import config_manager +from ..services.rate_limit import check as ratelimit_check +from ..services.rate_limit import get_peer router = APIRouter(prefix="/api/callbacks", tags=["callbacks"]) logger = logging.getLogger(__name__) @@ -28,6 +30,7 @@ def shutdown_callback_executor() -> None: def _require_callbacks_management() -> None: + """Authorise a callbacks-CRUD operation. Operator flag + per-token admin scope.""" if not settings.callbacks_management: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -36,6 +39,14 @@ def _require_callbacks_management() -> None: " in config.yaml to enable." ), ) + from ..auth import auth_enabled, token_has_scope, token_label_var + if auth_enabled(): + label = token_label_var.get("unknown") + if not token_has_scope(label, "admin"): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Token '{label}' lacks required scope: admin", + ) class CallbackInfo(BaseModel): @@ -122,6 +133,7 @@ async def list_callbacks(_: str = Depends(verify_token)) -> list[CallbackInfo]: @router.post("/execute/{callback_name}") async def execute_callback( callback_name: str, + http_request: Request, _: str = Depends(verify_token), ) -> CallbackExecuteResponse: """Execute a callback for debugging purposes. @@ -132,6 +144,16 @@ async def execute_callback( Returns: Execution result including stdout, stderr, and exit code """ + # Rate-limit callback execution per peer (10/min) — callbacks also run + # subprocesses and need the same protection as scripts. + allowed, retry_after = ratelimit_check("execute", get_peer(http_request)) + if not allowed: + raise HTTPException( + status_code=429, + detail="Too many callback executions, slow down", + headers={"Retry-After": str(int(retry_after or 60))}, + ) + # Validate callback name _validate_callback_name(callback_name) @@ -146,6 +168,8 @@ async def execute_callback( logger.info(f"Executing callback for debugging: {callback_name}") + from ..services.audit_log import record_script_execution + try: # Execute in dedicated thread pool to not block the default executor loop = asyncio.get_running_loop() @@ -159,6 +183,15 @@ async def execute_callback( ), ) + record_script_execution( + kind="callback", + name=callback_name, + exit_code=result["exit_code"], + duration=result.get("execution_time"), + stdout=result.get("stdout"), + stderr=result.get("stderr"), + ) + return CallbackExecuteResponse( success=result["exit_code"] == 0, callback=callback_name, @@ -170,6 +203,13 @@ async def execute_callback( except Exception as e: logger.error(f"Callback execution error: {e}") + record_script_execution( + kind="callback", + name=callback_name, + exit_code=None, + duration=None, + error=str(e), + ) return CallbackExecuteResponse( success=False, callback=callback_name, diff --git a/media_server/routes/links.py b/media_server/routes/links.py index 87ea4c9..3008a1f 100644 --- a/media_server/routes/links.py +++ b/media_server/routes/links.py @@ -39,11 +39,20 @@ def _validate_icon(icon: str) -> str: def _require_links_management() -> None: + """Authorise a links-CRUD operation. Operator flag + per-token admin scope.""" if not settings.links_management: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Links management is disabled. Set links_management: true in config.yaml to enable.", ) + from ..auth import auth_enabled, token_has_scope, token_label_var + if auth_enabled(): + label = token_label_var.get("unknown") + if not token_has_scope(label, "admin"): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Token '{label}' lacks required scope: admin", + ) class LinkInfo(BaseModel): diff --git a/media_server/routes/media.py b/media_server/routes/media.py index 67c346b..1099d27 100644 --- a/media_server/routes/media.py +++ b/media_server/routes/media.py @@ -3,7 +3,16 @@ import asyncio import logging -from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect, status +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Query, + Request, + WebSocket, + WebSocketDisconnect, + status, +) from fastapi.responses import Response from ..auth import verify_token, verify_token_or_query @@ -17,19 +26,28 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/media", tags=["media"]) +# Strong refs to background tasks so the asyncio GC can't drop them before +# they run. Mirrors the pattern used in routes/browser.py. +_background_callback_tasks: set[asyncio.Task] = set() + + def _run_callback(callback_name: str) -> None: """Fire-and-forget a callback if configured. Failures are logged but don't block.""" if not settings.callbacks or callback_name not in settings.callbacks: return async def _execute(): + # Use the dedicated callback executor (not the default loop pool) so a + # misbehaving callback can't starve the rest of the app's sync tasks. + from ..services.audit_log import record_script_execution + from .callbacks import _callback_executor from .scripts import _run_script try: callback = settings.callbacks[callback_name] loop = asyncio.get_running_loop() result = await loop.run_in_executor( - None, + _callback_executor, lambda: _run_script( command=callback.command, timeout=callback.timeout, @@ -37,6 +55,14 @@ def _run_callback(callback_name: str) -> None: working_dir=callback.working_dir, ), ) + record_script_execution( + kind="event-callback", + name=callback_name, + exit_code=result["exit_code"], + duration=result.get("execution_time"), + stdout=result.get("stdout"), + stderr=result.get("stderr"), + ) if result["exit_code"] != 0: logger.warning( "Callback %s failed with exit code %s: %s", @@ -46,8 +72,18 @@ def _run_callback(callback_name: str) -> None: ) except Exception as e: logger.error("Callback %s error: %s", callback_name, e) + from ..services.audit_log import record_script_execution as _rec + _rec( + kind="event-callback", + name=callback_name, + exit_code=None, + duration=None, + error=str(e), + ) - asyncio.create_task(_execute()) + task = asyncio.create_task(_execute()) + _background_callback_tasks.add(task) + task.add_done_callback(_background_callback_tasks.discard) @router.get("/status", response_model=MediaStatus) @@ -242,11 +278,14 @@ async def toggle(_: str = Depends(verify_token)) -> dict: @router.get("/artwork") -async def get_artwork(_: str = Depends(verify_token_or_query)) -> Response: +async def get_artwork( + request: Request, + _: str = Depends(verify_token_or_query), +) -> Response: """Get the current album artwork. - Returns: - The album art image as PNG/JPEG + Returns the bytes with a content-derived ETag so the browser can serve a + 304 when the same track is re-requested. """ art_bytes = get_current_album_art() if art_bytes is None: @@ -255,16 +294,34 @@ async def get_artwork(_: str = Depends(verify_token_or_query)) -> Response: detail="No album artwork available", ) - # Try to detect image type from magic bytes - content_type = "image/png" # Default + # Detect image type from magic bytes if art_bytes[:3] == b"\xff\xd8\xff": content_type = "image/jpeg" elif art_bytes[:8] == b"\x89PNG\r\n\x1a\n": content_type = "image/png" - elif art_bytes[:4] == b"RIFF" and art_bytes[8:12] == b"WEBP": + elif art_bytes[:4] == b"RIFF" and len(art_bytes) > 12 and art_bytes[8:12] == b"WEBP": content_type = "image/webp" + elif art_bytes[:2] == b"BM": + content_type = "image/bmp" + else: + content_type = "application/octet-stream" - return Response(content=art_bytes, media_type=content_type) + # Content-derived ETag (blake2b-128 — non-crypto cache key, ruff S324-safe) + import hashlib + + etag = '"' + hashlib.blake2b(art_bytes, digest_size=16).hexdigest() + '"' + + if request.headers.get("if-none-match") == etag: + return Response(status_code=status.HTTP_304_NOT_MODIFIED, headers={"ETag": etag}) + + return Response( + content=art_bytes, + media_type=content_type, + headers={ + "ETag": etag, + "Cache-Control": "private, max-age=0, must-revalidate", + }, + ) @router.get("/visualizer/status") @@ -323,12 +380,17 @@ async def set_visualizer_device( @router.websocket("/ws") async def websocket_endpoint( websocket: WebSocket, - token: str | None = Query(None, description="API authentication token"), + token: str | None = Query(None, description="API authentication token (legacy)"), ) -> None: """WebSocket endpoint for real-time media status updates. - Authentication is done via query parameter since WebSocket - doesn't support custom headers in the browser. + Authentication is accepted from two sources, in priority order: + 1. ``Sec-WebSocket-Protocol`` subprotocol of the form + ``media-server.token.``. This is the preferred path because + the token never lands in the URL, request logs, or browser history. + The browser WebSocket API supports custom subprotocols natively. + 2. ``?token=`` query parameter (legacy, kept for back-compat + with older clients and the HA integration). Messages sent to client: - {"type": "status", "data": {...}} - Initial status on connect @@ -339,11 +401,40 @@ async def websocket_endpoint( - {"type": "ping"} - Keepalive, server responds with {"type": "pong"} - {"type": "get_status"} - Request current status """ + # Pull token from subprotocol if present. WebSocket spec lets either side + # negotiate exactly one subprotocol back; we accept the token one and + # answer with the same string so browsers consider the negotiation + # successful. + subprotocol_token: str | None = None + accept_subprotocol: str | None = None + raw_protocols = websocket.headers.get("sec-websocket-protocol", "") + for proto in (p.strip() for p in raw_protocols.split(",") if p.strip()): + if proto.startswith("media-server.token."): + subprotocol_token = proto[len("media-server.token."):] + accept_subprotocol = proto + break + effective_token = subprotocol_token or token + # Origin check — block CSWSH from third-party LAN pages. We accept the same + # set of origins as CORS plus the default localhost loopback. + allowed_origins = set( + settings.cors_origins + or [ + f"http://localhost:{settings.port}", + f"http://127.0.0.1:{settings.port}", + ] + ) + origin = websocket.headers.get("origin") + # Same-origin connections from native apps may omit Origin entirely; only + # reject when an Origin is present AND not in the allow-list. + if origin is not None and origin not in allowed_origins: + await websocket.close(code=4003, reason="Origin not allowed") + return + # Verify token from ..auth import auth_enabled, get_token_label, token_label_var if auth_enabled(): - label = get_token_label(token) if token else None + label = get_token_label(effective_token) if effective_token else None if label is None: await websocket.close(code=4001, reason="Invalid authentication token") return @@ -351,16 +442,25 @@ async def websocket_endpoint( else: token_label_var.set("anonymous") - await ws_manager.connect(websocket) + # Accept with the negotiated subprotocol if one was used. Starlette's + # connect() calls accept() with no subprotocol — we need to accept first + # explicitly to echo the subprotocol back, then hand off to the manager. + if accept_subprotocol is not None: + await websocket.accept(subprotocol=accept_subprotocol) + await ws_manager.connect(websocket, already_accepted=True) + else: + await ws_manager.connect(websocket) try: while True: # Wait for messages from client (for keepalive/ping) data = await websocket.receive_json() - if data.get("type") == "ping": + msg_type = data.get("type") if isinstance(data, dict) else None + + if msg_type == "ping": await websocket.send_json({"type": "pong"}) - elif data.get("type") == "get_status": + elif msg_type == "get_status": # Allow manual status request controller = get_media_controller() status_data = await controller.get_status() @@ -368,15 +468,20 @@ async def websocket_endpoint( "type": "status", "data": status_data.model_dump(), }) - elif data.get("type") == "volume": - # Low-latency volume control via WebSocket - volume = data.get("volume") - if volume is not None: - controller = get_media_controller() - await controller.set_volume(int(volume)) - elif data.get("type") == "enable_visualizer": + elif msg_type == "volume": + # Low-latency volume control via WebSocket. Coerce, clamp, and + # never drop the socket on a single bad message — that would + # turn the WS into a one-shot DoS for any holder of a token. + try: + volume = int(data.get("volume")) + except (TypeError, ValueError): + continue + volume = max(0, min(100, volume)) + controller = get_media_controller() + await controller.set_volume(volume) + elif msg_type == "enable_visualizer": await ws_manager.subscribe_visualizer(websocket) - elif data.get("type") == "disable_visualizer": + elif msg_type == "disable_visualizer": await ws_manager.unsubscribe_visualizer(websocket) except WebSocketDisconnect: diff --git a/media_server/routes/scripts.py b/media_server/routes/scripts.py index 53516f5..54a7249 100644 --- a/media_server/routes/scripts.py +++ b/media_server/routes/scripts.py @@ -10,12 +10,14 @@ import time from concurrent.futures import ThreadPoolExecutor from typing import Any -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel, Field from ..auth import verify_token from ..config import ScriptConfig, ScriptParameterConfig, settings from ..config_manager import config_manager +from ..services.rate_limit import check as ratelimit_check +from ..services.rate_limit import get_peer from ..services.websocket_manager import ws_manager router = APIRouter(prefix="/api/scripts", tags=["scripts"]) @@ -31,6 +33,12 @@ def shutdown_script_executor() -> None: def _require_scripts_management() -> None: + """Authorise a scripts-CRUD operation. + + Two gates: the operator-level `scripts_management` flag in config.yaml, + AND the per-token `admin` scope check (read from request-context). Either + failure → 403. + """ if not settings.scripts_management: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -39,6 +47,14 @@ def _require_scripts_management() -> None: " in config.yaml to enable." ), ) + from ..auth import auth_enabled, token_has_scope, token_label_var + if auth_enabled(): + label = token_label_var.get("unknown") + if not token_has_scope(label, "admin"): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Token '{label}' lacks required scope: admin", + ) class ScriptExecuteRequest(BaseModel): @@ -215,6 +231,28 @@ def _validate_params( # string — just convert to str value = str(value) + # Optional regex constraint, validated against the *string form* of the + # value. This is the only practical defence for string parameters that + # flow into shell=true scripts via env vars (Windows cmd.exe expands + # `%VAR%` after argument parsing, so embedded `&`/`|`/`%` would inject + # commands). Authors of shell scripts should ALWAYS define a pattern. + if pdef.pattern: + try: + if not re.fullmatch(pdef.pattern, str(value)): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"Parameter '{pname}' value {value!r} does not match" + f" required pattern: {pdef.pattern}" + ), + ) + except re.error as e: + # Bad pattern in config — fail closed. + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Parameter '{pname}' has invalid pattern: {e}", + ) from e + env_vars[f"SCRIPT_PARAM_{pname.upper()}"] = str(value) return env_vars @@ -223,6 +261,7 @@ def _validate_params( @router.post("/execute/{script_name}") async def execute_script( script_name: str, + http_request: Request, request: ScriptExecuteRequest | None = None, _: str = Depends(verify_token), ) -> ScriptExecuteResponse: @@ -235,6 +274,16 @@ async def execute_script( Returns: Execution result including stdout, stderr, and exit code """ + # Rate-limit script execution per peer so a leaked token can't be used to + # spam the shell-exec endpoint. + allowed, retry_after = ratelimit_check("execute", get_peer(http_request)) + if not allowed: + raise HTTPException( + status_code=429, + detail="Too many script executions, slow down", + headers={"Retry-After": str(int(retry_after or 60))}, + ) + # Check if script exists if script_name not in settings.scripts: raise HTTPException( @@ -249,6 +298,8 @@ async def execute_script( logger.info(f"Executing script: {script_name}") + from ..services.audit_log import record_script_execution + try: # Execute in dedicated thread pool to not block the default executor loop = asyncio.get_running_loop() @@ -263,6 +314,15 @@ async def execute_script( ), ) + record_script_execution( + kind="script", + name=script_name, + exit_code=result["exit_code"], + duration=result.get("execution_time"), + stdout=result.get("stdout"), + stderr=result.get("stderr"), + ) + return ScriptExecuteResponse( success=result["exit_code"] == 0, script=script_name, @@ -274,6 +334,13 @@ async def execute_script( except Exception as e: logger.error(f"Script execution error: {e}") + record_script_execution( + kind="script", + name=script_name, + exit_code=None, + duration=None, + error=str(e), + ) return ScriptExecuteResponse( success=False, script=script_name, @@ -313,9 +380,21 @@ def _run_script( else: popen_kwargs["start_new_session"] = True + # When shell=False, the user-provided command string is split via shlex + # (POSIX rules — also works for Windows args without backslashes). This + # disables shell metacharacter expansion entirely, so SCRIPT_PARAM_* env + # vars referenced as $FOO / %FOO% will be treated as literal text by the + # process, not interpreted by a shell. Use shell=false for any script + # whose params come from external input. + if shell: + run_command: str | list[str] = command + else: + import shlex + run_command = shlex.split(command, posix=(sys.platform != "win32")) + try: result = subprocess.run( - command, + run_command, shell=shell, cwd=working_dir, capture_output=True, diff --git a/media_server/services/audit_log.py b/media_server/services/audit_log.py new file mode 100644 index 0000000..eca0bb8 --- /dev/null +++ b/media_server/services/audit_log.py @@ -0,0 +1,120 @@ +"""Append-only audit log for sensitive actions (script + callback execution). + +Writes a single JSONL line per event to ``/audit.log``. The log is +write-only from the app's perspective — it never reads back, and rotation is +left to the operator (the file size is dominated by stdout/stderr truncation, +which is already capped at 10 KB per stream in `_run_script`). + +Designed to be cheap: the write goes through a small background thread so the +hot path never blocks on disk I/O, and a failure to write is logged at WARNING +but never raised to callers. +""" + +from __future__ import annotations + +import json +import logging +import queue +import threading +import time +from typing import Any + +from ..auth import token_label_var +from ..config import get_config_dir + +logger = logging.getLogger(__name__) + +# Cap on stdout/stderr inside the audit record so a chatty script doesn't +# explode the log. Mirrors the 10k cap used by _run_script. +_OUTPUT_CAP = 2000 + +_audit_queue: "queue.Queue[dict[str, Any] | None]" = queue.Queue(maxsize=1000) +_audit_thread: threading.Thread | None = None +_audit_lock = threading.Lock() + + +def _ensure_writer_started() -> None: + global _audit_thread + with _audit_lock: + if _audit_thread is not None and _audit_thread.is_alive(): + return + _audit_thread = threading.Thread( + target=_audit_writer_loop, + name="audit-log", + daemon=True, + ) + _audit_thread.start() + + +def _audit_writer_loop() -> None: + log_path = get_config_dir() / "audit.log" + while True: + try: + record = _audit_queue.get() + except Exception: + return + if record is None: + return + try: + line = json.dumps(record, ensure_ascii=False, default=str) + with open(log_path, "a", encoding="utf-8") as f: + f.write(line + "\n") + except OSError as e: + logger.warning("Failed to write audit record: %s", e) + + +def _truncate(value: str | None) -> str | None: + if value is None: + return None + if len(value) <= _OUTPUT_CAP: + return value + return value[:_OUTPUT_CAP] + f"\n…[truncated, {len(value) - _OUTPUT_CAP} chars]" + + +def record_script_execution( + *, + kind: str, + name: str, + exit_code: int | None, + duration: float | None, + stdout: str | None = None, + stderr: str | None = None, + error: str | None = None, +) -> None: + """Append a single audit record. Never raises.""" + _ensure_writer_started() + try: + record = { + "ts": time.time(), + "iso": time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime()), + "token_label": token_label_var.get("unknown"), + "kind": kind, + "name": name, + "exit_code": exit_code, + "duration_s": round(duration, 4) if duration is not None else None, + "success": exit_code == 0 if exit_code is not None else False, + "stdout": _truncate(stdout), + "stderr": _truncate(stderr), + "error": error, + } + _audit_queue.put_nowait(record) + except queue.Full: + # Backpressure: drop oldest record to make room. We'd rather lose an + # old entry than block the script that just ran. + try: + _audit_queue.get_nowait() + _audit_queue.put_nowait(record) + except queue.Empty: + pass + except Exception as e: + logger.warning("Failed to enqueue audit record: %s", e) + + +def shutdown_audit_log() -> None: + """Flush the audit queue on app shutdown.""" + try: + _audit_queue.put_nowait(None) + except queue.Full: + pass + if _audit_thread is not None: + _audit_thread.join(timeout=2) diff --git a/media_server/services/display_service.py b/media_server/services/display_service.py index 0df96ba..27ff23f 100644 --- a/media_server/services/display_service.py +++ b/media_server/services/display_service.py @@ -192,10 +192,11 @@ _CACHE_TTL = 5.0 # seconds # Per-monitor cache of static capabilities (option lists + support flags). # DDC/CI capability discovery is the slow part — it only changes when a # monitor is replaced or rewired, so we probe it once per monitor and reuse -# it across refreshes. Cleared on explicit `rediscover` or when the monitor -# count changes (cheap stale-detection for hot-plug events). -_static_cache: dict[int, dict] = {} -_static_cache_monitor_count: int = -1 +# it across refreshes. Keyed by a stable identity tuple +# (manufacturer, model, edid_hash) so that hot-plug swaps where the new +# topology has the same number of monitors but different devices still +# refresh the cache for the new monitor instead of serving stale capabilities. +_static_cache: dict[tuple, dict] = {} def _enum_name(value, enum_cls=None) -> str | None: @@ -353,7 +354,7 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list next probe re-runs DDC/CI capability discovery. Use after hot-plug or when a monitor's reported capabilities change. """ - global _monitor_cache, _cache_time, _static_cache_monitor_count + global _monitor_cache, _cache_time if ( not force_refresh @@ -372,12 +373,11 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list info_list = sbc.list_monitors_info() brightnesses = sbc.get_brightness() - # Invalidate the static cache on explicit rediscover OR on topology - # change (hot-plug / disconnect). Both indicate the cached probe is - # potentially stale. - if rediscover or len(info_list) != _static_cache_monitor_count: + # Explicit rediscover wipes the whole cache; otherwise rely on stable + # per-monitor keys (manufacturer|model|edid_hash) so a hot-plug swap + # invalidates the entry for the missing monitor automatically. + if rediscover: _static_cache.clear() - _static_cache_monitor_count = len(info_list) mc = _load_monitorcontrol() ddc_monitors = [] @@ -387,6 +387,9 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list except Exception: pass + import hashlib + + seen_keys: set[tuple] = set() for i, info in enumerate(info_list): name = info.get("name", f"Monitor {i}") model = info.get("model", "") @@ -400,6 +403,21 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list edid = info.get("edid", "") resolution = _parse_edid_resolution(edid) if edid else None + # Stable cache key — EDID hash is unique per physical monitor. + # Fall back to (manufacturer, model, serial-ish) when EDID is + # missing, then to the legacy index as a last resort. + if edid: + edid_hash = hashlib.blake2b( + edid.encode("utf-8") if isinstance(edid, str) else bytes(edid), + digest_size=8, + ).hexdigest() + cache_key: tuple = ("edid", edid_hash) + elif manufacturer or model: + cache_key = ("mm", manufacturer, model, name) + else: + cache_key = ("idx", i) + seen_keys.add(cache_key) + static: dict = {} dynamic: dict = {} @@ -409,13 +427,13 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list if power_supported and i < len(ddc_monitors): try: with ddc_monitors[i] as mon: - if i not in _static_cache: - _static_cache[i] = _probe_static_open(mon, mc, i) - static = _static_cache[i] + if cache_key not in _static_cache: + _static_cache[cache_key] = _probe_static_open(mon, mc, i) + static = _static_cache[cache_key] dynamic = _probe_dynamic_open(mon, mc, i, static) except Exception as e: logger.debug("Monitor %d: DDC/CI session failed: %s", i, e) - static = _static_cache.get(i, {}) + static = _static_cache.get(cache_key, {}) monitors.append(MonitorInfo( id=i, @@ -439,6 +457,12 @@ def list_monitors(force_refresh: bool = False, rediscover: bool = False) -> list available_picture_modes=static.get("available_picture_modes", []), picture_mode_supported=static.get("picture_mode_supported", False), )) + # Evict cache entries for monitors that disappeared from this scan so + # the next hot-plug of a different monitor with the same identity + # tuple (e.g. same model) doesn't hit a stale entry first. + for stale_key in list(_static_cache.keys()): + if stale_key not in seen_keys: + _static_cache.pop(stale_key, None) except Exception as e: logger.error("Failed to enumerate monitors: %s", e) diff --git a/media_server/services/foreground_service.py b/media_server/services/foreground_service.py index ee3b47e..d15a601 100644 --- a/media_server/services/foreground_service.py +++ b/media_server/services/foreground_service.py @@ -86,9 +86,29 @@ class _Cache: _cache = _Cache() +# Win32 handles + signatures are declared once at module load (when running on +# Windows). The TTL cache fires this hundreds of times per minute; redoing the +# DLL load + ~10 argtype assignments per call was the largest chunk of probe +# cost. Keep these guarded behind a lazy init so non-Windows platforms don't +# pay the import. +_WIN32_INITIALIZED = False +_win32_user32 = None +_win32_kernel32 = None +_win32_psapi = None + + +def _init_win32_apis() -> None: + """Declare ctypes argtypes/restype on every Win32 call we make. + + CRITICAL: ctypes defaults to `c_int` (32-bit) for HANDLE/HWND/HMONITOR + which silently truncates 64-bit pointer values on x64 — that corrupts the + handle so `CloseHandle()` can either fail or close the wrong kernel + object, and pointer-equality comparisons (monitor index lookup) miss. + """ + global _WIN32_INITIALIZED, _win32_user32, _win32_kernel32, _win32_psapi + if _WIN32_INITIALIZED: + return -def _probe_windows() -> ForegroundInfo: - """Probe foreground window state on Windows via Win32 API.""" import ctypes import ctypes.wintypes as wt @@ -96,11 +116,6 @@ def _probe_windows() -> ForegroundInfo: kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) psapi = ctypes.WinDLL("psapi", use_last_error=True) - # CRITICAL: declare argtypes/restype on every Win32 call that returns a - # HANDLE/HWND/HMONITOR. ctypes defaults to `c_int` (32-bit) which - # silently truncates 64-bit pointer values on x64 — that corrupts the - # handle so `CloseHandle()` can either fail or close the wrong kernel - # object, and pointer-equality comparisons (monitor index lookup) miss. user32.GetForegroundWindow.restype = wt.HWND user32.GetWindowThreadProcessId.argtypes = [wt.HWND, ctypes.POINTER(wt.DWORD)] user32.GetWindowThreadProcessId.restype = wt.DWORD @@ -137,6 +152,20 @@ def _probe_windows() -> ForegroundInfo: psapi.GetModuleFileNameExW.argtypes = [wt.HANDLE, wt.HMODULE, wt.LPWSTR, wt.DWORD] psapi.GetModuleFileNameExW.restype = wt.DWORD + _win32_user32, _win32_kernel32, _win32_psapi = user32, kernel32, psapi + _WIN32_INITIALIZED = True + + +def _probe_windows() -> ForegroundInfo: + """Probe foreground window state on Windows via Win32 API.""" + import ctypes + import ctypes.wintypes as wt + + _init_win32_apis() + user32 = _win32_user32 + kernel32 = _win32_kernel32 + psapi = _win32_psapi + hwnd = user32.GetForegroundWindow() if not hwnd: return ForegroundInfo(available=True, error="no foreground window") diff --git a/media_server/services/gitea_release_provider.py b/media_server/services/gitea_release_provider.py index 0ea71ba..f9dcab9 100644 --- a/media_server/services/gitea_release_provider.py +++ b/media_server/services/gitea_release_provider.py @@ -2,6 +2,7 @@ import json import logging +import re import urllib.error import urllib.request from typing import Optional @@ -15,6 +16,11 @@ _DEFAULT_BASE_URL = "https://git.dolgolyov-family.by" _DEFAULT_OWNER = "alexei.dolgolyov" _DEFAULT_REPO = "media-player-server" +# Restrictive tag whitelist — prevents a hostile Gitea response (or MITM) from +# injecting `..`, slashes, or URL-altering characters into the release URL we +# broadcast to clients. SemVer + pre-release suffix only. +_TAG_RE = re.compile(r"^v?\d+\.\d+\.\d+(?:[\w.\-+]{0,32})?$") + class GiteaReleaseProvider(ReleaseProvider): """Fetches the latest release from a Gitea repository.""" @@ -53,6 +59,9 @@ class GiteaReleaseProvider(ReleaseProvider): continue tag = release.get("tag_name", "") + if not isinstance(tag, str) or not _TAG_RE.match(tag): + logger.warning("Rejecting malformed release tag from upstream: %r", tag) + continue version = tag.lstrip("v") if not version: continue diff --git a/media_server/services/macos_media.py b/media_server/services/macos_media.py index 2afdf26..86b1642 100644 --- a/media_server/services/macos_media.py +++ b/media_server/services/macos_media.py @@ -264,8 +264,12 @@ class MacOSMediaController(MediaController): async def set_volume(self, volume: int) -> bool: """Set system volume.""" - result = self._run_osascript(f"set volume output volume {volume}") - return result is not None or True # osascript returns empty on success + # osascript returns empty string on success and None on failure (the + # _run_osascript helper catches subprocess errors). The previous + # `result is not None or True` always returned True regardless of + # outcome — surface real failures so the route can return 503. + result = self._run_osascript(f"set volume output volume {int(volume)}") + return result is not None async def toggle_mute(self) -> bool: """Toggle mute state.""" diff --git a/media_server/services/rate_limit.py b/media_server/services/rate_limit.py new file mode 100644 index 0000000..d3fdce0 --- /dev/null +++ b/media_server/services/rate_limit.py @@ -0,0 +1,95 @@ +"""In-process token-bucket rate limiter. + +Light enough for a single-process app: one dict keyed by ``(bucket, peer)`` +guarded by a thread lock. No extra dependency, no Redis. Good enough for +defeating credential-stuffing and runaway clients on a LAN; not a substitute +for an upstream WAF in a public deployment. + +Buckets: + auth — failed-auth attempts, 5/min/peer (used in auth middleware) + execute — script + callback execute calls, 10/min/peer (LAN-friendly) + default — generic POST/DELETE writes, 60/min/peer +""" + +from __future__ import annotations + +import logging +import threading +import time +from dataclasses import dataclass +from typing import Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class BucketConfig: + capacity: float # max tokens (= burst size) + refill_per_sec: float # tokens added per second + + +# Defaults — tuned for "trusted LAN" use; operator can override via Settings. +BUCKETS: dict[str, BucketConfig] = { + "auth": BucketConfig(capacity=5, refill_per_sec=5 / 60), # 5/min + "execute": BucketConfig(capacity=10, refill_per_sec=10 / 60), # 10/min + "default": BucketConfig(capacity=60, refill_per_sec=60 / 60), # 60/min +} + + +_state: dict[tuple[str, str], tuple[float, float]] = {} +_lock = threading.Lock() +_LAST_CLEANUP = 0.0 + + +def _evict_stale_locked(now: float) -> None: + """Drop entries whose buckets are full (= idle for capacity / refill seconds).""" + global _LAST_CLEANUP + if now - _LAST_CLEANUP < 60: + return + _LAST_CLEANUP = now + stale = [] + for key, (tokens, last) in _state.items(): + bucket = BUCKETS.get(key[0]) + if bucket is None: + continue + if tokens >= bucket.capacity and (now - last) > 3600: + stale.append(key) + for key in stale: + _state.pop(key, None) + + +def check(bucket: str, peer: str) -> tuple[bool, Optional[float]]: + """Try to consume one token from ``(bucket, peer)``. + + Returns: + (allowed, retry_after_seconds). When allowed=True retry_after is None. + When allowed=False, retry_after is the seconds to wait for one more token. + """ + cfg = BUCKETS.get(bucket) or BUCKETS["default"] + now = time.monotonic() + with _lock: + _evict_stale_locked(now) + tokens, last = _state.get((bucket, peer), (cfg.capacity, now)) + elapsed = max(0.0, now - last) + tokens = min(cfg.capacity, tokens + elapsed * cfg.refill_per_sec) + if tokens >= 1: + tokens -= 1 + _state[(bucket, peer)] = (tokens, now) + return True, None + deficit = 1 - tokens + retry = deficit / cfg.refill_per_sec if cfg.refill_per_sec > 0 else 60 + _state[(bucket, peer)] = (tokens, now) + return False, retry + + +def get_peer(request) -> str: + """Best-effort peer identifier from a Starlette request. + + Honors X-Forwarded-For (only when settings.proxy_headers is True, which is + already enforced by uvicorn's middleware) so a reverse-proxied install + still rate-limits per real client. + """ + client = getattr(request, "client", None) + if client and client.host: + return client.host + return "unknown" diff --git a/media_server/services/thumbnail_service.py b/media_server/services/thumbnail_service.py index 5b40235..becf3ff 100644 --- a/media_server/services/thumbnail_service.py +++ b/media_server/services/thumbnail_service.py @@ -26,12 +26,23 @@ class ThumbnailService: def get_cache_dir() -> Path: """Get the thumbnail cache directory path. - Returns: - Path to the cache directory (project-local). + Returns user-writable platform cache dir so installs under + ``%PROGRAMFILES%`` / ``/opt`` work without elevated permissions. + Mirrors the platform branching of ``config.get_config_dir``. """ - # Store cache in project directory: media-server/.cache/thumbnails/ - project_root = Path(__file__).parent.parent.parent - cache_dir = project_root / ".cache" / "thumbnails" + import os + + if os.name == "nt": + # %LOCALAPPDATA% so the cache survives roaming-profile sync. + base = Path(os.environ.get("LOCALAPPDATA") + or os.environ.get("APPDATA") + or Path.home() / "AppData" / "Local") + cache_dir = base / "media-server" / "cache" / "thumbnails" + else: + # XDG_CACHE_HOME convention; falls back to ~/.cache. + xdg = os.environ.get("XDG_CACHE_HOME") + base = Path(xdg) if xdg else Path.home() / ".cache" + cache_dir = base / "media-server" / "thumbnails" cache_dir.mkdir(parents=True, exist_ok=True) return cache_dir diff --git a/media_server/services/websocket_manager.py b/media_server/services/websocket_manager.py index f0a6917..bed4050 100644 --- a/media_server/services/websocket_manager.py +++ b/media_server/services/websocket_manager.py @@ -33,9 +33,15 @@ class ConnectionManager: self._audio_task: asyncio.Task | None = None self._audio_analyzer = None - async def connect(self, websocket: WebSocket) -> None: - """Accept a new WebSocket connection.""" - await websocket.accept() + async def connect(self, websocket: WebSocket, already_accepted: bool = False) -> None: + """Accept a new WebSocket connection. + + ``already_accepted=True`` is for callers that needed to call + ``websocket.accept(subprotocol=...)`` themselves (token-via-subprotocol + auth path). + """ + if not already_accepted: + await websocket.accept() async with self._lock: self._active_connections.add(websocket) logger.info( diff --git a/media_server/services/windows_media.py b/media_server/services/windows_media.py index a073fc6..973222c 100644 --- a/media_server/services/windows_media.py +++ b/media_server/services/windows_media.py @@ -31,8 +31,15 @@ def _thread_loop() -> asyncio.AbstractEventLoop: _thread_local.loop = loop return loop -# Global storage for current album art (as bytes) +# Global storage for current album art (as bytes). Guarded by _art_lock so the +# WinRT polling thread and the FastAPI handler thread don't race on swap. _current_album_art_bytes: bytes | None = None +_art_lock = threading.Lock() + +# Identity of the track whose art is currently in _current_album_art_bytes. +# Used to gate the expensive WinRT thumbnail.open_read_async() so the bytes +# aren't re-decoded on every 500ms status poll. +_current_album_art_key: tuple | None = None # Lock protecting _position_cache and _track_skip_pending from concurrent access _position_lock = threading.Lock() @@ -56,8 +63,9 @@ _track_skip_pending = { def get_current_album_art() -> bytes | None: - """Get the current album art bytes.""" - return _current_album_art_bytes + """Get the current album art bytes (thread-safe snapshot).""" + with _art_lock: + return _current_album_art_bytes # Windows-specific imports try: @@ -379,28 +387,48 @@ def _sync_get_media_status() -> dict[str, Any]: except Exception as e: logger.debug(f"Timeline parse error: {e}") - # Try to get album art (requires media_props) + # Try to get album art (requires media_props). Gated by track key so + # the WinRT IPC + bytes copy only runs when the track actually + # changes; otherwise we just preserve the existing cached bytes. if media_props: - try: - thumbnail = media_props.thumbnail - if thumbnail: - stream = loop.run_until_complete(thumbnail.open_read_async()) - if stream: - size = stream.size - if size > 0 and size < 10 * 1024 * 1024: # Max 10MB - from winsdk.windows.storage.streams import DataReader - reader = DataReader(stream) - loop.run_until_complete(reader.load_async(size)) - buffer = bytearray(size) - reader.read_bytes(buffer) - reader.close() - stream.close() + track_key = ( + getattr(media_props, "title", "") or "", + getattr(media_props, "artist", "") or "", + getattr(media_props, "album_title", "") or "", + ) + global _current_album_art_bytes, _current_album_art_key + if track_key == _current_album_art_key and _current_album_art_bytes: + # Same track — reuse cached art bytes without touching WinRT. + result["album_art_url"] = "/api/media/artwork" + else: + try: + thumbnail = media_props.thumbnail + if thumbnail: + stream = loop.run_until_complete(thumbnail.open_read_async()) + if stream: + size = stream.size + if size > 0 and size < 10 * 1024 * 1024: # Max 10MB + from winsdk.windows.storage.streams import DataReader + reader = DataReader(stream) + loop.run_until_complete(reader.load_async(size)) + buffer = bytearray(size) + reader.read_bytes(buffer) + reader.close() + stream.close() - global _current_album_art_bytes - _current_album_art_bytes = bytes(buffer) - result["album_art_url"] = "/api/media/artwork" - except Exception as e: - logger.debug(f"Failed to get album art: {e}") + with _art_lock: + _current_album_art_bytes = bytes(buffer) + _current_album_art_key = track_key + result["album_art_url"] = "/api/media/artwork" + else: + # No thumbnail on this track — drop stale bytes so + # the ETag flips and clients don't keep showing the + # previous album's cover. + with _art_lock: + _current_album_art_bytes = None + _current_album_art_key = track_key + except Exception as e: + logger.debug(f"Failed to get album art: {e}") result["source"] = session.source_app_user_model_id diff --git a/media_server/static/index.html b/media_server/static/index.html index 90b4d86..437620e 100644 --- a/media_server/static/index.html +++ b/media_server/static/index.html @@ -26,16 +26,16 @@
- -
@@ -48,8 +48,8 @@
-
- +