From 5519e449cd8a722ba986c48c11e1cd00aeb7bd45 Mon Sep 17 00:00:00 2001 From: "alexei.dolgolyov" Date: Wed, 4 Feb 2026 14:02:53 +0300 Subject: [PATCH] Add WebSocket support for real-time media status updates Replace HTTP polling with WebSocket push notifications for instant state change responses. Server broadcasts updates only when significant changes occur (state, track, volume, etc.) while letting Home Assistant interpolate position during playback. Includes seek detection for timeline updates and automatic fallback to HTTP polling if WebSocket disconnects. Co-Authored-By: Claude Opus 4.5 --- .../remote_media_player/__init__.py | 6 + .../remote_media_player/api_client.py | 140 +++++++++++++ .../remote_media_player/const.py | 4 + .../remote_media_player/manifest.json | 2 +- .../remote_media_player/media_player.py | 119 ++++++++++- media_server/main.py | 11 + media_server/routes/media.py | 58 +++++- media_server/services/websocket_manager.py | 189 ++++++++++++++++++ 8 files changed, 521 insertions(+), 8 deletions(-) create mode 100644 media_server/services/websocket_manager.py diff --git a/custom_components/remote_media_player/__init__.py b/custom_components/remote_media_player/__init__.py index ab976cf..f71f3bb 100644 --- a/custom_components/remote_media_player/__init__.py +++ b/custom_components/remote_media_player/__init__.py @@ -138,6 +138,12 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if unload_ok: # Close client and remove data data = hass.data[DOMAIN].pop(entry.entry_id) + + # Shutdown coordinator (WebSocket cleanup) + if "coordinator" in data: + await data["coordinator"].async_shutdown() + + # Close HTTP client await data["client"].close() # Remove services if this was the last entry diff --git a/custom_components/remote_media_player/api_client.py b/custom_components/remote_media_player/api_client.py index 812ed0e..7dbcfc0 100644 --- a/custom_components/remote_media_player/api_client.py +++ b/custom_components/remote_media_player/api_client.py @@ -2,7 +2,10 @@ from __future__ import annotations +import asyncio +import hashlib import logging +from collections.abc import Callable from typing import Any import aiohttp @@ -265,3 +268,140 @@ class MediaServerClient: endpoint = f"{API_SCRIPTS_EXECUTE}/{script_name}" json_data = {"args": args or []} return await self._request("POST", endpoint, json_data) + + +class MediaServerWebSocket: + """WebSocket client for real-time media status updates.""" + + def __init__( + self, + host: str, + port: int, + token: str, + on_status_update: Callable[[dict[str, Any]], None], + on_disconnect: Callable[[], None] | None = None, + ) -> None: + """Initialize the WebSocket client. + + Args: + host: Server hostname or IP + port: Server port + token: API authentication token + on_status_update: Callback when status update received + on_disconnect: Callback when connection lost + """ + self._host = host + self._port = int(port) + self._token = token + self._on_status_update = on_status_update + self._on_disconnect = on_disconnect + self._ws_url = f"ws://{host}:{self._port}/api/media/ws?token={token}" + self._session: aiohttp.ClientSession | None = None + self._ws: aiohttp.ClientWebSocketResponse | None = None + self._receive_task: asyncio.Task | None = None + self._running = False + + async def connect(self) -> bool: + """Establish WebSocket connection. + + Returns: + True if connection successful + """ + try: + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + + self._ws = await self._session.ws_connect( + self._ws_url, + heartbeat=30, + timeout=aiohttp.ClientTimeout(total=10), + ) + self._running = True + + # Start receive loop + self._receive_task = asyncio.create_task(self._receive_loop()) + + _LOGGER.info("WebSocket connected to %s:%s", self._host, self._port) + return True + + except Exception as err: + _LOGGER.warning("WebSocket connection failed: %s", err) + return False + + async def disconnect(self) -> None: + """Close WebSocket connection.""" + self._running = False + + if self._receive_task: + self._receive_task.cancel() + try: + await self._receive_task + except asyncio.CancelledError: + pass + self._receive_task = None + + if self._ws and not self._ws.closed: + await self._ws.close() + self._ws = None + + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + _LOGGER.debug("WebSocket disconnected") + + async def _receive_loop(self) -> None: + """Background loop to receive WebSocket messages.""" + while self._running and self._ws and not self._ws.closed: + try: + msg = await self._ws.receive(timeout=60) + + if msg.type == aiohttp.WSMsgType.TEXT: + data = msg.json() + msg_type = data.get("type") + + if msg_type in ("status", "status_update"): + status_data = data.get("data", {}) + # Convert album art URL to absolute + if ( + status_data.get("album_art_url") + and status_data["album_art_url"].startswith("/") + ): + track_id = f"{status_data.get('title', '')}-{status_data.get('artist', '')}" + track_hash = hashlib.md5(track_id.encode()).hexdigest()[:8] + status_data["album_art_url"] = ( + f"http://{self._host}:{self._port}" + f"{status_data['album_art_url']}?token={self._token}&t={track_hash}" + ) + self._on_status_update(status_data) + elif msg_type == "pong": + _LOGGER.debug("Received pong") + + elif msg.type == aiohttp.WSMsgType.CLOSED: + _LOGGER.warning("WebSocket closed by server") + break + elif msg.type == aiohttp.WSMsgType.ERROR: + _LOGGER.error("WebSocket error: %s", self._ws.exception()) + break + + except asyncio.TimeoutError: + # Send ping to keep connection alive + if self._ws and not self._ws.closed: + try: + await self._ws.send_json({"type": "ping"}) + except Exception: + break + except asyncio.CancelledError: + break + except Exception as err: + _LOGGER.error("WebSocket receive error: %s", err) + break + + # Connection lost, notify callback + if self._on_disconnect: + self._on_disconnect() + + @property + def is_connected(self) -> bool: + """Return True if WebSocket is connected.""" + return self._ws is not None and not self._ws.closed diff --git a/custom_components/remote_media_player/const.py b/custom_components/remote_media_player/const.py index c437701..c9dafd5 100644 --- a/custom_components/remote_media_player/const.py +++ b/custom_components/remote_media_player/const.py @@ -8,11 +8,14 @@ CONF_PORT = "port" CONF_TOKEN = "token" CONF_POLL_INTERVAL = "poll_interval" CONF_NAME = "name" +CONF_USE_WEBSOCKET = "use_websocket" # Default values DEFAULT_PORT = 8765 DEFAULT_POLL_INTERVAL = 5 DEFAULT_NAME = "Remote Media Player" +DEFAULT_USE_WEBSOCKET = True +DEFAULT_RECONNECT_INTERVAL = 30 # API endpoints API_HEALTH = "/api/health" @@ -27,6 +30,7 @@ API_MUTE = "/api/media/mute" API_SEEK = "/api/media/seek" API_SCRIPTS_LIST = "/api/scripts/list" API_SCRIPTS_EXECUTE = "/api/scripts/execute" +API_WEBSOCKET = "/api/media/ws" # Service names SERVICE_EXECUTE_SCRIPT = "execute_script" diff --git a/custom_components/remote_media_player/manifest.json b/custom_components/remote_media_player/manifest.json index edaab2b..be908fe 100644 --- a/custom_components/remote_media_player/manifest.json +++ b/custom_components/remote_media_player/manifest.json @@ -6,7 +6,7 @@ "dependencies": [], "documentation": "https://github.com/your-username/haos-integration-media-player", "integration_type": "device", - "iot_class": "local_polling", + "iot_class": "local_push", "requirements": ["aiohttp>=3.8.0"], "version": "1.0.0" } diff --git a/custom_components/remote_media_player/media_player.py b/custom_components/remote_media_player/media_player.py index 2a5385c..319afb7 100644 --- a/custom_components/remote_media_player/media_player.py +++ b/custom_components/remote_media_player/media_player.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import logging from datetime import datetime, timedelta from typing import Any @@ -14,7 +15,7 @@ from homeassistant.components.media_player import ( ) from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_NAME -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import ( @@ -23,14 +24,18 @@ from homeassistant.helpers.update_coordinator import ( UpdateFailed, ) -from .api_client import MediaServerClient, MediaServerError +from .api_client import MediaServerClient, MediaServerError, MediaServerWebSocket from .const import ( DOMAIN, CONF_HOST, CONF_PORT, + CONF_TOKEN, CONF_POLL_INTERVAL, + CONF_USE_WEBSOCKET, DEFAULT_POLL_INTERVAL, DEFAULT_NAME, + DEFAULT_USE_WEBSOCKET, + DEFAULT_RECONNECT_INTERVAL, ) _LOGGER = logging.getLogger(__name__) @@ -62,13 +67,26 @@ async def async_setup_entry( entry.data.get(CONF_POLL_INTERVAL, DEFAULT_POLL_INTERVAL), ) - # Create update coordinator + # Get WebSocket setting from options or data + use_websocket = entry.options.get( + CONF_USE_WEBSOCKET, + entry.data.get(CONF_USE_WEBSOCKET, DEFAULT_USE_WEBSOCKET), + ) + + # Create update coordinator with WebSocket support coordinator = MediaPlayerCoordinator( hass, client, poll_interval, + host=entry.data[CONF_HOST], + port=entry.data[CONF_PORT], + token=entry.data[CONF_TOKEN], + use_websocket=use_websocket, ) + # Set up WebSocket connection if enabled + await coordinator.async_setup() + # Fetch initial data - don't fail setup if this fails try: await coordinator.async_config_entry_first_refresh() @@ -76,6 +94,9 @@ async def async_setup_entry( _LOGGER.warning("Initial data fetch failed, will retry: %s", err) # Continue anyway - the coordinator will retry + # Store coordinator for cleanup + hass.data[DOMAIN][entry.entry_id]["coordinator"] = coordinator + # Create and add entity entity = RemoteMediaPlayerEntity( coordinator, @@ -86,13 +107,17 @@ async def async_setup_entry( class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]): - """Coordinator for fetching media player data.""" + """Coordinator for fetching media player data with WebSocket support.""" def __init__( self, hass: HomeAssistant, client: MediaServerClient, poll_interval: int, + host: str, + port: int, + token: str, + use_websocket: bool = True, ) -> None: """Initialize the coordinator. @@ -100,6 +125,10 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]): hass: Home Assistant instance client: Media Server API client poll_interval: Update interval in seconds + host: Server hostname + port: Server port + token: API token + use_websocket: Whether to use WebSocket for updates """ super().__init__( hass, @@ -108,9 +137,76 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]): update_interval=timedelta(seconds=poll_interval), ) self.client = client + self._host = host + self._port = port + self._token = token + self._use_websocket = use_websocket + self._ws_client: MediaServerWebSocket | None = None + self._ws_connected = False + self._reconnect_task: asyncio.Task | None = None + self._poll_interval = poll_interval + + async def async_setup(self) -> None: + """Set up the coordinator with WebSocket if enabled.""" + if self._use_websocket: + await self._connect_websocket() + + async def _connect_websocket(self) -> None: + """Establish WebSocket connection.""" + if self._ws_client: + await self._ws_client.disconnect() + + self._ws_client = MediaServerWebSocket( + host=self._host, + port=self._port, + token=self._token, + on_status_update=self._handle_ws_status_update, + on_disconnect=self._handle_ws_disconnect, + ) + + if await self._ws_client.connect(): + self._ws_connected = True + # Disable polling - WebSocket handles all updates including position + self.update_interval = None + _LOGGER.info("WebSocket connected, polling disabled") + else: + self._ws_connected = False + # Keep polling as fallback + self.update_interval = timedelta(seconds=self._poll_interval) + _LOGGER.warning("WebSocket failed, falling back to polling") + # Schedule reconnect attempt + self._schedule_reconnect() + + @callback + def _handle_ws_status_update(self, status_data: dict[str, Any]) -> None: + """Handle status update from WebSocket.""" + self.async_set_updated_data(status_data) + + @callback + def _handle_ws_disconnect(self) -> None: + """Handle WebSocket disconnection.""" + self._ws_connected = False + # Re-enable polling as fallback + self.update_interval = timedelta(seconds=self._poll_interval) + _LOGGER.warning("WebSocket disconnected, falling back to polling") + # Schedule reconnect attempt + self._schedule_reconnect() + + def _schedule_reconnect(self) -> None: + """Schedule a WebSocket reconnection attempt.""" + if self._reconnect_task and not self._reconnect_task.done(): + return # Already scheduled + + async def reconnect() -> None: + await asyncio.sleep(DEFAULT_RECONNECT_INTERVAL) + if self._use_websocket and not self._ws_connected: + _LOGGER.info("Attempting WebSocket reconnect...") + await self._connect_websocket() + + self._reconnect_task = self.hass.async_create_task(reconnect()) async def _async_update_data(self) -> dict[str, Any]: - """Fetch data from the API. + """Fetch data from the API (fallback when WebSocket unavailable). Returns: Media status data @@ -120,7 +216,7 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]): """ try: data = await self.client.get_status() - _LOGGER.debug("Received media status: %s", data) + _LOGGER.debug("HTTP poll received status: %s", data.get("state")) return data except MediaServerError as err: raise UpdateFailed(f"Error communicating with server: {err}") from err @@ -128,6 +224,17 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]): _LOGGER.exception("Unexpected error fetching media status") raise UpdateFailed(f"Unexpected error: {err}") from err + async def async_shutdown(self) -> None: + """Clean up resources.""" + if self._reconnect_task: + self._reconnect_task.cancel() + try: + await self._reconnect_task + except asyncio.CancelledError: + pass + if self._ws_client: + await self._ws_client.disconnect() + class RemoteMediaPlayerEntity(CoordinatorEntity[MediaPlayerCoordinator], MediaPlayerEntity): """Representation of a Remote Media Player.""" diff --git a/media_server/main.py b/media_server/main.py index 1f57e66..1cecaa6 100644 --- a/media_server/main.py +++ b/media_server/main.py @@ -11,6 +11,8 @@ from fastapi.middleware.cors import CORSMiddleware from .config import settings, generate_default_config, get_config_dir from .routes import health_router, media_router, scripts_router +from .services import get_media_controller +from .services.websocket_manager import ws_manager def setup_logging(): @@ -29,7 +31,16 @@ async def lifespan(app: FastAPI): logger = logging.getLogger(__name__) logger.info(f"Media Server starting on {settings.host}:{settings.port}") logger.info(f"API Token: {settings.api_token[:8]}...") + + # Start WebSocket status monitor + controller = get_media_controller() + await ws_manager.start_status_monitor(controller.get_status) + logger.info("WebSocket status monitor started") + yield + + # Stop WebSocket status monitor + await ws_manager.stop_status_monitor() logger.info("Media Server shutting down") diff --git a/media_server/routes/media.py b/media_server/routes/media.py index 107ea3f..618d0aa 100644 --- a/media_server/routes/media.py +++ b/media_server/routes/media.py @@ -1,11 +1,18 @@ """Media control API endpoints.""" -from fastapi import APIRouter, Depends, HTTPException, status +import logging + +from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect +from fastapi import status from fastapi.responses import Response from ..auth import verify_token, verify_token_or_query +from ..config import settings from ..models import MediaStatus, VolumeRequest, SeekRequest from ..services import get_media_controller, get_current_album_art +from ..services.websocket_manager import ws_manager + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/media", tags=["media"]) @@ -184,3 +191,52 @@ async def get_artwork(_: str = Depends(verify_token_or_query)) -> Response: content_type = "image/webp" return Response(content=art_bytes, media_type=content_type) + + +@router.websocket("/ws") +async def websocket_endpoint( + websocket: WebSocket, + token: str = Query(..., description="API authentication token"), +) -> None: + """WebSocket endpoint for real-time media status updates. + + Authentication is done via query parameter since WebSocket + doesn't support custom headers in the browser. + + Messages sent to client: + - {"type": "status", "data": {...}} - Initial status on connect + - {"type": "status_update", "data": {...}} - Status changes + - {"type": "error", "message": "..."} - Error messages + + Client can send: + - {"type": "ping"} - Keepalive, server responds with {"type": "pong"} + - {"type": "get_status"} - Request current status + """ + # Verify token + if token != settings.api_token: + await websocket.close(code=4001, reason="Invalid authentication token") + return + + await ws_manager.connect(websocket) + + try: + while True: + # Wait for messages from client (for keepalive/ping) + data = await websocket.receive_json() + + if data.get("type") == "ping": + await websocket.send_json({"type": "pong"}) + elif data.get("type") == "get_status": + # Allow manual status request + controller = get_media_controller() + status_data = await controller.get_status() + await websocket.send_json({ + "type": "status", + "data": status_data.model_dump(), + }) + + except WebSocketDisconnect: + await ws_manager.disconnect(websocket) + except Exception as e: + logger.error("WebSocket error: %s", e) + await ws_manager.disconnect(websocket) diff --git a/media_server/services/websocket_manager.py b/media_server/services/websocket_manager.py new file mode 100644 index 0000000..7289cb7 --- /dev/null +++ b/media_server/services/websocket_manager.py @@ -0,0 +1,189 @@ +"""WebSocket connection manager and status broadcaster.""" + +import asyncio +import logging +import time +from typing import Any, Callable, Coroutine + +from fastapi import WebSocket + +logger = logging.getLogger(__name__) + + +class ConnectionManager: + """Manages WebSocket connections and broadcasts status updates.""" + + def __init__(self) -> None: + """Initialize the connection manager.""" + self._active_connections: set[WebSocket] = set() + self._lock = asyncio.Lock() + self._last_status: dict[str, Any] | None = None + self._broadcast_task: asyncio.Task | None = None + self._poll_interval: float = 0.5 # Internal poll interval for change detection + self._position_broadcast_interval: float = 5.0 # Send position updates every 5s during playback + self._last_broadcast_time: float = 0.0 + self._running: bool = False + + async def connect(self, websocket: WebSocket) -> None: + """Accept a new WebSocket connection.""" + await websocket.accept() + async with self._lock: + self._active_connections.add(websocket) + logger.info( + "WebSocket client connected. Total: %d", len(self._active_connections) + ) + + # Send current status immediately upon connection + if self._last_status: + try: + await websocket.send_json({"type": "status", "data": self._last_status}) + except Exception as e: + logger.debug("Failed to send initial status: %s", e) + + async def disconnect(self, websocket: WebSocket) -> None: + """Remove a WebSocket connection.""" + async with self._lock: + self._active_connections.discard(websocket) + logger.info( + "WebSocket client disconnected. Total: %d", len(self._active_connections) + ) + + async def broadcast(self, message: dict[str, Any]) -> None: + """Broadcast a message to all connected clients.""" + async with self._lock: + connections = list(self._active_connections) + + if not connections: + return + + disconnected = [] + for websocket in connections: + try: + await websocket.send_json(message) + except Exception as e: + logger.debug("Failed to send to client: %s", e) + disconnected.append(websocket) + + # Clean up disconnected clients + for ws in disconnected: + await self.disconnect(ws) + + def status_changed( + self, old: dict[str, Any] | None, new: dict[str, Any] + ) -> bool: + """Detect if media status has meaningfully changed. + + Position is NOT included for normal playback (let HA interpolate). + But seeks (large unexpected jumps) are detected. + """ + if old is None: + return True + + # Fields to compare for changes (NO position - let HA interpolate) + significant_fields = [ + "state", + "title", + "artist", + "album", + "volume", + "muted", + "duration", + "source", + "album_art_url", + ] + + for field in significant_fields: + if old.get(field) != new.get(field): + return True + + # Detect seeks - large position jumps that aren't normal playback + old_pos = old.get("position") or 0 + new_pos = new.get("position") or 0 + pos_diff = new_pos - old_pos + + # During playback, position should increase by ~0.5s (our poll interval) + # A seek is when position jumps backwards OR forward by more than expected + if new.get("state") == "playing": + # Backward seek or forward jump > 3s indicates seek + if pos_diff < -1.0 or pos_diff > 3.0: + return True + else: + # When paused, any significant position change is a seek + if abs(pos_diff) > 1.0: + return True + + return False + + async def start_status_monitor( + self, + get_status_func: Callable[[], Coroutine[Any, Any, Any]], + ) -> None: + """Start the background status monitoring loop.""" + if self._running: + return + + self._running = True + self._broadcast_task = asyncio.create_task( + self._status_monitor_loop(get_status_func) + ) + logger.info("WebSocket status monitor started") + + async def stop_status_monitor(self) -> None: + """Stop the background status monitoring loop.""" + self._running = False + if self._broadcast_task: + self._broadcast_task.cancel() + try: + await self._broadcast_task + except asyncio.CancelledError: + pass + logger.info("WebSocket status monitor stopped") + + async def _status_monitor_loop( + self, + get_status_func: Callable[[], Coroutine[Any, Any, Any]], + ) -> None: + """Background loop that polls for status changes and broadcasts.""" + while self._running: + try: + # Only poll if we have connected clients + async with self._lock: + has_clients = len(self._active_connections) > 0 + + if has_clients: + status = await get_status_func() + status_dict = status.model_dump() + + # Only broadcast on actual state changes + # Let HA handle position interpolation during playback + if self.status_changed(self._last_status, status_dict): + self._last_status = status_dict + self._last_broadcast_time = time.time() + await self.broadcast( + {"type": "status_update", "data": status_dict} + ) + logger.debug("Broadcast sent: status change") + else: + # Update cached status even without broadcast + self._last_status = status_dict + else: + # Still update cache for when clients connect + status = await get_status_func() + self._last_status = status.model_dump() + + await asyncio.sleep(self._poll_interval) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error("Error in status monitor: %s", e) + await asyncio.sleep(self._poll_interval) + + @property + def client_count(self) -> int: + """Return the number of connected clients.""" + return len(self._active_connections) + + +# Global instance +ws_manager = ConnectionManager()