Add WebSocket support for real-time media status updates

Replace HTTP polling with WebSocket push notifications for instant
state change responses. Server broadcasts updates only when significant
changes occur (state, track, volume, etc.) while letting Home Assistant
interpolate position during playback. Includes seek detection for
timeline updates and automatic fallback to HTTP polling if WebSocket
disconnects.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-04 14:02:53 +03:00
parent 67a89e8349
commit 5519e449cd
8 changed files with 521 additions and 8 deletions

View File

@@ -11,6 +11,8 @@ from fastapi.middleware.cors import CORSMiddleware
from .config import settings, generate_default_config, get_config_dir
from .routes import health_router, media_router, scripts_router
from .services import get_media_controller
from .services.websocket_manager import ws_manager
def setup_logging():
@@ -29,7 +31,16 @@ async def lifespan(app: FastAPI):
logger = logging.getLogger(__name__)
logger.info(f"Media Server starting on {settings.host}:{settings.port}")
logger.info(f"API Token: {settings.api_token[:8]}...")
# Start WebSocket status monitor
controller = get_media_controller()
await ws_manager.start_status_monitor(controller.get_status)
logger.info("WebSocket status monitor started")
yield
# Stop WebSocket status monitor
await ws_manager.stop_status_monitor()
logger.info("Media Server shutting down")

View File

@@ -1,11 +1,18 @@
"""Media control API endpoints."""
from fastapi import APIRouter, Depends, HTTPException, status
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi import status
from fastapi.responses import Response
from ..auth import verify_token, verify_token_or_query
from ..config import settings
from ..models import MediaStatus, VolumeRequest, SeekRequest
from ..services import get_media_controller, get_current_album_art
from ..services.websocket_manager import ws_manager
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/media", tags=["media"])
@@ -184,3 +191,52 @@ async def get_artwork(_: str = Depends(verify_token_or_query)) -> Response:
content_type = "image/webp"
return Response(content=art_bytes, media_type=content_type)
@router.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str = Query(..., description="API authentication token"),
) -> None:
"""WebSocket endpoint for real-time media status updates.
Authentication is done via query parameter since WebSocket
doesn't support custom headers in the browser.
Messages sent to client:
- {"type": "status", "data": {...}} - Initial status on connect
- {"type": "status_update", "data": {...}} - Status changes
- {"type": "error", "message": "..."} - Error messages
Client can send:
- {"type": "ping"} - Keepalive, server responds with {"type": "pong"}
- {"type": "get_status"} - Request current status
"""
# Verify token
if token != settings.api_token:
await websocket.close(code=4001, reason="Invalid authentication token")
return
await ws_manager.connect(websocket)
try:
while True:
# Wait for messages from client (for keepalive/ping)
data = await websocket.receive_json()
if data.get("type") == "ping":
await websocket.send_json({"type": "pong"})
elif data.get("type") == "get_status":
# Allow manual status request
controller = get_media_controller()
status_data = await controller.get_status()
await websocket.send_json({
"type": "status",
"data": status_data.model_dump(),
})
except WebSocketDisconnect:
await ws_manager.disconnect(websocket)
except Exception as e:
logger.error("WebSocket error: %s", e)
await ws_manager.disconnect(websocket)

View File

