Add WebSocket support for real-time media status updates
Replace HTTP polling with WebSocket push notifications for instant state change responses. Server broadcasts updates only when significant changes occur (state, track, volume, etc.) while letting Home Assistant interpolate position during playback. Includes seek detection for timeline updates and automatic fallback to HTTP polling if WebSocket disconnects. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -11,6 +11,8 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .config import settings, generate_default_config, get_config_dir
|
||||
from .routes import health_router, media_router, scripts_router
|
||||
from .services import get_media_controller
|
||||
from .services.websocket_manager import ws_manager
|
||||
|
||||
|
||||
def setup_logging():
|
||||
@@ -29,7 +31,16 @@ async def lifespan(app: FastAPI):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Media Server starting on {settings.host}:{settings.port}")
|
||||
logger.info(f"API Token: {settings.api_token[:8]}...")
|
||||
|
||||
# Start WebSocket status monitor
|
||||
controller = get_media_controller()
|
||||
await ws_manager.start_status_monitor(controller.get_status)
|
||||
logger.info("WebSocket status monitor started")
|
||||
|
||||
yield
|
||||
|
||||
# Stop WebSocket status monitor
|
||||
await ws_manager.stop_status_monitor()
|
||||
logger.info("Media Server shutting down")
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
"""Media control API endpoints."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi import status
|
||||
from fastapi.responses import Response
|
||||
|
||||
from ..auth import verify_token, verify_token_or_query
|
||||
from ..config import settings
|
||||
from ..models import MediaStatus, VolumeRequest, SeekRequest
|
||||
from ..services import get_media_controller, get_current_album_art
|
||||
from ..services.websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/media", tags=["media"])
|
||||
|
||||
@@ -184,3 +191,52 @@ async def get_artwork(_: str = Depends(verify_token_or_query)) -> Response:
|
||||
content_type = "image/webp"
|
||||
|
||||
return Response(content=art_bytes, media_type=content_type)
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(..., description="API authentication token"),
|
||||
) -> None:
|
||||
"""WebSocket endpoint for real-time media status updates.
|
||||
|
||||
Authentication is done via query parameter since WebSocket
|
||||
doesn't support custom headers in the browser.
|
||||
|
||||
Messages sent to client:
|
||||
- {"type": "status", "data": {...}} - Initial status on connect
|
||||
- {"type": "status_update", "data": {...}} - Status changes
|
||||
- {"type": "error", "message": "..."} - Error messages
|
||||
|
||||
Client can send:
|
||||
- {"type": "ping"} - Keepalive, server responds with {"type": "pong"}
|
||||
- {"type": "get_status"} - Request current status
|
||||
"""
|
||||
# Verify token
|
||||
if token != settings.api_token:
|
||||
await websocket.close(code=4001, reason="Invalid authentication token")
|
||||
return
|
||||
|
||||
await ws_manager.connect(websocket)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Wait for messages from client (for keepalive/ping)
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data.get("type") == "ping":
|
||||
await websocket.send_json({"type": "pong"})
|
||||
elif data.get("type") == "get_status":
|
||||
# Allow manual status request
|
||||
controller = get_media_controller()
|
||||
status_data = await controller.get_status()
|
||||
await websocket.send_json({
|
||||
"type": "status",
|
||||
"data": status_data.model_dump(),
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
await ws_manager.disconnect(websocket)
|
||||
except Exception as e:
|
||||
logger.error("WebSocket error: %s", e)
|
||||
await ws_manager.disconnect(websocket)
|
||||
|
||||
189
media_server/services/websocket_manager.py
Normal file
189
media_server/services/websocket_manager.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""WebSocket connection manager and status broadcaster."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages WebSocket connections and broadcasts status updates."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the connection manager."""
|
||||
self._active_connections: set[WebSocket] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self._last_status: dict[str, Any] | None = None
|
||||
self._broadcast_task: asyncio.Task | None = None
|
||||
self._poll_interval: float = 0.5 # Internal poll interval for change detection
|
||||
self._position_broadcast_interval: float = 5.0 # Send position updates every 5s during playback
|
||||
self._last_broadcast_time: float = 0.0
|
||||
self._running: bool = False
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> None:
|
||||
"""Accept a new WebSocket connection."""
|
||||
await websocket.accept()
|
||||
async with self._lock:
|
||||
self._active_connections.add(websocket)
|
||||
logger.info(
|
||||
"WebSocket client connected. Total: %d", len(self._active_connections)
|
||||
)
|
||||
|
||||
# Send current status immediately upon connection
|
||||
if self._last_status:
|
||||
try:
|
||||
await websocket.send_json({"type": "status", "data": self._last_status})
|
||||
except Exception as e:
|
||||
logger.debug("Failed to send initial status: %s", e)
|
||||
|
||||
async def disconnect(self, websocket: WebSocket) -> None:
|
||||
"""Remove a WebSocket connection."""
|
||||
async with self._lock:
|
||||
self._active_connections.discard(websocket)
|
||||
logger.info(
|
||||
"WebSocket client disconnected. Total: %d", len(self._active_connections)
|
||||
)
|
||||
|
||||
async def broadcast(self, message: dict[str, Any]) -> None:
|
||||
"""Broadcast a message to all connected clients."""
|
||||
async with self._lock:
|
||||
connections = list(self._active_connections)
|
||||
|
||||
if not connections:
|
||||
return
|
||||
|
||||
disconnected = []
|
||||
for websocket in connections:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to send to client: %s", e)
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clean up disconnected clients
|
||||
for ws in disconnected:
|
||||
await self.disconnect(ws)
|
||||
|
||||
def status_changed(
|
||||
self, old: dict[str, Any] | None, new: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Detect if media status has meaningfully changed.
|
||||
|
||||
Position is NOT included for normal playback (let HA interpolate).
|
||||
But seeks (large unexpected jumps) are detected.
|
||||
"""
|
||||
if old is None:
|
||||
return True
|
||||
|
||||
# Fields to compare for changes (NO position - let HA interpolate)
|
||||
significant_fields = [
|
||||
"state",
|
||||
"title",
|
||||
"artist",
|
||||
"album",
|
||||
"volume",
|
||||
"muted",
|
||||
"duration",
|
||||
"source",
|
||||
"album_art_url",
|
||||
]
|
||||
|
||||
for field in significant_fields:
|
||||
if old.get(field) != new.get(field):
|
||||
return True
|
||||
|
||||
# Detect seeks - large position jumps that aren't normal playback
|
||||
old_pos = old.get("position") or 0
|
||||
new_pos = new.get("position") or 0
|
||||
pos_diff = new_pos - old_pos
|
||||
|
||||
# During playback, position should increase by ~0.5s (our poll interval)
|
||||
# A seek is when position jumps backwards OR forward by more than expected
|
||||
if new.get("state") == "playing":
|
||||
# Backward seek or forward jump > 3s indicates seek
|
||||
if pos_diff < -1.0 or pos_diff > 3.0:
|
||||
return True
|
||||
else:
|
||||
# When paused, any significant position change is a seek
|
||||
if abs(pos_diff) > 1.0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def start_status_monitor(
|
||||
self,
|
||||
get_status_func: Callable[[], Coroutine[Any, Any, Any]],
|
||||
) -> None:
|
||||
"""Start the background status monitoring loop."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._broadcast_task = asyncio.create_task(
|
||||
self._status_monitor_loop(get_status_func)
|
||||
)
|
||||
logger.info("WebSocket status monitor started")
|
||||
|
||||
async def stop_status_monitor(self) -> None:
|
||||
"""Stop the background status monitoring loop."""
|
||||
self._running = False
|
||||
if self._broadcast_task:
|
||||
self._broadcast_task.cancel()
|
||||
try:
|
||||
await self._broadcast_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("WebSocket status monitor stopped")
|
||||
|
||||
async def _status_monitor_loop(
|
||||
self,
|
||||
get_status_func: Callable[[], Coroutine[Any, Any, Any]],
|
||||
) -> None:
|
||||
"""Background loop that polls for status changes and broadcasts."""
|
||||
while self._running:
|
||||
try:
|
||||
# Only poll if we have connected clients
|
||||
async with self._lock:
|
||||
has_clients = len(self._active_connections) > 0
|
||||
|
||||
if has_clients:
|
||||
status = await get_status_func()
|
||||
status_dict = status.model_dump()
|
||||
|
||||
# Only broadcast on actual state changes
|
||||
# Let HA handle position interpolation during playback
|
||||
if self.status_changed(self._last_status, status_dict):
|
||||
self._last_status = status_dict
|
||||
self._last_broadcast_time = time.time()
|
||||
await self.broadcast(
|
||||
{"type": "status_update", "data": status_dict}
|
||||
)
|
||||
logger.debug("Broadcast sent: status change")
|
||||
else:
|
||||
# Update cached status even without broadcast
|
||||
self._last_status = status_dict
|
||||
else:
|
||||
# Still update cache for when clients connect
|
||||
status = await get_status_func()
|
||||
self._last_status = status.model_dump()
|
||||
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in status monitor: %s", e)
|
||||
await asyncio.sleep(self._poll_interval)
|
||||
|
||||
@property
|
||||
def client_count(self) -> int:
|
||||
"""Return the number of connected clients."""
|
||||
return len(self._active_connections)
|
||||
|
||||
|
||||
# Global instance
|
||||
ws_manager = ConnectionManager()
|
||||
Reference in New Issue
Block a user