"""WebSocket connection manager and status broadcaster.""" import asyncio import json import logging import time from typing import Any, Callable, Coroutine from fastapi import WebSocket logger = logging.getLogger(__name__) class ConnectionManager: """Manages WebSocket connections and broadcasts status updates.""" def __init__(self) -> None: """Initialize the connection manager.""" self._active_connections: set[WebSocket] = set() self._lock = asyncio.Lock() self._last_status: dict[str, Any] | None = None self._last_foreground: dict[str, Any] | None = None self._foreground_poll_interval: float = 1.0 self._last_foreground_poll: float = 0.0 self._get_status_func: Callable[[], Coroutine[Any, Any, Any]] | None = None self._broadcast_task: asyncio.Task | None = None self._poll_interval: float = 0.5 # Internal poll interval for change detection self._position_broadcast_interval: float = 5.0 # Send position updates every 5s during playback self._last_broadcast_time: float = 0.0 self._running: bool = False # Audio visualizer self._visualizer_subscribers: set[WebSocket] = set() self._audio_task: asyncio.Task | None = None self._audio_analyzer = None async def connect(self, websocket: WebSocket, already_accepted: bool = False) -> None: """Accept a new WebSocket connection. ``already_accepted=True`` is for callers that needed to call ``websocket.accept(subprotocol=...)`` themselves (token-via-subprotocol auth path). """ if not already_accepted: await websocket.accept() async with self._lock: self._active_connections.add(websocket) logger.info( "WebSocket client connected. Total: %d", len(self._active_connections) ) # Send current status immediately upon connection status = self._last_status if not status and self._get_status_func: try: result = await self._get_status_func() status = result.model_dump() self._last_status = status except Exception as e: logger.debug("Failed to fetch initial status: %s", e) if status: try: await websocket.send_json({"type": "status", "data": status}) except Exception as e: logger.debug("Failed to send initial status: %s", e) # Push a fresh foreground snapshot on connect so the UI can render # the tile immediately instead of waiting for the next change. try: from .foreground_service import get_foreground_info fg = await asyncio.to_thread(get_foreground_info) fg_dict = fg.to_dict() self._last_foreground = fg_dict await websocket.send_json({"type": "foreground", "data": fg_dict}) except Exception as e: logger.debug("Failed to send initial foreground snapshot: %s", e) async def disconnect(self, websocket: WebSocket) -> None: """Remove a WebSocket connection. Stops audio capture if last visualizer subscriber.""" should_stop = False async with self._lock: self._active_connections.discard(websocket) was_subscriber = websocket in self._visualizer_subscribers self._visualizer_subscribers.discard(websocket) if was_subscriber and len(self._visualizer_subscribers) == 0: should_stop = True if should_stop: await self._maybe_stop_capture() logger.info( "WebSocket client disconnected. Total: %d", len(self._active_connections) ) async def broadcast(self, message: dict[str, Any]) -> None: """Broadcast a message to all connected clients concurrently. The payload is serialized once and pushed via ``send_text`` to every client, instead of having Starlette/Pydantic encode it N times via ``send_json``. """ async with self._lock: connections = list(self._active_connections) if not connections: return try: payload = json.dumps(message, default=str) except (TypeError, ValueError) as e: logger.error("Failed to encode broadcast message: %s", e) return async def _send(ws: WebSocket) -> WebSocket | None: try: await ws.send_text(payload) return None except Exception as e: logger.debug("Failed to send to client: %s", e) return ws results = await asyncio.gather(*(_send(ws) for ws in connections)) # Clean up disconnected clients for ws in results: if ws is not None: await self.disconnect(ws) async def broadcast_scripts_changed(self) -> None: """Notify all connected clients that scripts have changed.""" message = {"type": "scripts_changed", "data": {}} await self.broadcast(message) logger.info("Broadcast sent: scripts_changed") async def broadcast_links_changed(self) -> None: """Notify all connected clients that links have changed.""" message = {"type": "links_changed", "data": {}} await self.broadcast(message) logger.info("Broadcast sent: links_changed") def foreground_changed( self, old: dict[str, Any] | None, new: dict[str, Any] ) -> bool: """Detect a meaningful change in the foreground process snapshot. The probe also returns ``window_geometry`` which jitters on every pixel of cursor drag — comparing the whole dict would flood clients. We only diff the fields a user (or HA automation) would actually act on. ``window_geometry``/``monitor_geometry``/``started_at`` are still delivered in the payload, but they don't drive broadcast cadence. """ if old is None: return True diff_fields = ( "pid", "process_name", "executable_path", "window_title", "is_fullscreen", "is_minimized", "monitor_id", "available", "error", ) for f in diff_fields: if old.get(f) != new.get(f): return True return False async def subscribe_visualizer(self, websocket: WebSocket) -> None: """Subscribe a client to audio visualizer data. Starts capture on first subscriber.""" should_start = False async with self._lock: self._visualizer_subscribers.add(websocket) if len(self._visualizer_subscribers) == 1 and self._audio_analyzer: should_start = True if should_start: await self._maybe_start_capture() logger.debug("Visualizer subscriber added. Total: %d", len(self._visualizer_subscribers)) async def unsubscribe_visualizer(self, websocket: WebSocket) -> None: """Unsubscribe a client from audio visualizer data. Stops capture on last subscriber.""" should_stop = False async with self._lock: self._visualizer_subscribers.discard(websocket) if len(self._visualizer_subscribers) == 0: should_stop = True if should_stop: await self._maybe_stop_capture() logger.debug("Visualizer subscriber removed. Total: %d", len(self._visualizer_subscribers)) async def _maybe_start_capture(self) -> None: """Start audio capture if not already running (called on first subscriber).""" if self._audio_analyzer and not self._audio_analyzer.running: loop = asyncio.get_running_loop() started = await loop.run_in_executor(None, self._audio_analyzer.start) if started: logger.info("Audio capture started (first subscriber)") else: logger.warning("Audio capture failed to start") async def _maybe_stop_capture(self) -> None: """Stop audio capture if running (called when last subscriber leaves).""" if self._audio_analyzer and self._audio_analyzer.running: loop = asyncio.get_running_loop() await loop.run_in_executor(None, self._audio_analyzer.stop) logger.info("Audio capture stopped (no subscribers)") async def start_audio_monitor(self, analyzer) -> None: """Register the audio analyzer. Capture starts on-demand when clients subscribe.""" self._audio_analyzer = analyzer if analyzer and analyzer.available: self._audio_task = asyncio.create_task(self._audio_broadcast_loop()) logger.info("Audio visualizer broadcast loop started (capture on-demand)") async def stop_audio_monitor(self) -> None: """Stop audio frequency broadcasting.""" if self._audio_task: self._audio_task.cancel() try: await self._audio_task except asyncio.CancelledError: pass self._audio_task = None async def _audio_broadcast_loop(self) -> None: """Background loop: read frequency data from analyzer and broadcast to subscribers. Event-driven: blocks on the analyzer's data_event so it wakes up exactly once per produced frame, instead of polling on a timer. Backstop sleep applies when capture is idle / has no subscribers. """ from ..config import settings idle_interval = 1.0 / max(1, settings.visualizer_fps) # Bounded wait so we still notice subscribe/unsubscribe transitions. wake_timeout = max(0.05, idle_interval) loop = asyncio.get_running_loop() last_seq = -1 while True: try: async with self._lock: subscribers = list(self._visualizer_subscribers) analyzer = self._audio_analyzer if not subscribers or not analyzer or not analyzer.running: await asyncio.sleep(idle_interval) continue # Wait off-loop for a fresh frame. The capture thread sets # data_event after each FFT update; we clear it before the # next wait so we never burn a wake on stale data. ev = analyzer.data_event def _wait() -> bool: return ev.wait(wake_timeout) got = await loop.run_in_executor(None, _wait) if not got: # Timeout — loop around to re-check subscriber state. continue ev.clear() data, seq = analyzer.get_frequency_data_versioned() if data is None or seq == last_seq: continue last_seq = seq # Pre-serialize once for all subscribers (avoids per-client JSON encoding) text = json.dumps({"type": "audio_data", "data": data}, separators=(',', ':')) async def _send(ws: WebSocket) -> WebSocket | None: try: await ws.send_text(text) return None except Exception: return ws results = await asyncio.gather(*(_send(ws) for ws in subscribers)) failed = [ws for ws in results if ws is not None] for ws in failed: await self.disconnect(ws) except asyncio.CancelledError: break except Exception as e: logger.error("Error in audio broadcast: %s", e) await asyncio.sleep(idle_interval) def status_changed( self, old: dict[str, Any] | None, new: dict[str, Any] ) -> bool: """Detect if media status has meaningfully changed. Position is NOT included for normal playback (let HA interpolate). But seeks (large unexpected jumps) are detected. """ if old is None: return True # Fields to compare for changes (NO position - let HA interpolate) significant_fields = [ "state", "title", "artist", "album", "volume", "muted", "duration", "source", "album_art_url", ] for field in significant_fields: if old.get(field) != new.get(field): return True # Detect seeks - large position jumps that aren't normal playback old_pos = old.get("position") or 0 new_pos = new.get("position") or 0 pos_diff = new_pos - old_pos # During playback, position should increase by ~0.5s (our poll interval) # A seek is when position jumps backwards OR forward by more than expected if new.get("state") == "playing": # Backward seek or forward jump > 3s indicates seek if pos_diff < -1.0 or pos_diff > 3.0: return True else: # When paused, any significant position change is a seek if abs(pos_diff) > 1.0: return True return False async def start_status_monitor( self, get_status_func: Callable[[], Coroutine[Any, Any, Any]], ) -> None: """Start the background status monitoring loop.""" if self._running: return self._get_status_func = get_status_func self._running = True self._broadcast_task = asyncio.create_task( self._status_monitor_loop(get_status_func) ) logger.info("WebSocket status monitor started") async def stop_status_monitor(self) -> None: """Stop the background status monitoring loop.""" self._running = False if self._broadcast_task: self._broadcast_task.cancel() try: await self._broadcast_task except asyncio.CancelledError: pass logger.info("WebSocket status monitor stopped") async def _status_monitor_loop( self, get_status_func: Callable[[], Coroutine[Any, Any, Any]], ) -> None: """Background loop that polls for status changes and broadcasts.""" # Foreground tracker is imported lazily so unit tests of the WS # manager don't drag in platform-specific probe code. from .foreground_service import get_foreground_info while self._running: try: # Only poll if we have connected clients async with self._lock: has_clients = len(self._active_connections) > 0 if not has_clients: await asyncio.sleep(2.0) # Sleep longer when no clients connected continue status = await get_status_func() status_dict = status.model_dump() # Only broadcast on actual state changes # Let HA handle position interpolation during playback if self.status_changed(self._last_status, status_dict): self._last_status = status_dict self._last_broadcast_time = time.time() await self.broadcast( {"type": "status_update", "data": status_dict} ) logger.debug("Broadcast sent: status change") else: # Update cached status even without broadcast self._last_status = status_dict # Foreground process — poll at a coarser interval than media # status. Broadcasts only fire on a real change, so a quiet # desktop costs nothing. now = time.time() if ( now - self._last_foreground_poll ) >= self._foreground_poll_interval: self._last_foreground_poll = now try: fg = await asyncio.to_thread(get_foreground_info) fg_dict = fg.to_dict() if self.foreground_changed(self._last_foreground, fg_dict): self._last_foreground = fg_dict await self.broadcast( {"type": "foreground_update", "data": fg_dict} ) logger.debug("Broadcast sent: foreground change") else: self._last_foreground = fg_dict except Exception as e: logger.debug("Foreground poll failed: %s", e) await asyncio.sleep(self._poll_interval) except asyncio.CancelledError: break except Exception as e: logger.error("Error in status monitor: %s", e) await asyncio.sleep(self._poll_interval) @property def client_count(self) -> int: """Return the number of connected clients.""" return len(self._active_connections) # Global instance ws_manager = ConnectionManager()