@@ -0,0 +1,189 @@
"""WebSocket connection manager and status broadcaster."""
import asyncio
import logging
import time
from typing import Any, Callable, Coroutine
from fastapi import WebSocket
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Manages WebSocket connections and broadcasts status updates."""
def __init__(self) -> None:
"""Initialize the connection manager."""
self._active_connections: set[WebSocket] = set()
self._lock = asyncio.Lock()
self._last_status: dict[str, Any] | None = None
self._broadcast_task: asyncio.Task | None = None
self._poll_interval: float = 0.5 # Internal poll interval for change detection
self._position_broadcast_interval: float = 5.0 # Send position updates every 5s during playback
self._last_broadcast_time: float = 0.0
self._running: bool = False
async def connect(self, websocket: WebSocket) -> None:
"""Accept a new WebSocket connection."""
await websocket.accept()
async with self._lock:
self._active_connections.add(websocket)
logger.info(
"WebSocket client connected. Total: %d", len(self._active_connections)
)
# Send current status immediately upon connection
if self._last_status:
try:
await websocket.send_json({"type": "status", "data": self._last_status})
except Exception as e:
logger.debug("Failed to send initial status: %s", e)
async def disconnect(self, websocket: WebSocket) -> None:
"""Remove a WebSocket connection."""
async with self._lock:
self._active_connections.discard(websocket)
logger.info(
"WebSocket client disconnected. Total: %d", len(self._active_connections)
)
async def broadcast(self, message: dict[str, Any]) -> None:
"""Broadcast a message to all connected clients."""
async with self._lock:
connections = list(self._active_connections)
if not connections:
return
disconnected = []
for websocket in connections:
try:
await websocket.send_json(message)
except Exception as e:
logger.debug("Failed to send to client: %s", e)
disconnected.append(websocket)
# Clean up disconnected clients
for ws in disconnected:
await self.disconnect(ws)
def status_changed(
self, old: dict[str, Any] | None, new: dict[str, Any]
) -> bool:
"""Detect if media status has meaningfully changed.
Position is NOT included for normal playback (let HA interpolate).
But seeks (large unexpected jumps) are detected.
"""
if old is None:
return True
# Fields to compare for changes (NO position - let HA interpolate)
significant_fields = [
"state",
"title",
"artist",
"album",
"volume",
"muted",
"duration",
"source",
"album_art_url",
]
for field in significant_fields:
if old.get(field) != new.get(field):
return True
# Detect seeks - large position jumps that aren't normal playback
old_pos = old.get("position") or 0
new_pos = new.get("position") or 0
pos_diff = new_pos - old_pos
# During playback, position should increase by ~0.5s (our poll interval)
# A seek is when position jumps backwards OR forward by more than expected
if new.get("state") == "playing":
# Backward seek or forward jump > 3s indicates seek
if pos_diff < -1.0 or pos_diff > 3.0:
return True
else:
# When paused, any significant position change is a seek
if abs(pos_diff) > 1.0:
return True
return False
async def start_status_monitor(
self,
get_status_func: Callable[[], Coroutine[Any, Any, Any]],
) -> None:
"""Start the background status monitoring loop."""
if self._running:
return
self._running = True
self._broadcast_task = asyncio.create_task(
self._status_monitor_loop(get_status_func)
)
logger.info("WebSocket status monitor started")
async def stop_status_monitor(self) -> None:
"""Stop the background status monitoring loop."""
self._running = False
if self._broadcast_task:
self._broadcast_task.cancel()
try:
await self._broadcast_task
except asyncio.CancelledError:
pass
logger.info("WebSocket status monitor stopped")
async def _status_monitor_loop(
self,
get_status_func: Callable[[], Coroutine[Any, Any, Any]],
) -> None:
"""Background loop that polls for status changes and broadcasts."""
while self._running:
try:
# Only poll if we have connected clients
async with self._lock:
has_clients = len(self._active_connections) > 0
if has_clients:
status = await get_status_func()
status_dict = status.model_dump()
# Only broadcast on actual state changes
# Let HA handle position interpolation during playback
if self.status_changed(self._last_status, status_dict):
self._last_status = status_dict
self._last_broadcast_time = time.time()
await self.broadcast(
{"type": "status_update", "data": status_dict}
)
logger.debug("Broadcast sent: status change")
else:
# Update cached status even without broadcast
self._last_status = status_dict
else:
# Still update cache for when clients connect
status = await get_status_func()
self._last_status = status.model_dump()
await asyncio.sleep(self._poll_interval)
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in status monitor: %s", e)
await asyncio.sleep(self._poll_interval)
@property
def client_count(self) -> int:
"""Return the number of connected clients."""
return len(self._active_connections)
# Global instance
ws_manager = ConnectionManager()