9b9a2b5c9f
When `cors_origins` was unset, the WS endpoint only allowed `http://localhost:<port>` and `http://127.0.0.1:<port>` as origins, so a browser opening the UI via the LAN IP (e.g. `http://192.168.2.100:8765` when bound to `0.0.0.0`) had its WebSocket closed with code 4003 and never recovered — leaving the Web UI in a permanent reconnect loop. Also accept any `Origin` whose authority matches the request's `Host` header (both `http://` and `https://` schemes). Same-origin is by definition not CSWSH, so the cross-origin defence added in v0.3.0 remains intact for genuine third-party LAN pages.
506 lines
17 KiB
Python
506 lines
17 KiB
Python
"""Media control API endpoints."""
|
|
|
|
import asyncio
|
|
import logging
|
|
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
HTTPException,
|
|
Query,
|
|
Request,
|
|
WebSocket,
|
|
WebSocketDisconnect,
|
|
status,
|
|
)
|
|
from fastapi.responses import Response
|
|
|
|
from ..auth import verify_token, verify_token_or_query
|
|
from ..config import settings
|
|
from ..models import MediaStatus, SeekRequest, VolumeRequest
|
|
from ..services import get_current_album_art, get_media_controller
|
|
from ..services.websocket_manager import ws_manager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/media", tags=["media"])
|
|
|
|
|
|
# Strong refs to background tasks so the asyncio GC can't drop them before
|
|
# they run. Mirrors the pattern used in routes/browser.py.
|
|
_background_callback_tasks: set[asyncio.Task] = set()
|
|
|
|
|
|
def _run_callback(callback_name: str) -> None:
|
|
"""Fire-and-forget a callback if configured. Failures are logged but don't block."""
|
|
if not settings.callbacks or callback_name not in settings.callbacks:
|
|
return
|
|
|
|
async def _execute():
|
|
# Use the dedicated callback executor (not the default loop pool) so a
|
|
# misbehaving callback can't starve the rest of the app's sync tasks.
|
|
from ..services.audit_log import record_script_execution
|
|
from .callbacks import _callback_executor
|
|
from .scripts import _run_script
|
|
|
|
try:
|
|
callback = settings.callbacks[callback_name]
|
|
loop = asyncio.get_running_loop()
|
|
result = await loop.run_in_executor(
|
|
_callback_executor,
|
|
lambda: _run_script(
|
|
command=callback.command,
|
|
timeout=callback.timeout,
|
|
shell=callback.shell,
|
|
working_dir=callback.working_dir,
|
|
),
|
|
)
|
|
record_script_execution(
|
|
kind="event-callback",
|
|
name=callback_name,
|
|
exit_code=result["exit_code"],
|
|
duration=result.get("execution_time"),
|
|
stdout=result.get("stdout"),
|
|
stderr=result.get("stderr"),
|
|
)
|
|
if result["exit_code"] != 0:
|
|
logger.warning(
|
|
"Callback %s failed with exit code %s: %s",
|
|
callback_name,
|
|
result["exit_code"],
|
|
result["stderr"],
|
|
)
|
|
except Exception as e:
|
|
logger.error("Callback %s error: %s", callback_name, e)
|
|
from ..services.audit_log import record_script_execution as _rec
|
|
_rec(
|
|
kind="event-callback",
|
|
name=callback_name,
|
|
exit_code=None,
|
|
duration=None,
|
|
error=str(e),
|
|
)
|
|
|
|
task = asyncio.create_task(_execute())
|
|
_background_callback_tasks.add(task)
|
|
task.add_done_callback(_background_callback_tasks.discard)
|
|
|
|
|
|
@router.get("/status", response_model=MediaStatus)
|
|
async def get_media_status(_: str = Depends(verify_token)) -> MediaStatus:
|
|
"""Get current media playback status.
|
|
|
|
Returns:
|
|
Current playback state, media info, volume, etc.
|
|
"""
|
|
controller = get_media_controller()
|
|
return await controller.get_status()
|
|
|
|
|
|
@router.post("/play")
|
|
async def play(_: str = Depends(verify_token)) -> dict:
|
|
"""Resume or start playback.
|
|
|
|
Returns:
|
|
Success status
|
|
"""
|
|
controller = get_media_controller()
|
|
success = await controller.play()
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Failed to start playback - no active media session",
|
|
)
|
|
_run_callback("on_play")
|
|
return {"success": True}
|
|
|
|
|
|
@router.post("/pause")
|
|
async def pause(_: str = Depends(verify_token)) -> dict:
|
|
"""Pause playback.
|
|
|
|
Returns:
|
|
Success status
|
|
"""
|
|
controller = get_media_controller()
|
|
success = await controller.pause()
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Failed to pause - no active media session",
|
|
)
|
|
_run_callback("on_pause")
|
|
return {"success": True}
|
|
|
|
|
|
@router.post("/stop")
|
|
async def stop(_: str = Depends(verify_token)) -> dict:
|
|
"""Stop playback.
|
|
|
|
Returns:
|
|
Success status
|
|
"""
|
|
controller = get_media_controller()
|
|
success = await controller.stop()
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Failed to stop - no active media session",
|
|
)
|
|
_run_callback("on_stop")
|
|
return {"success": True}
|
|
|
|
|
|
@router.post("/next")
|
|
async def next_track(_: str = Depends(verify_token)) -> dict:
|
|
"""Skip to next track.
|
|
|
|
Returns:
|
|
Success status
|
|
"""
|
|
controller = get_media_controller()
|
|
success = await controller.next_track()
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Failed to skip - no active media session",
|
|
)
|
|
_run_callback("on_next")
|
|
return {"success": True}
|
|
|
|
|
|
@router.post("/previous")
|
|
async def previous_track(_: str = Depends(verify_token)) -> dict:
|
|
"""Go to previous track.
|
|
|
|
Returns:
|
|
Success status
|
|
"""
|
|
controller = get_media_controller()
|
|
success = await controller.previous_track()
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Failed to go back - no active media session",
|
|
)
|
|
_run_callback("on_previous")
|
|
return {"success": True}
|
|
|
|
|
|
@router.post("/volume")
|
|
async def set_volume(
|
|
request: VolumeRequest, _: str = Depends(verify_token)
|
|
) -> dict:
|
|
"""Set the system volume.
|
|
|
|
Args:
|
|
request: Volume level (0-100)
|
|
|
|
Returns:
|
|
Success status with new volume level
|
|
"""
|
|
controller = get_media_controller()
|
|
success = await controller.set_volume(request.volume)
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Failed to set volume",
|
|
)
|
|
_run_callback("on_volume")
|
|
return {"success": True, "volume": request.volume}
|
|
|
|
|
|
@router.post("/mute")
|
|
async def toggle_mute(_: str = Depends(verify_token)) -> dict:
|
|
"""Toggle mute state.
|
|
|
|
Returns:
|
|
Success status with new mute state
|
|
"""
|
|
controller = get_media_controller()
|
|
muted = await controller.toggle_mute()
|
|
_run_callback("on_mute")
|
|
return {"success": True, "muted": muted}
|
|
|
|
|
|
@router.post("/seek")
|
|
async def seek(request: SeekRequest, _: str = Depends(verify_token)) -> dict:
|
|
"""Seek to a position in the current track.
|
|
|
|
Args:
|
|
request: Position in seconds
|
|
|
|
Returns:
|
|
Success status
|
|
"""
|
|
controller = get_media_controller()
|
|
success = await controller.seek(request.position)
|
|
if not success:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Failed to seek - no active media session or seek not supported",
|
|
)
|
|
_run_callback("on_seek")
|
|
return {"success": True, "position": request.position}
|
|
|
|
|
|
@router.post("/turn_on")
|
|
async def turn_on(_: str = Depends(verify_token)) -> dict:
|
|
"""Execute turn on callback if configured.
|
|
|
|
Returns:
|
|
Success status
|
|
"""
|
|
_run_callback("on_turn_on")
|
|
return {"success": True}
|
|
|
|
|
|
@router.post("/turn_off")
|
|
async def turn_off(_: str = Depends(verify_token)) -> dict:
|
|
"""Execute turn off callback if configured.
|
|
|
|
Returns:
|
|
Success status
|
|
"""
|
|
_run_callback("on_turn_off")
|
|
return {"success": True}
|
|
|
|
|
|
@router.post("/toggle")
|
|
async def toggle(_: str = Depends(verify_token)) -> dict:
|
|
"""Execute toggle callback if configured.
|
|
|
|
Returns:
|
|
Success status
|
|
"""
|
|
_run_callback("on_toggle")
|
|
return {"success": True}
|
|
|
|
|
|
@router.get("/artwork")
|
|
async def get_artwork(
|
|
request: Request,
|
|
_: str = Depends(verify_token_or_query),
|
|
) -> Response:
|
|
"""Get the current album artwork.
|
|
|
|
Returns the bytes with a content-derived ETag so the browser can serve a
|
|
304 when the same track is re-requested.
|
|
"""
|
|
art_bytes = get_current_album_art()
|
|
if art_bytes is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="No album artwork available",
|
|
)
|
|
|
|
# Detect image type from magic bytes
|
|
if art_bytes[:3] == b"\xff\xd8\xff":
|
|
content_type = "image/jpeg"
|
|
elif art_bytes[:8] == b"\x89PNG\r\n\x1a\n":
|
|
content_type = "image/png"
|
|
elif art_bytes[:4] == b"RIFF" and len(art_bytes) > 12 and art_bytes[8:12] == b"WEBP":
|
|
content_type = "image/webp"
|
|
elif art_bytes[:2] == b"BM":
|
|
content_type = "image/bmp"
|
|
else:
|
|
content_type = "application/octet-stream"
|
|
|
|
# Content-derived ETag (blake2b-128 — non-crypto cache key, ruff S324-safe)
|
|
import hashlib
|
|
|
|
etag = '"' + hashlib.blake2b(art_bytes, digest_size=16).hexdigest() + '"'
|
|
|
|
if request.headers.get("if-none-match") == etag:
|
|
return Response(status_code=status.HTTP_304_NOT_MODIFIED, headers={"ETag": etag})
|
|
|
|
return Response(
|
|
content=art_bytes,
|
|
media_type=content_type,
|
|
headers={
|
|
"ETag": etag,
|
|
"Cache-Control": "private, max-age=0, must-revalidate",
|
|
},
|
|
)
|
|
|
|
|
|
@router.get("/visualizer/status")
|
|
async def visualizer_status(_: str = Depends(verify_token)) -> dict:
|
|
"""Check if audio visualizer is available and running."""
|
|
from ..services.audio_analyzer import get_audio_analyzer
|
|
|
|
analyzer = get_audio_analyzer()
|
|
return {
|
|
"available": analyzer.available,
|
|
"running": analyzer.running,
|
|
"current_device": analyzer.current_device,
|
|
}
|
|
|
|
|
|
@router.get("/visualizer/devices")
|
|
async def visualizer_devices(_: str = Depends(verify_token)) -> list[dict[str, str]]:
|
|
"""List available loopback audio devices for the visualizer."""
|
|
from ..services.audio_analyzer import AudioAnalyzer
|
|
|
|
loop = asyncio.get_running_loop()
|
|
return await loop.run_in_executor(None, AudioAnalyzer.list_loopback_devices)
|
|
|
|
|
|
@router.post("/visualizer/device")
|
|
async def set_visualizer_device(
|
|
request: dict,
|
|
_: str = Depends(verify_token),
|
|
) -> dict:
|
|
"""Set the loopback audio device for the visualizer.
|
|
|
|
Body: {"device_name": "Device Name" | null}
|
|
Passing null resets to auto-detect.
|
|
"""
|
|
from ..services.audio_analyzer import get_audio_analyzer
|
|
|
|
device_name = request.get("device_name")
|
|
analyzer = get_audio_analyzer()
|
|
|
|
# set_device() handles stop/start internally if capture was running
|
|
success = analyzer.set_device(device_name)
|
|
|
|
# Persist selection to config.yaml so it survives server restarts
|
|
if success:
|
|
from ..config_manager import config_manager
|
|
|
|
config_manager.set_setting("visualizer_device", device_name)
|
|
|
|
return {
|
|
"success": success,
|
|
"current_device": analyzer.current_device,
|
|
"running": analyzer.running,
|
|
}
|
|
|
|
|
|
@router.websocket("/ws")
|
|
async def websocket_endpoint(
|
|
websocket: WebSocket,
|
|
token: str | None = Query(None, description="API authentication token (legacy)"),
|
|
) -> None:
|
|
"""WebSocket endpoint for real-time media status updates.
|
|
|
|
Authentication is accepted from two sources, in priority order:
|
|
1. ``Sec-WebSocket-Protocol`` subprotocol of the form
|
|
``media-server.token.<TOKEN>``. This is the preferred path because
|
|
the token never lands in the URL, request logs, or browser history.
|
|
The browser WebSocket API supports custom subprotocols natively.
|
|
2. ``?token=<TOKEN>`` query parameter (legacy, kept for back-compat
|
|
with older clients and the HA integration).
|
|
|
|
Messages sent to client:
|
|
- {"type": "status", "data": {...}} - Initial status on connect
|
|
- {"type": "status_update", "data": {...}} - Status changes
|
|
- {"type": "error", "message": "..."} - Error messages
|
|
|
|
Client can send:
|
|
- {"type": "ping"} - Keepalive, server responds with {"type": "pong"}
|
|
- {"type": "get_status"} - Request current status
|
|
"""
|
|
# Pull token from subprotocol if present. WebSocket spec lets either side
|
|
# negotiate exactly one subprotocol back; we accept the token one and
|
|
# answer with the same string so browsers consider the negotiation
|
|
# successful.
|
|
subprotocol_token: str | None = None
|
|
accept_subprotocol: str | None = None
|
|
raw_protocols = websocket.headers.get("sec-websocket-protocol", "")
|
|
for proto in (p.strip() for p in raw_protocols.split(",") if p.strip()):
|
|
if proto.startswith("media-server.token."):
|
|
subprotocol_token = proto[len("media-server.token."):]
|
|
accept_subprotocol = proto
|
|
break
|
|
effective_token = subprotocol_token or token
|
|
# Origin check — block CSWSH from third-party LAN pages. Accept the same
|
|
# set of origins as CORS plus the default localhost loopback, AND any
|
|
# same-origin connection (where Origin matches the request's Host header).
|
|
# Same-origin is inherently safe from CSWSH because CSWSH is a *cross*-
|
|
# origin attack — without this, binding to 0.0.0.0 and accessing the UI
|
|
# via a LAN IP would have its WebSocket rejected by the browser-sent
|
|
# Origin, which the static allowlist can't anticipate.
|
|
allowed_origins = set(
|
|
settings.cors_origins
|
|
or [
|
|
f"http://localhost:{settings.port}",
|
|
f"http://127.0.0.1:{settings.port}",
|
|
]
|
|
)
|
|
origin = websocket.headers.get("origin")
|
|
# Same-origin connections from native apps may omit Origin entirely; only
|
|
# reject when an Origin is present AND not in the allow-list.
|
|
if origin is not None and origin not in allowed_origins:
|
|
host_header = websocket.headers.get("host", "")
|
|
# Origin uses http/https; match against both scheme variants of Host
|
|
# so HTTPS deployments without an explicit cors_origins still work.
|
|
same_origin_candidates = (
|
|
{f"http://{host_header}", f"https://{host_header}"}
|
|
if host_header
|
|
else set()
|
|
)
|
|
if origin not in same_origin_candidates:
|
|
await websocket.close(code=4003, reason="Origin not allowed")
|
|
return
|
|
|
|
# Verify token
|
|
from ..auth import auth_enabled, get_token_label, token_label_var
|
|
|
|
if auth_enabled():
|
|
label = get_token_label(effective_token) if effective_token else None
|
|
if label is None:
|
|
await websocket.close(code=4001, reason="Invalid authentication token")
|
|
return
|
|
token_label_var.set(label)
|
|
else:
|
|
token_label_var.set("anonymous")
|
|
|
|
# Accept with the negotiated subprotocol if one was used. Starlette's
|
|
# connect() calls accept() with no subprotocol — we need to accept first
|
|
# explicitly to echo the subprotocol back, then hand off to the manager.
|
|
if accept_subprotocol is not None:
|
|
await websocket.accept(subprotocol=accept_subprotocol)
|
|
await ws_manager.connect(websocket, already_accepted=True)
|
|
else:
|
|
await ws_manager.connect(websocket)
|
|
|
|
try:
|
|
while True:
|
|
# Wait for messages from client (for keepalive/ping)
|
|
data = await websocket.receive_json()
|
|
|
|
msg_type = data.get("type") if isinstance(data, dict) else None
|
|
|
|
if msg_type == "ping":
|
|
await websocket.send_json({"type": "pong"})
|
|
elif msg_type == "get_status":
|
|
# Allow manual status request
|
|
controller = get_media_controller()
|
|
status_data = await controller.get_status()
|
|
await websocket.send_json({
|
|
"type": "status",
|
|
"data": status_data.model_dump(),
|
|
})
|
|
elif msg_type == "volume":
|
|
# Low-latency volume control via WebSocket. Coerce, clamp, and
|
|
# never drop the socket on a single bad message — that would
|
|
# turn the WS into a one-shot DoS for any holder of a token.
|
|
try:
|
|
volume = int(data.get("volume"))
|
|
except (TypeError, ValueError):
|
|
continue
|
|
volume = max(0, min(100, volume))
|
|
controller = get_media_controller()
|
|
await controller.set_volume(volume)
|
|
elif msg_type == "enable_visualizer":
|
|
await ws_manager.subscribe_visualizer(websocket)
|
|
elif msg_type == "disable_visualizer":
|
|
await ws_manager.unsubscribe_visualizer(websocket)
|
|
|
|
except WebSocketDisconnect:
|
|
await ws_manager.disconnect(websocket)
|
|
except Exception as e:
|
|
logger.error("WebSocket error: %s", e)
|
|
await ws_manager.disconnect(websocket)
|