"""WebSocket connection manager for Key Colors target color streams.""" from __future__ import annotations import asyncio import contextlib import json import logging from collections.abc import Callable from typing import Any import aiohttp from homeassistant.core import HomeAssistant from homeassistant.helpers.aiohttp_client import async_get_clientsession from .const import WS_RECONNECT_DELAY, WS_MAX_RECONNECT_DELAY _LOGGER = logging.getLogger(__name__) class KeyColorsWebSocketManager: """Manages WebSocket connections for Key Colors target color streams.""" def __init__( self, hass: HomeAssistant, server_url: str, api_key: str, ) -> None: self._hass = hass self._server_url = server_url self._api_key = api_key self._connections: dict[str, asyncio.Task] = {} self._callbacks: dict[str, list[Callable]] = {} self._latest_colors: dict[str, dict[str, dict[str, int]]] = {} self._shutting_down = False def _get_ws_url(self, target_id: str) -> str: """Build WebSocket URL for a target.""" ws_base = self._server_url.replace("http://", "ws://").replace( "https://", "wss://" ) return f"{ws_base}/api/v1/output-targets/{target_id}/ws?token={self._api_key}" async def start_listening(self, target_id: str) -> None: """Start WebSocket connection for a target.""" if target_id in self._connections: return task = self._hass.async_create_background_task( self._ws_loop(target_id), f"wled_screen_controller_ws_{target_id}", ) self._connections[target_id] = task async def stop_listening(self, target_id: str) -> None: """Stop WebSocket connection for a target.""" task = self._connections.pop(target_id, None) if task: task.cancel() with contextlib.suppress(asyncio.CancelledError): await task self._latest_colors.pop(target_id, None) def register_callback( self, target_id: str, callback: Callable ) -> Callable[[], None]: """Register a callback for color updates. Returns unregister function.""" self._callbacks.setdefault(target_id, []).append(callback) def unregister() -> None: cbs = self._callbacks.get(target_id) if cbs and callback in cbs: cbs.remove(callback) return unregister def get_latest_colors(self, target_id: str) -> dict[str, dict[str, int]]: """Get latest colors for a target.""" return self._latest_colors.get(target_id, {}) async def _ws_loop(self, target_id: str) -> None: """WebSocket connection loop with reconnection.""" delay = WS_RECONNECT_DELAY session = async_get_clientsession(self._hass) while not self._shutting_down: try: url = self._get_ws_url(target_id) async with session.ws_connect(url) as ws: delay = WS_RECONNECT_DELAY # reset on successful connect _LOGGER.debug("WS connected for target %s", target_id) async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: self._handle_message(target_id, msg.data) elif msg.type in ( aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR, ): break except asyncio.CancelledError: raise except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as err: _LOGGER.debug("WS connection error for %s: %s", target_id, err) except Exception as err: _LOGGER.error("Unexpected WS error for %s: %s", target_id, err) if self._shutting_down: break await asyncio.sleep(delay) delay = min(delay * 2, WS_MAX_RECONNECT_DELAY) def _handle_message(self, target_id: str, raw: str) -> None: """Handle incoming WebSocket message.""" try: data = json.loads(raw) except json.JSONDecodeError: return if data.get("type") != "colors_update": return colors: dict[str, Any] = data.get("colors", {}) self._latest_colors[target_id] = colors for cb in self._callbacks.get(target_id, []): try: cb(colors) except Exception: _LOGGER.exception("Error in WS color callback for %s", target_id) async def shutdown(self) -> None: """Stop all WebSocket connections.""" self._shutting_down = True for target_id in list(self._connections): await self.stop_listening(target_id)