diff --git a/custom_components/remote_media_player/__init__.py b/custom_components/remote_media_player/__init__.py index ab976cf..f71f3bb 100644 --- a/custom_components/remote_media_player/__init__.py +++ b/custom_components/remote_media_player/__init__.py @@ -138,6 +138,12 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: if unload_ok: # Close client and remove data data = hass.data[DOMAIN].pop(entry.entry_id) + + # Shutdown coordinator (WebSocket cleanup) + if "coordinator" in data: + await data["coordinator"].async_shutdown() + + # Close HTTP client await data["client"].close() # Remove services if this was the last entry diff --git a/custom_components/remote_media_player/api_client.py b/custom_components/remote_media_player/api_client.py index 812ed0e..7dbcfc0 100644 --- a/custom_components/remote_media_player/api_client.py +++ b/custom_components/remote_media_player/api_client.py @@ -2,7 +2,10 @@ from __future__ import annotations +import asyncio +import hashlib import logging +from collections.abc import Callable from typing import Any import aiohttp @@ -265,3 +268,140 @@ class MediaServerClient: endpoint = f"{API_SCRIPTS_EXECUTE}/{script_name}" json_data = {"args": args or []} return await self._request("POST", endpoint, json_data) + + +class MediaServerWebSocket: + """WebSocket client for real-time media status updates.""" + + def __init__( + self, + host: str, + port: int, + token: str, + on_status_update: Callable[[dict[str, Any]], None], + on_disconnect: Callable[[], None] | None = None, + ) -> None: + """Initialize the WebSocket client. + + Args: + host: Server hostname or IP + port: Server port + token: API authentication token + on_status_update: Callback when status update received + on_disconnect: Callback when connection lost + """ + self._host = host + self._port = int(port) + self._token = token + self._on_status_update = on_status_update + self._on_disconnect = on_disconnect + self._ws_url = f"ws://{host}:{self._port}/api/media/ws?token={token}" + self._session: aiohttp.ClientSession | None = None + self._ws: aiohttp.ClientWebSocketResponse | None = None + self._receive_task: asyncio.Task | None = None + self._running = False + + async def connect(self) -> bool: + """Establish WebSocket connection. + + Returns: + True if connection successful + """ + try: + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + + self._ws = await self._session.ws_connect( + self._ws_url, + heartbeat=30, + timeout=aiohttp.ClientTimeout(total=10), + ) + self._running = True + + # Start receive loop + self._receive_task = asyncio.create_task(self._receive_loop()) + + _LOGGER.info("WebSocket connected to %s:%s", self._host, self._port) + return True + + except Exception as err: + _LOGGER.warning("WebSocket connection failed: %s", err) + return False + + async def disconnect(self) -> None: + """Close WebSocket connection.""" + self._running = False + + if self._receive_task: + self._receive_task.cancel() + try: + await self._receive_task + except asyncio.CancelledError: + pass + self._receive_task = None + + if self._ws and not self._ws.closed: + await self._ws.close() + self._ws = None + + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + _LOGGER.debug("WebSocket disconnected") + + async def _receive_loop(self) -> None: + """Background loop to receive WebSocket messages.""" + while self._running and self._ws and not self._ws.closed: + try: + msg = await self._ws.receive(timeout=60) + + if msg.type == aiohttp.WSMsgType.TEXT: + data = msg.json() + msg_type = data.get("type") + + if msg_type in ("status", "status_update"): + status_data = data.get("data", {}) + # Convert album art URL to absolute + if ( + status_data.get("album_art_url") + and status_data["album_art_url"].startswith("/") + ): + track_id = f"{status_data.get('title', '')}-{status_data.get('artist', '')}" + track_hash = hashlib.md5(track_id.encode()).hexdigest()[:8] + status_data["album_art_url"] = ( + f"http://{self._host}:{self._port}" + f"{status_data['album_art_url']}?token={self._token}&t={track_hash}" + ) + self._on_status_update(status_data) + elif msg_type == "pong": + _LOGGER.debug("Received pong") + + elif msg.type == aiohttp.WSMsgType.CLOSED: + _LOGGER.warning("WebSocket closed by server") + break + elif msg.type == aiohttp.WSMsgType.ERROR: + _LOGGER.error("WebSocket error: %s", self._ws.exception()) + break + + except asyncio.TimeoutError: + # Send ping to keep connection alive + if self._ws and not self._ws.closed: + try: + await self._ws.send_json({"type": "ping"}) + except Exception: + break + except asyncio.CancelledError: + break + except Exception as err: + _LOGGER.error("WebSocket receive error: %s", err) + break + + # Connection lost, notify callback + if self._on_disconnect: + self._on_disconnect() + + @property + def is_connected(self) -> bool: + """Return True if WebSocket is connected.""" + return self._ws is not None and not self._ws.closed diff --git a/custom_components/remote_media_player/const.py b/custom_components/remote_media_player/const.py index c437701..c9dafd5 100644 --- a/custom_components/remote_media_player/const.py +++ b/custom_components/remote_media_player/const.py @@ -8,11 +8,14 @@ CONF_PORT = "port" CONF_TOKEN = "token" CONF_POLL_INTERVAL = "poll_interval" CONF_NAME = "name" +CONF_USE_WEBSOCKET = "use_websocket" # Default values DEFAULT_PORT = 8765 DEFAULT_POLL_INTERVAL = 5 DEFAULT_NAME = "Remote Media Player" +DEFAULT_USE_WEBSOCKET = True +DEFAULT_RECONNECT_INTERVAL = 30 # API endpoints API_HEALTH = "/api/health" @@ -27,6 +30,7 @@ API_MUTE = "/api/media/mute" API_SEEK = "/api/media/seek" API_SCRIPTS_LIST = "/api/scripts/list" API_SCRIPTS_EXECUTE = "/api/scripts/execute" +API_WEBSOCKET = "/api/media/ws" # Service names SERVICE_EXECUTE_SCRIPT = "execute_script" diff --git a/custom_components/remote_media_player/manifest.json b/custom_components/remote_media_player/manifest.json index edaab2b..be908fe 100644 --- a/custom_components/remote_media_player/manifest.json +++ b/custom_components/remote_media_player/manifest.json @@ -6,7 +6,7 @@ "dependencies": [], "documentation": "https://github.com/your-username/haos-integration-media-player", "integration_type": "device", - "iot_class": "local_polling", + "iot_class": "local_push", "requirements": ["aiohttp>=3.8.0"], "version": "1.0.0" } diff --git a/custom_components/remote_media_player/media_player.py b/custom_components/remote_media_player/media_player.py index 2a5385c..319afb7 100644 --- a/custom_components/remote_media_player/media_player.py +++ b/custom_components/remote_media_player/media_player.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import logging from datetime import datetime, timedelta from typing import Any @@ -14,7 +15,7 @@ from homeassistant.components.media_player import ( ) from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_NAME -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import ( @@ -23,14 +24,18 @@ from homeassistant.helpers.update_coordinator import ( UpdateFailed, ) -from .api_client import MediaServerClient, MediaServerError +from .api_client import MediaServerClient, MediaServerError, MediaServerWebSocket from .const import ( DOMAIN, CONF_HOST, CONF_PORT, + CONF_TOKEN, CONF_POLL_INTERVAL, + CONF_USE_WEBSOCKET, DEFAULT_POLL_INTERVAL, DEFAULT_NAME, + DEFAULT_USE_WEBSOCKET, + DEFAULT_RECONNECT_INTERVAL, ) _LOGGER = logging.getLogger(__name__) @@ -62,13 +67,26 @@ async def async_setup_entry( entry.data.get(CONF_POLL_INTERVAL, DEFAULT_POLL_INTERVAL), ) - # Create update coordinator + # Get WebSocket setting from options or data + use_websocket = entry.options.get( + CONF_USE_WEBSOCKET, + entry.data.get(CONF_USE_WEBSOCKET, DEFAULT_USE_WEBSOCKET), + ) + + # Create update coordinator with WebSocket support coordinator = MediaPlayerCoordinator( hass, client, poll_interval, + host=entry.data[CONF_HOST], + port=entry.data[CONF_PORT], + token=entry.data[CONF_TOKEN], + use_websocket=use_websocket, ) + # Set up WebSocket connection if enabled + await coordinator.async_setup() + # Fetch initial data - don't fail setup if this fails try: await coordinator.async_config_entry_first_refresh() @@ -76,6 +94,9 @@ async def async_setup_entry( _LOGGER.warning("Initial data fetch failed, will retry: %s", err) # Continue anyway - the coordinator will retry + # Store coordinator for cleanup + hass.data[DOMAIN][entry.entry_id]["coordinator"] = coordinator + # Create and add entity entity = RemoteMediaPlayerEntity( coordinator, @@ -86,13 +107,17 @@ async def async_setup_entry( class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]): - """Coordinator for fetching media player data.""" + """Coordinator for fetching media player data with WebSocket support.""" def __init__( self, hass: HomeAssistant, client: MediaServerClient, poll_interval: int, + host: str, + port: int, + token: str, + use_websocket: bool = True, ) -> None: """Initialize the coordinator. @@ -100,6 +125,10 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]): hass: Home Assistant instance client: Media Server API client poll_interval: Update interval in seconds + host: Server hostname + port: Server port + token: API token + use_websocket: Whether to use WebSocket for updates """ super().__init__( hass, @@ -108,9 +137,76 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]): update_interval=timedelta(seconds=poll_interval), ) self.client = client + self._host = host + self._port = port + self._token = token + self._use_websocket = use_websocket + self._ws_client: MediaServerWebSocket | None = None + self._ws_connected = False + self._reconnect_task: asyncio.Task | None = None + self._poll_interval = poll_interval + + async def async_setup(self) -> None: + """Set up the coordinator with WebSocket if enabled.""" + if self._use_websocket: + await self._connect_websocket() + + async def _connect_websocket(self) -> None: + """Establish WebSocket connection.""" + if self._ws_client: + await self._ws_client.disconnect() + + self._ws_client = MediaServerWebSocket( + host=self._host, + port=self._port, + token=self._token, + on_status_update=self._handle_ws_status_update, + on_disconnect=self._handle_ws_disconnect, + ) + + if await self._ws_client.connect(): + self._ws_connected = True + # Disable polling - WebSocket handles all updates including position + self.update_interval = None + _LOGGER.info("WebSocket connected, polling disabled") + else: + self._ws_connected = False + # Keep polling as fallback + self.update_interval = timedelta(seconds=self._poll_interval) + _LOGGER.warning("WebSocket failed, falling back to polling") + # Schedule reconnect attempt + self._schedule_reconnect() + + @callback + def _handle_ws_status_update(self, status_data: dict[str, Any]) -> None: + """Handle status update from WebSocket.""" + self.async_set_updated_data(status_data) + + @callback + def _handle_ws_disconnect(self) -> None: + """Handle WebSocket disconnection.""" + self._ws_connected = False + # Re-enable polling as fallback + self.update_interval = timedelta(seconds=self._poll_interval) + _LOGGER.warning("WebSocket disconnected, falling back to polling") + # Schedule reconnect attempt + self._schedule_reconnect() + + def _schedule_reconnect(self) -> None: + """Schedule a WebSocket reconnection attempt.""" + if self._reconnect_task and not self._reconnect_task.done(): + return # Already scheduled + + async def reconnect() -> None: + await asyncio.sleep(DEFAULT_RECONNECT_INTERVAL) + if self._use_websocket and not self._ws_connected: + _LOGGER.info("Attempting WebSocket reconnect...") + await self._connect_websocket() + + self._reconnect_task = self.hass.async_create_task(reconnect()) async def _async_update_data(self) -> dict[str, Any]: - """Fetch data from the API. + """Fetch data from the API (fallback when WebSocket unavailable). Returns: Media status data @@ -120,7 +216,7 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]): """ try: data = await self.client.get_status() - _LOGGER.debug("Received media status: %s", data) + _LOGGER.debug("HTTP poll received status: %s", data.get("state")) return data except MediaServerError as err: raise UpdateFailed(f"Error communicating with server: {err}") from err @@ -128,6 +224,17 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]): _LOGGER.exception("Unexpected error fetching media status") raise UpdateFailed(f"Unexpected error: {err}") from err + async def async_shutdown(self) -> None: + """Clean up resources.""" + if self._reconnect_task: + self._reconnect_task.cancel() + try: + await self._reconnect_task + except asyncio.CancelledError: + pass + if self._ws_client: + await self._ws_client.disconnect() + class RemoteMediaPlayerEntity(CoordinatorEntity[MediaPlayerCoordinator], MediaPlayerEntity): """Representation of a Remote Media Player.""" diff --git a/media_server/main.py b/media_server/main.py index 1f57e66..1cecaa6 100644 --- a/media_server/main.py +++ b/media_server/main.py @@ -11,6 +11,8 @@ from fastapi.middleware.cors import CORSMiddleware from .config import settings, generate_default_config, get_config_dir from .routes import health_router, media_router, scripts_router +from .services import get_media_controller +from .services.websocket_manager import ws_manager def setup_logging(): @@ -29,7 +31,16 @@ async def lifespan(app: FastAPI): logger = logging.getLogger(__name__) logger.info(f"Media Server starting on {settings.host}:{settings.port}") logger.info(f"API Token: {settings.api_token[:8]}...") + + # Start WebSocket status monitor + controller = get_media_controller() + await ws_manager.start_status_monitor(controller.get_status) + logger.info("WebSocket status monitor started") + yield + + # Stop WebSocket status monitor + await ws_manager.stop_status_monitor() logger.info("Media Server shutting down") diff --git a/media_server/routes/media.py b/media_server/routes/media.py index 107ea3f..618d0aa 100644 --- a/media_server/routes/media.py +++ b/media_server/routes/media.py @@ -1,11 +1,18 @@ """Media control API endpoints.""" -from fastapi import APIRouter, Depends, HTTPException, status +import logging + +from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect +from fastapi import status from fastapi.responses import Response from ..auth import verify_token, verify_token_or_query +from ..config import settings from ..models import MediaStatus, VolumeRequest, SeekRequest from ..services import get_media_controller, get_current_album_art +from ..services.websocket_manager import ws_manager + +logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/media", tags=["media"]) @@ -184,3 +191,52 @@ async def get_artwork(_: str = Depends(verify_token_or_query)) -> Response: content_type = "image/webp" return Response(content=art_bytes, media_type=content_type) + + +@router.websocket("/ws") +async def websocket_endpoint( + websocket: WebSocket, + token: str = Query(..., description="API authentication token"), +) -> None: + """WebSocket endpoint for real-time media status updates. + + Authentication is done via query parameter since WebSocket + doesn't support custom headers in the browser. + + Messages sent to client: + - {"type": "status", "data": {...}} - Initial status on connect + - {"type": "status_update", "data": {...}} - Status changes + - {"type": "error", "message": "..."} - Error messages + + Client can send: + - {"type": "ping"} - Keepalive, server responds with {"type": "pong"} + - {"type": "get_status"} - Request current status + """ + # Verify token + if token != settings.api_token: + await websocket.close(code=4001, reason="Invalid authentication token") + return + + await ws_manager.connect(websocket) + + try: + while True: + # Wait for messages from client (for keepalive/ping) + data = await websocket.receive_json() + + if data.get("type") == "ping": + await websocket.send_json({"type": "pong"}) + elif data.get("type") == "get_status": + # Allow manual status request + controller = get_media_controller() + status_data = await controller.get_status() + await websocket.send_json({ + "type": "status", + "data": status_data.model_dump(), + }) + + except WebSocketDisconnect: + await ws_manager.disconnect(websocket) + except Exception as e: + logger.error("WebSocket error: %s", e) + await ws_manager.disconnect(websocket) diff --git a/media_server/services/websocket_manager.py b/media_server/services/websocket_manager.py new file mode 100644 index 0000000..7289cb7 --- /dev/null +++ b/media_server/services/websocket_manager.py @@ -0,0 +1,189 @@ +"""WebSocket connection manager and status broadcaster.""" + +import asyncio +import logging +import time +from typing import Any, Callable, Coroutine + +from fastapi import WebSocket + +logger = logging.getLogger(__name__) + + +class ConnectionManager: + """Manages WebSocket connections and broadcasts status updates.""" + + def __init__(self) -> None: + """Initialize the connection manager.""" + self._active_connections: set[WebSocket] = set() + self._lock = asyncio.Lock() + self._last_status: dict[str, Any] | None = None + self._broadcast_task: asyncio.Task | None = None + self._poll_interval: float = 0.5 # Internal poll interval for change detection + self._position_broadcast_interval: float = 5.0 # Send position updates every 5s during playback + self._last_broadcast_time: float = 0.0 + self._running: bool = False + + async def connect(self, websocket: WebSocket) -> None: + """Accept a new WebSocket connection.""" + await websocket.accept() + async with self._lock: + self._active_connections.add(websocket) + logger.info( + "WebSocket client connected. Total: %d", len(self._active_connections) + ) + + # Send current status immediately upon connection + if self._last_status: + try: + await websocket.send_json({"type": "status", "data": self._last_status}) + except Exception as e: + logger.debug("Failed to send initial status: %s", e) + + async def disconnect(self, websocket: WebSocket) -> None: + """Remove a WebSocket connection.""" + async with self._lock: + self._active_connections.discard(websocket) + logger.info( + "WebSocket client disconnected. Total: %d", len(self._active_connections) + ) + + async def broadcast(self, message: dict[str, Any]) -> None: + """Broadcast a message to all connected clients.""" + async with self._lock: + connections = list(self._active_connections) + + if not connections: + return + + disconnected = [] + for websocket in connections: + try: + await websocket.send_json(message) + except Exception as e: + logger.debug("Failed to send to client: %s", e) + disconnected.append(websocket) + + # Clean up disconnected clients + for ws in disconnected: + await self.disconnect(ws) + + def status_changed( + self, old: dict[str, Any] | None, new: dict[str, Any] + ) -> bool: + """Detect if media status has meaningfully changed. + + Position is NOT included for normal playback (let HA interpolate). + But seeks (large unexpected jumps) are detected. + """ + if old is None: + return True + + # Fields to compare for changes (NO position - let HA interpolate) + significant_fields = [ + "state", + "title", + "artist", + "album", + "volume", + "muted", + "duration", + "source", + "album_art_url", + ] + + for field in significant_fields: + if old.get(field) != new.get(field): + return True + + # Detect seeks - large position jumps that aren't normal playback + old_pos = old.get("position") or 0 + new_pos = new.get("position") or 0 + pos_diff = new_pos - old_pos + + # During playback, position should increase by ~0.5s (our poll interval) + # A seek is when position jumps backwards OR forward by more than expected + if new.get("state") == "playing": + # Backward seek or forward jump > 3s indicates seek + if pos_diff < -1.0 or pos_diff > 3.0: + return True + else: + # When paused, any significant position change is a seek + if abs(pos_diff) > 1.0: + return True + + return False + + async def start_status_monitor( + self, + get_status_func: Callable[[], Coroutine[Any, Any, Any]], + ) -> None: + """Start the background status monitoring loop.""" + if self._running: + return + + self._running = True + self._broadcast_task = asyncio.create_task( + self._status_monitor_loop(get_status_func) + ) + logger.info("WebSocket status monitor started") + + async def stop_status_monitor(self) -> None: + """Stop the background status monitoring loop.""" + self._running = False + if self._broadcast_task: + self._broadcast_task.cancel() + try: + await self._broadcast_task + except asyncio.CancelledError: + pass + logger.info("WebSocket status monitor stopped") + + async def _status_monitor_loop( + self, + get_status_func: Callable[[], Coroutine[Any, Any, Any]], + ) -> None: + """Background loop that polls for status changes and broadcasts.""" + while self._running: + try: + # Only poll if we have connected clients + async with self._lock: + has_clients = len(self._active_connections) > 0 + + if has_clients: + status = await get_status_func() + status_dict = status.model_dump() + + # Only broadcast on actual state changes + # Let HA handle position interpolation during playback + if self.status_changed(self._last_status, status_dict): + self._last_status = status_dict + self._last_broadcast_time = time.time() + await self.broadcast( + {"type": "status_update", "data": status_dict} + ) + logger.debug("Broadcast sent: status change") + else: + # Update cached status even without broadcast + self._last_status = status_dict + else: + # Still update cache for when clients connect + status = await get_status_func() + self._last_status = status.model_dump() + + await asyncio.sleep(self._poll_interval) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error("Error in status monitor: %s", e) + await asyncio.sleep(self._poll_interval) + + @property + def client_count(self) -> int: + """Return the number of connected clients.""" + return len(self._active_connections) + + +# Global instance +ws_manager = ConnectionManager()