"""WebSocket connection manager and status broadcaster.""" import asyncio import json import logging import time from typing import Any, Callable, Coroutine from fastapi import WebSocket logger = logging.getLogger(__name__) class ConnectionManager: """Manages WebSocket connections and broadcasts status updates.""" def __init__(self) -> None: """Initialize the connection manager.""" self._active_connections: set[WebSocket] = set() self._lock = asyncio.Lock() self._last_status: dict[str, Any] | None = None self._broadcast_task: asyncio.Task | None = None self._poll_interval: float = 0.5 # Internal poll interval for change detection self._position_broadcast_interval: float = 5.0 # Send position updates every 5s during playback self._last_broadcast_time: float = 0.0 self._running: bool = False # Audio visualizer self._visualizer_subscribers: set[WebSocket] = set() self._audio_task: asyncio.Task | None = None self._audio_analyzer = None async def connect(self, websocket: WebSocket) -> None: """Accept a new WebSocket connection.""" await websocket.accept() async with self._lock: self._active_connections.add(websocket) logger.info( "WebSocket client connected. Total: %d", len(self._active_connections) ) # Send current status immediately upon connection if self._last_status: try: await websocket.send_json({"type": "status", "data": self._last_status}) except Exception as e: logger.debug("Failed to send initial status: %s", e) async def disconnect(self, websocket: WebSocket) -> None: """Remove a WebSocket connection.""" async with self._lock: self._active_connections.discard(websocket) self._visualizer_subscribers.discard(websocket) logger.info( "WebSocket client disconnected. Total: %d", len(self._active_connections) ) async def broadcast(self, message: dict[str, Any]) -> None: """Broadcast a message to all connected clients concurrently.""" async with self._lock: connections = list(self._active_connections) if not connections: return async def _send(ws: WebSocket) -> WebSocket | None: try: await ws.send_json(message) return None except Exception as e: logger.debug("Failed to send to client: %s", e) return ws results = await asyncio.gather(*(_send(ws) for ws in connections)) # Clean up disconnected clients for ws in results: if ws is not None: await self.disconnect(ws) async def broadcast_scripts_changed(self) -> None: """Notify all connected clients that scripts have changed.""" message = {"type": "scripts_changed", "data": {}} await self.broadcast(message) logger.info("Broadcast sent: scripts_changed") async def broadcast_links_changed(self) -> None: """Notify all connected clients that links have changed.""" message = {"type": "links_changed", "data": {}} await self.broadcast(message) logger.info("Broadcast sent: links_changed") async def subscribe_visualizer(self, websocket: WebSocket) -> None: """Subscribe a client to audio visualizer data.""" async with self._lock: self._visualizer_subscribers.add(websocket) logger.debug("Visualizer subscriber added. Total: %d", len(self._visualizer_subscribers)) async def unsubscribe_visualizer(self, websocket: WebSocket) -> None: """Unsubscribe a client from audio visualizer data.""" async with self._lock: self._visualizer_subscribers.discard(websocket) logger.debug("Visualizer subscriber removed. Total: %d", len(self._visualizer_subscribers)) async def start_audio_monitor(self, analyzer) -> None: """Start audio frequency broadcasting if analyzer is available.""" self._audio_analyzer = analyzer if analyzer and analyzer.running: self._audio_task = asyncio.create_task(self._audio_broadcast_loop()) logger.info("Audio visualizer broadcast started") async def stop_audio_monitor(self) -> None: """Stop audio frequency broadcasting.""" if self._audio_task: self._audio_task.cancel() try: await self._audio_task except asyncio.CancelledError: pass self._audio_task = None async def _audio_broadcast_loop(self) -> None: """Background loop: read frequency data from analyzer and broadcast to subscribers.""" from ..config import settings interval = 1.0 / settings.visualizer_fps _last_data = None while True: try: async with self._lock: subscribers = list(self._visualizer_subscribers) if not subscribers or not self._audio_analyzer or not self._audio_analyzer.running: await asyncio.sleep(interval) continue data = self._audio_analyzer.get_frequency_data() if data is None or data is _last_data: await asyncio.sleep(interval) continue _last_data = data # Pre-serialize once for all subscribers (avoids per-client JSON encoding) text = json.dumps({"type": "audio_data", "data": data}, separators=(',', ':')) async def _send(ws: WebSocket) -> WebSocket | None: try: await ws.send_text(text) return None except Exception: return ws results = await asyncio.gather(*(_send(ws) for ws in subscribers)) failed = [ws for ws in results if ws is not None] if failed: async with self._lock: for ws in failed: self._visualizer_subscribers.discard(ws) for ws in failed: await self.disconnect(ws) await asyncio.sleep(interval) except asyncio.CancelledError: break except Exception as e: logger.error("Error in audio broadcast: %s", e) await asyncio.sleep(interval) def status_changed( self, old: dict[str, Any] | None, new: dict[str, Any] ) -> bool: """Detect if media status has meaningfully changed. Position is NOT included for normal playback (let HA interpolate). But seeks (large unexpected jumps) are detected. """ if old is None: return True # Fields to compare for changes (NO position - let HA interpolate) significant_fields = [ "state", "title", "artist", "album", "volume", "muted", "duration", "source", "album_art_url", ] for field in significant_fields: if old.get(field) != new.get(field): return True # Detect seeks - large position jumps that aren't normal playback old_pos = old.get("position") or 0 new_pos = new.get("position") or 0 pos_diff = new_pos - old_pos # During playback, position should increase by ~0.5s (our poll interval) # A seek is when position jumps backwards OR forward by more than expected if new.get("state") == "playing": # Backward seek or forward jump > 3s indicates seek if pos_diff < -1.0 or pos_diff > 3.0: return True else: # When paused, any significant position change is a seek if abs(pos_diff) > 1.0: return True return False async def start_status_monitor( self, get_status_func: Callable[[], Coroutine[Any, Any, Any]], ) -> None: """Start the background status monitoring loop.""" if self._running: return self._running = True self._broadcast_task = asyncio.create_task( self._status_monitor_loop(get_status_func) ) logger.info("WebSocket status monitor started") async def stop_status_monitor(self) -> None: """Stop the background status monitoring loop.""" self._running = False if self._broadcast_task: self._broadcast_task.cancel() try: await self._broadcast_task except asyncio.CancelledError: pass logger.info("WebSocket status monitor stopped") async def _status_monitor_loop( self, get_status_func: Callable[[], Coroutine[Any, Any, Any]], ) -> None: """Background loop that polls for status changes and broadcasts.""" while self._running: try: # Only poll if we have connected clients async with self._lock: has_clients = len(self._active_connections) > 0 if not has_clients: await asyncio.sleep(self._poll_interval) continue status = await get_status_func() status_dict = status.model_dump() # Only broadcast on actual state changes # Let HA handle position interpolation during playback if self.status_changed(self._last_status, status_dict): self._last_status = status_dict self._last_broadcast_time = time.time() await self.broadcast( {"type": "status_update", "data": status_dict} ) logger.debug("Broadcast sent: status change") else: # Update cached status even without broadcast self._last_status = status_dict await asyncio.sleep(self._poll_interval) except asyncio.CancelledError: break except Exception as e: logger.error("Error in status monitor: %s", e) await asyncio.sleep(self._poll_interval) @property def client_count(self) -> int: """Return the number of connected clients.""" return len(self._active_connections) # Global instance ws_manager = ConnectionManager()