Add WebSocket support for real-time media status updates

Replace HTTP polling with WebSocket push notifications for instant
state change responses. Server broadcasts updates only when significant
changes occur (state, track, volume, etc.) while letting Home Assistant
interpolate position during playback. Includes seek detection for
timeline updates and automatic fallback to HTTP polling if WebSocket
disconnects.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-04 14:02:53 +03:00
parent 67a89e8349
commit 5519e449cd
8 changed files with 521 additions and 8 deletions

View File

@@ -138,6 +138,12 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if unload_ok:
# Close client and remove data
data = hass.data[DOMAIN].pop(entry.entry_id)
# Shutdown coordinator (WebSocket cleanup)
if "coordinator" in data:
await data["coordinator"].async_shutdown()
# Close HTTP client
await data["client"].close()
# Remove services if this was the last entry

View File

@@ -2,7 +2,10 @@
from __future__ import annotations
import asyncio
import hashlib
import logging
from collections.abc import Callable
from typing import Any
import aiohttp
@@ -265,3 +268,140 @@ class MediaServerClient:
endpoint = f"{API_SCRIPTS_EXECUTE}/{script_name}"
json_data = {"args": args or []}
return await self._request("POST", endpoint, json_data)
class MediaServerWebSocket:
"""WebSocket client for real-time media status updates."""
def __init__(
self,
host: str,
port: int,
token: str,
on_status_update: Callable[[dict[str, Any]], None],
on_disconnect: Callable[[], None] | None = None,
) -> None:
"""Initialize the WebSocket client.
Args:
host: Server hostname or IP
port: Server port
token: API authentication token
on_status_update: Callback when status update received
on_disconnect: Callback when connection lost
"""
self._host = host
self._port = int(port)
self._token = token
self._on_status_update = on_status_update
self._on_disconnect = on_disconnect
self._ws_url = f"ws://{host}:{self._port}/api/media/ws?token={token}"
self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._receive_task: asyncio.Task | None = None
self._running = False
async def connect(self) -> bool:
"""Establish WebSocket connection.
Returns:
True if connection successful
"""
try:
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession()
self._ws = await self._session.ws_connect(
self._ws_url,
heartbeat=30,
timeout=aiohttp.ClientTimeout(total=10),
)
self._running = True
# Start receive loop
self._receive_task = asyncio.create_task(self._receive_loop())
_LOGGER.info("WebSocket connected to %s:%s", self._host, self._port)
return True
except Exception as err:
_LOGGER.warning("WebSocket connection failed: %s", err)
return False
async def disconnect(self) -> None:
"""Close WebSocket connection."""
self._running = False
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
self._receive_task = None
if self._ws and not self._ws.closed:
await self._ws.close()
self._ws = None
if self._session and not self._session.closed:
await self._session.close()
self._session = None
_LOGGER.debug("WebSocket disconnected")
async def _receive_loop(self) -> None:
"""Background loop to receive WebSocket messages."""
while self._running and self._ws and not self._ws.closed:
try:
msg = await self._ws.receive(timeout=60)
if msg.type == aiohttp.WSMsgType.TEXT:
data = msg.json()
msg_type = data.get("type")
if msg_type in ("status", "status_update"):
status_data = data.get("data", {})
# Convert album art URL to absolute
if (
status_data.get("album_art_url")
and status_data["album_art_url"].startswith("/")
):
track_id = f"{status_data.get('title', '')}-{status_data.get('artist', '')}"
track_hash = hashlib.md5(track_id.encode()).hexdigest()[:8]
status_data["album_art_url"] = (
f"http://{self._host}:{self._port}"
f"{status_data['album_art_url']}?token={self._token}&t={track_hash}"
)
self._on_status_update(status_data)
elif msg_type == "pong":
_LOGGER.debug("Received pong")
elif msg.type == aiohttp.WSMsgType.CLOSED:
_LOGGER.warning("WebSocket closed by server")
break
elif msg.type == aiohttp.WSMsgType.ERROR:
_LOGGER.error("WebSocket error: %s", self._ws.exception())
break
except asyncio.TimeoutError:
# Send ping to keep connection alive
if self._ws and not self._ws.closed:
try:
await self._ws.send_json({"type": "ping"})
except Exception:
break
except asyncio.CancelledError:
break
except Exception as err:
_LOGGER.error("WebSocket receive error: %s", err)
break
# Connection lost, notify callback
if self._on_disconnect:
self._on_disconnect()
@property
def is_connected(self) -> bool:
"""Return True if WebSocket is connected."""
return self._ws is not None and not self._ws.closed

View File

@@ -8,11 +8,14 @@ CONF_PORT = "port"
CONF_TOKEN = "token"
CONF_POLL_INTERVAL = "poll_interval"
CONF_NAME = "name"
CONF_USE_WEBSOCKET = "use_websocket"
# Default values
DEFAULT_PORT = 8765
DEFAULT_POLL_INTERVAL = 5
DEFAULT_NAME = "Remote Media Player"
DEFAULT_USE_WEBSOCKET = True
DEFAULT_RECONNECT_INTERVAL = 30
# API endpoints
API_HEALTH = "/api/health"
@@ -27,6 +30,7 @@ API_MUTE = "/api/media/mute"
API_SEEK = "/api/media/seek"
API_SCRIPTS_LIST = "/api/scripts/list"
API_SCRIPTS_EXECUTE = "/api/scripts/execute"
API_WEBSOCKET = "/api/media/ws"
# Service names
SERVICE_EXECUTE_SCRIPT = "execute_script"

View File

@@ -6,7 +6,7 @@
"dependencies": [],
"documentation": "https://github.com/your-username/haos-integration-media-player",
"integration_type": "device",
"iot_class": "local_polling",
"iot_class": "local_push",
"requirements": ["aiohttp>=3.8.0"],
"version": "1.0.0"
}

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Any
@@ -14,7 +15,7 @@ from homeassistant.components.media_player import (
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_NAME
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity import DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import (
@@ -23,14 +24,18 @@ from homeassistant.helpers.update_coordinator import (
UpdateFailed,
)
from .api_client import MediaServerClient, MediaServerError
from .api_client import MediaServerClient, MediaServerError, MediaServerWebSocket
from .const import (
DOMAIN,
CONF_HOST,
CONF_PORT,
CONF_TOKEN,
CONF_POLL_INTERVAL,
CONF_USE_WEBSOCKET,
DEFAULT_POLL_INTERVAL,
DEFAULT_NAME,
DEFAULT_USE_WEBSOCKET,
DEFAULT_RECONNECT_INTERVAL,
)
_LOGGER = logging.getLogger(__name__)
@@ -62,13 +67,26 @@ async def async_setup_entry(
entry.data.get(CONF_POLL_INTERVAL, DEFAULT_POLL_INTERVAL),
)
# Create update coordinator
# Get WebSocket setting from options or data
use_websocket = entry.options.get(
CONF_USE_WEBSOCKET,
entry.data.get(CONF_USE_WEBSOCKET, DEFAULT_USE_WEBSOCKET),
)
# Create update coordinator with WebSocket support
coordinator = MediaPlayerCoordinator(
hass,
client,
poll_interval,
host=entry.data[CONF_HOST],
port=entry.data[CONF_PORT],
token=entry.data[CONF_TOKEN],
use_websocket=use_websocket,
)
# Set up WebSocket connection if enabled
await coordinator.async_setup()
# Fetch initial data - don't fail setup if this fails
try:
await coordinator.async_config_entry_first_refresh()
@@ -76,6 +94,9 @@ async def async_setup_entry(
_LOGGER.warning("Initial data fetch failed, will retry: %s", err)
# Continue anyway - the coordinator will retry
# Store coordinator for cleanup
hass.data[DOMAIN][entry.entry_id]["coordinator"] = coordinator
# Create and add entity
entity = RemoteMediaPlayerEntity(
coordinator,
@@ -86,13 +107,17 @@ async def async_setup_entry(
class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]):
"""Coordinator for fetching media player data."""
"""Coordinator for fetching media player data with WebSocket support."""
def __init__(
self,
hass: HomeAssistant,
client: MediaServerClient,
poll_interval: int,
host: str,
port: int,
token: str,
use_websocket: bool = True,
) -> None:
"""Initialize the coordinator.
@@ -100,6 +125,10 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]):
hass: Home Assistant instance
client: Media Server API client
poll_interval: Update interval in seconds
host: Server hostname
port: Server port
token: API token
use_websocket: Whether to use WebSocket for updates
"""
super().__init__(
hass,
@@ -108,9 +137,76 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]):
update_interval=timedelta(seconds=poll_interval),
)
self.client = client
self._host = host
self._port = port
self._token = token
self._use_websocket = use_websocket
self._ws_client: MediaServerWebSocket | None = None
self._ws_connected = False
self._reconnect_task: asyncio.Task | None = None
self._poll_interval = poll_interval
async def async_setup(self) -> None:
"""Set up the coordinator with WebSocket if enabled."""
if self._use_websocket:
await self._connect_websocket()
async def _connect_websocket(self) -> None:
"""Establish WebSocket connection."""
if self._ws_client:
await self._ws_client.disconnect()
self._ws_client = MediaServerWebSocket(
host=self._host,
port=self._port,
token=self._token,
on_status_update=self._handle_ws_status_update,
on_disconnect=self._handle_ws_disconnect,
)
if await self._ws_client.connect():
self._ws_connected = True
# Disable polling - WebSocket handles all updates including position
self.update_interval = None
_LOGGER.info("WebSocket connected, polling disabled")
else:
self._ws_connected = False
# Keep polling as fallback
self.update_interval = timedelta(seconds=self._poll_interval)
_LOGGER.warning("WebSocket failed, falling back to polling")
# Schedule reconnect attempt
self._schedule_reconnect()
@callback
def _handle_ws_status_update(self, status_data: dict[str, Any]) -> None:
"""Handle status update from WebSocket."""
self.async_set_updated_data(status_data)
@callback
def _handle_ws_disconnect(self) -> None:
"""Handle WebSocket disconnection."""
self._ws_connected = False
# Re-enable polling as fallback
self.update_interval = timedelta(seconds=self._poll_interval)
_LOGGER.warning("WebSocket disconnected, falling back to polling")
# Schedule reconnect attempt
self._schedule_reconnect()
def _schedule_reconnect(self) -> None:
"""Schedule a WebSocket reconnection attempt."""
if self._reconnect_task and not self._reconnect_task.done():
return # Already scheduled
async def reconnect() -> None:
await asyncio.sleep(DEFAULT_RECONNECT_INTERVAL)
if self._use_websocket and not self._ws_connected:
_LOGGER.info("Attempting WebSocket reconnect...")
await self._connect_websocket()
self._reconnect_task = self.hass.async_create_task(reconnect())
async def _async_update_data(self) -> dict[str, Any]:
"""Fetch data from the API.
"""Fetch data from the API (fallback when WebSocket unavailable).
Returns:
Media status data
@@ -120,7 +216,7 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]):
"""
try:
data = await self.client.get_status()
_LOGGER.debug("Received media status: %s", data)
_LOGGER.debug("HTTP poll received status: %s", data.get("state"))
return data
except MediaServerError as err:
raise UpdateFailed(f"Error communicating with server: {err}") from err
@@ -128,6 +224,17 @@ class MediaPlayerCoordinator(DataUpdateCoordinator[dict[str, Any]]):
_LOGGER.exception("Unexpected error fetching media status")
raise UpdateFailed(f"Unexpected error: {err}") from err
async def async_shutdown(self) -> None:
"""Clean up resources."""
if self._reconnect_task:
self._reconnect_task.cancel()
try:
await self._reconnect_task
except asyncio.CancelledError:
pass
if self._ws_client:
await self._ws_client.disconnect()
class RemoteMediaPlayerEntity(CoordinatorEntity[MediaPlayerCoordinator], MediaPlayerEntity):
"""Representation of a Remote Media Player."""

View File

@@ -11,6 +11,8 @@ from fastapi.middleware.cors import CORSMiddleware
from .config import settings, generate_default_config, get_config_dir
from .routes import health_router, media_router, scripts_router
from .services import get_media_controller
from .services.websocket_manager import ws_manager
def setup_logging():
@@ -29,7 +31,16 @@ async def lifespan(app: FastAPI):
logger = logging.getLogger(__name__)
logger.info(f"Media Server starting on {settings.host}:{settings.port}")
logger.info(f"API Token: {settings.api_token[:8]}...")
# Start WebSocket status monitor
controller = get_media_controller()
await ws_manager.start_status_monitor(controller.get_status)
logger.info("WebSocket status monitor started")
yield
# Stop WebSocket status monitor
await ws_manager.stop_status_monitor()
logger.info("Media Server shutting down")

View File

@@ -1,11 +1,18 @@
"""Media control API endpoints."""
from fastapi import APIRouter, Depends, HTTPException, status
import logging
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
from fastapi import status
from fastapi.responses import Response
from ..auth import verify_token, verify_token_or_query
from ..config import settings
from ..models import MediaStatus, VolumeRequest, SeekRequest
from ..services import get_media_controller, get_current_album_art
from ..services.websocket_manager import ws_manager
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/media", tags=["media"])
@@ -184,3 +191,52 @@ async def get_artwork(_: str = Depends(verify_token_or_query)) -> Response:
content_type = "image/webp"
return Response(content=art_bytes, media_type=content_type)
@router.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
token: str = Query(..., description="API authentication token"),
) -> None:
"""WebSocket endpoint for real-time media status updates.
Authentication is done via query parameter since WebSocket
doesn't support custom headers in the browser.
Messages sent to client:
- {"type": "status", "data": {...}} - Initial status on connect
- {"type": "status_update", "data": {...}} - Status changes
- {"type": "error", "message": "..."} - Error messages
Client can send:
- {"type": "ping"} - Keepalive, server responds with {"type": "pong"}
- {"type": "get_status"} - Request current status
"""
# Verify token
if token != settings.api_token:
await websocket.close(code=4001, reason="Invalid authentication token")
return
await ws_manager.connect(websocket)
try:
while True:
# Wait for messages from client (for keepalive/ping)
data = await websocket.receive_json()
if data.get("type") == "ping":
await websocket.send_json({"type": "pong"})
elif data.get("type") == "get_status":
# Allow manual status request
controller = get_media_controller()
status_data = await controller.get_status()
await websocket.send_json({
"type": "status",
"data": status_data.model_dump(),
})
except WebSocketDisconnect:
await ws_manager.disconnect(websocket)
except Exception as e:
logger.error("WebSocket error: %s", e)
await ws_manager.disconnect(websocket)

View File

@@ -0,0 +1,189 @@
"""WebSocket connection manager and status broadcaster."""
import asyncio
import logging
import time
from typing import Any, Callable, Coroutine
from fastapi import WebSocket
logger = logging.getLogger(__name__)
class ConnectionManager:
"""Manages WebSocket connections and broadcasts status updates."""
def __init__(self) -> None:
"""Initialize the connection manager."""
self._active_connections: set[WebSocket] = set()
self._lock = asyncio.Lock()
self._last_status: dict[str, Any] | None = None
self._broadcast_task: asyncio.Task | None = None
self._poll_interval: float = 0.5 # Internal poll interval for change detection
self._position_broadcast_interval: float = 5.0 # Send position updates every 5s during playback
self._last_broadcast_time: float = 0.0
self._running: bool = False
async def connect(self, websocket: WebSocket) -> None:
"""Accept a new WebSocket connection."""
await websocket.accept()
async with self._lock:
self._active_connections.add(websocket)
logger.info(
"WebSocket client connected. Total: %d", len(self._active_connections)
)
# Send current status immediately upon connection
if self._last_status:
try:
await websocket.send_json({"type": "status", "data": self._last_status})
except Exception as e:
logger.debug("Failed to send initial status: %s", e)
async def disconnect(self, websocket: WebSocket) -> None:
"""Remove a WebSocket connection."""
async with self._lock:
self._active_connections.discard(websocket)
logger.info(
"WebSocket client disconnected. Total: %d", len(self._active_connections)
)
async def broadcast(self, message: dict[str, Any]) -> None:
"""Broadcast a message to all connected clients."""
async with self._lock:
connections = list(self._active_connections)
if not connections:
return
disconnected = []
for websocket in connections:
try:
await websocket.send_json(message)
except Exception as e:
logger.debug("Failed to send to client: %s", e)
disconnected.append(websocket)
# Clean up disconnected clients
for ws in disconnected:
await self.disconnect(ws)
def status_changed(
self, old: dict[str, Any] | None, new: dict[str, Any]
) -> bool:
"""Detect if media status has meaningfully changed.
Position is NOT included for normal playback (let HA interpolate).
But seeks (large unexpected jumps) are detected.
"""
if old is None:
return True
# Fields to compare for changes (NO position - let HA interpolate)
significant_fields = [
"state",
"title",
"artist",
"album",
"volume",
"muted",
"duration",
"source",
"album_art_url",
]
for field in significant_fields:
if old.get(field) != new.get(field):
return True
# Detect seeks - large position jumps that aren't normal playback
old_pos = old.get("position") or 0
new_pos = new.get("position") or 0
pos_diff = new_pos - old_pos
# During playback, position should increase by ~0.5s (our poll interval)
# A seek is when position jumps backwards OR forward by more than expected
if new.get("state") == "playing":
# Backward seek or forward jump > 3s indicates seek
if pos_diff < -1.0 or pos_diff > 3.0:
return True
else:
# When paused, any significant position change is a seek
if abs(pos_diff) > 1.0:
return True
return False
async def start_status_monitor(
self,
get_status_func: Callable[[], Coroutine[Any, Any, Any]],
) -> None:
"""Start the background status monitoring loop."""
if self._running:
return
self._running = True
self._broadcast_task = asyncio.create_task(
self._status_monitor_loop(get_status_func)
)
logger.info("WebSocket status monitor started")
async def stop_status_monitor(self) -> None:
"""Stop the background status monitoring loop."""
self._running = False
if self._broadcast_task:
self._broadcast_task.cancel()
try:
await self._broadcast_task
except asyncio.CancelledError:
pass
logger.info("WebSocket status monitor stopped")
async def _status_monitor_loop(
self,
get_status_func: Callable[[], Coroutine[Any, Any, Any]],
) -> None:
"""Background loop that polls for status changes and broadcasts."""
while self._running:
try:
# Only poll if we have connected clients
async with self._lock:
has_clients = len(self._active_connections) > 0
if has_clients:
status = await get_status_func()
status_dict = status.model_dump()
# Only broadcast on actual state changes
# Let HA handle position interpolation during playback
if self.status_changed(self._last_status, status_dict):
self._last_status = status_dict
self._last_broadcast_time = time.time()
await self.broadcast(
{"type": "status_update", "data": status_dict}
)
logger.debug("Broadcast sent: status change")
else:
# Update cached status even without broadcast
self._last_status = status_dict
else:
# Still update cache for when clients connect
status = await get_status_func()
self._last_status = status.model_dump()
await asyncio.sleep(self._poll_interval)
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in status monitor: %s", e)
await asyncio.sleep(self._poll_interval)
@property
def client_count(self) -> int:
"""Return the number of connected clients."""
return len(self._active_connections)
# Global instance
ws_manager = ConnectionManager()