"""Emby REST API client.""" from __future__ import annotations import logging import re from typing import Any import aiohttp from .const import ( ALLOWED_IMAGE_TYPES, COMMAND_DISPLAY_MESSAGE, COMMAND_MUTE, COMMAND_SET_REPEAT_MODE, COMMAND_SET_VOLUME, COMMAND_UNMUTE, DEFAULT_DEVICE_VERSION, DEFAULT_PORT, DEVICE_NAME, EMBY_ID_PATTERN, ENDPOINT_ARTISTS, ENDPOINT_ITEMS, ENDPOINT_LIBRARY_REFRESH, ENDPOINT_PREFIX_EMBY, ENDPOINT_PREFIX_NONE, ENDPOINT_SESSIONS, ENDPOINT_SYSTEM_INFO, ENDPOINT_USERS, ENDPOINT_USERS_PUBLIC, IMAGE_FETCH_TIMEOUT_SECONDS, PLAY_COMMAND_PLAY_NOW, PLAYBACK_COMMAND_NEXT_TRACK, PLAYBACK_COMMAND_PAUSE, PLAYBACK_COMMAND_PREVIOUS_TRACK, PLAYBACK_COMMAND_SEEK, PLAYBACK_COMMAND_STOP, PLAYBACK_COMMAND_UNPAUSE, ) _LOGGER = logging.getLogger(__name__) _REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15) _IMAGE_TIMEOUT = aiohttp.ClientTimeout(total=IMAGE_FETCH_TIMEOUT_SECONDS) _EMBY_ID_RE = re.compile(EMBY_ID_PATTERN) class EmbyApiError(Exception): """Base exception for Emby API errors.""" class EmbyConnectionError(EmbyApiError): """Exception for connection errors.""" class EmbyAuthenticationError(EmbyApiError): """Exception for authentication errors.""" def _validate_emby_id(value: str, field_name: str = "id") -> str: """Reject IDs that don't look like Emby identifiers.""" if not isinstance(value, str) or not _EMBY_ID_RE.fullmatch(value): raise EmbyApiError(f"Invalid Emby {field_name}: {value!r}") return value class EmbyApiClient: """Emby REST API client. The aiohttp session is owned by the caller (Home Assistant); this class never closes it. """ def __init__( self, host: str, api_key: str, session: aiohttp.ClientSession, device_id: str, port: int = DEFAULT_PORT, ssl: bool = False, verify_ssl: bool = True, client_version: str = DEFAULT_DEVICE_VERSION, ) -> None: """Initialize the Emby API client.""" if not host or not host.strip(): raise ValueError("host must not be empty") if not api_key: raise ValueError("api_key must not be empty") if not device_id: raise ValueError("device_id must not be empty") self._host = host.strip().rstrip("/") self._port = port self._api_key = api_key self._ssl = ssl self._verify_ssl = verify_ssl self._session = session self._device_id = device_id self._client_version = client_version protocol = "https" if ssl else "http" self._base_url = f"{protocol}://{self._host}:{port}" # Discovered at test_connection(); set to "" if server is configured # without the /emby prefix. self._prefix: str = ENDPOINT_PREFIX_EMBY @property def base_url(self) -> str: """Return the base URL (no API prefix).""" return self._base_url @property def prefix(self) -> str: """Return the working API path prefix (e.g. "/emby" or "").""" return self._prefix @property def device_id(self) -> str: """Return the device id used to identify this client to Emby.""" return self._device_id def _get_headers(self, *, content_json: bool = True) -> dict[str, str]: """Get headers for API requests.""" headers = { "X-Emby-Token": self._api_key, "X-Emby-Client": DEVICE_NAME, "X-Emby-Device-Name": DEVICE_NAME, "X-Emby-Device-Id": self._device_id, "X-Emby-Client-Version": self._client_version, "Accept": "application/json", } if content_json: headers["Content-Type"] = "application/json" return headers def _ssl_kwarg(self) -> dict[str, Any]: """Return the ssl kwarg for aiohttp depending on config.""" if not self._ssl: return {} return {"ssl": self._verify_ssl} async def _request( self, method: str, endpoint: str, params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, *, absolute: bool = False, ) -> Any: """Make an API request. ``endpoint`` is expected to begin with "/" (e.g. "/System/Info"). When ``absolute`` is False the discovered prefix ("/emby" by default) is prepended. """ path = endpoint if absolute else f"{self._prefix}{endpoint}" url = f"{self._base_url}{path}" _LOGGER.debug("Making %s request to %s", method, url) try: async with self._session.request( method, url, headers=self._get_headers(), params=params, json=data, timeout=_REQUEST_TIMEOUT, **self._ssl_kwarg(), ) as response: _LOGGER.debug("Response status: %s", response.status) if response.status == 401: raise EmbyAuthenticationError("Invalid API key") if response.status == 403: raise EmbyAuthenticationError("Access forbidden") if response.status >= 400: text = await response.text() _LOGGER.debug("API error %s: %s", response.status, text) raise EmbyApiError(f"API error {response.status}: {text}") if response.status == 204 or response.content_length == 0: return None content_type = response.headers.get("Content-Type", "") if "application/json" in content_type: return await response.json() return await response.text() except aiohttp.ClientError as err: _LOGGER.debug("Connection error to %s: %s", url, err) raise EmbyConnectionError(f"Connection error: {err}") from err except TimeoutError as err: _LOGGER.debug("Timeout connecting to %s", url) raise EmbyConnectionError(f"Connection timeout: {err}") from err async def _get( self, endpoint: str, params: dict[str, Any] | None = None ) -> Any: """Make a GET request.""" return await self._request("GET", endpoint, params=params) async def _post( self, endpoint: str, params: dict[str, Any] | None = None, data: dict[str, Any] | None = None, ) -> Any: """Make a POST request.""" return await self._request("POST", endpoint, params=params, data=data) # ------------------------------------------------------------------------- # Authentication & System # ------------------------------------------------------------------------- async def test_connection(self) -> dict[str, Any]: """Test the connection to the Emby server. Tries both /emby/System/Info and /System/Info and pins the working prefix for subsequent calls. """ last_error: Exception | None = None for prefix in (ENDPOINT_PREFIX_EMBY, ENDPOINT_PREFIX_NONE): url_path = f"{prefix}{ENDPOINT_SYSTEM_INFO}" try: _LOGGER.debug("Probing %s for connectivity", url_path) result = await self._request("GET", url_path, absolute=True) self._prefix = prefix _LOGGER.debug("Using API prefix %r", prefix or "") return result except EmbyAuthenticationError: raise except (EmbyConnectionError, EmbyApiError) as err: last_error = err continue raise EmbyConnectionError( f"Cannot connect to Emby server at {self._base_url}. " f"Last error: {last_error}" ) async def get_server_info(self) -> dict[str, Any]: """Get server information.""" return await self._get(ENDPOINT_SYSTEM_INFO) async def get_users(self) -> list[dict[str, Any]]: """Get list of users. Falls back to the public users endpoint if the API key is not an admin token (HTTP 401/403 on the authenticated endpoint). """ try: return await self._get(ENDPOINT_USERS) except EmbyAuthenticationError: _LOGGER.debug( "API key is not admin, falling back to /Users/Public" ) return await self._get(ENDPOINT_USERS_PUBLIC) # ------------------------------------------------------------------------- # Sessions # ------------------------------------------------------------------------- async def get_sessions(self) -> list[dict[str, Any]]: """Get all active sessions.""" return await self._get(ENDPOINT_SESSIONS) async def get_controllable_sessions( self, user_id: str | None = None ) -> list[dict[str, Any]]: """Get sessions that can be remotely controlled.""" params: dict[str, Any] = {} if user_id: params["ControllableByUserId"] = user_id sessions = await self._get(ENDPOINT_SESSIONS, params=params) return [s for s in sessions if s.get("SupportsRemoteControl")] # ------------------------------------------------------------------------- # Playback Control # ------------------------------------------------------------------------- async def play_media( self, session_id: str, item_ids: list[str], play_command: str = PLAY_COMMAND_PLAY_NOW, start_position_ticks: int = 0, ) -> None: """Send play command to a session.""" if not item_ids: raise EmbyApiError("item_ids are required") _validate_emby_id(session_id, "session_id") for item_id in item_ids: _validate_emby_id(item_id, "item_id") endpoint = f"{ENDPOINT_SESSIONS}/{session_id}/Playing" params: dict[str, Any] = { "ItemIds": ",".join(item_ids), "PlayCommand": play_command, } if start_position_ticks > 0: params["StartPositionTicks"] = start_position_ticks await self._post(endpoint, params=params) async def _playback_command(self, session_id: str, command: str) -> None: """Send a playback command to a session.""" _validate_emby_id(session_id, "session_id") endpoint = f"{ENDPOINT_SESSIONS}/{session_id}/Playing/{command}" await self._post(endpoint) async def play(self, session_id: str) -> None: """Resume playback.""" await self._playback_command(session_id, PLAYBACK_COMMAND_UNPAUSE) async def pause(self, session_id: str) -> None: """Pause playback.""" await self._playback_command(session_id, PLAYBACK_COMMAND_PAUSE) async def stop(self, session_id: str) -> None: """Stop playback.""" await self._playback_command(session_id, PLAYBACK_COMMAND_STOP) async def next_track(self, session_id: str) -> None: """Skip to next track.""" await self._playback_command(session_id, PLAYBACK_COMMAND_NEXT_TRACK) async def previous_track(self, session_id: str) -> None: """Skip to previous track.""" await self._playback_command(session_id, PLAYBACK_COMMAND_PREVIOUS_TRACK) async def seek(self, session_id: str, position_ticks: int) -> None: """Seek to a position.""" _validate_emby_id(session_id, "session_id") endpoint = f"{ENDPOINT_SESSIONS}/{session_id}/Playing/{PLAYBACK_COMMAND_SEEK}" await self._post(endpoint, params={"SeekPositionTicks": position_ticks}) # ------------------------------------------------------------------------- # General Commands # ------------------------------------------------------------------------- async def _send_command( self, session_id: str, command: str, arguments: dict[str, Any] | None = None, ) -> None: """Send a general command to a session. Emby's /Command endpoint accepts arguments as strings; numeric values are stringified here. """ _validate_emby_id(session_id, "session_id") endpoint = f"{ENDPOINT_SESSIONS}/{session_id}/Command" data: dict[str, Any] = {"Name": command} if arguments: data["Arguments"] = {k: str(v) for k, v in arguments.items()} await self._post(endpoint, data=data) async def set_volume(self, session_id: str, volume: int) -> None: """Set volume level (0-100).""" volume = max(0, min(100, int(volume))) await self._send_command( session_id, COMMAND_SET_VOLUME, {"Volume": volume} ) async def mute(self, session_id: str) -> None: """Mute the session.""" await self._send_command(session_id, COMMAND_MUTE) async def unmute(self, session_id: str) -> None: """Unmute the session.""" await self._send_command(session_id, COMMAND_UNMUTE) async def set_repeat_mode(self, session_id: str, mode: str) -> None: """Set the session's repeat mode (RepeatNone/RepeatOne/RepeatAll).""" await self._send_command( session_id, COMMAND_SET_REPEAT_MODE, {"RepeatMode": mode} ) async def display_message( self, session_id: str, text: str, header: str | None = None, timeout_ms: int | None = None, ) -> None: """Display a message on the client device.""" args: dict[str, Any] = {"Text": text} if header is not None: args["Header"] = header if timeout_ms is not None: args["TimeoutMs"] = int(timeout_ms) await self._send_command(session_id, COMMAND_DISPLAY_MESSAGE, args) # ------------------------------------------------------------------------- # Library Browsing # ------------------------------------------------------------------------- async def get_views(self, user_id: str) -> list[dict[str, Any]]: """Get user's library views (top-level folders).""" _validate_emby_id(user_id, "user_id") endpoint = f"{ENDPOINT_USERS}/{user_id}/Views" result = await self._get(endpoint) if not isinstance(result, dict): return [] items = result.get("Items", []) return items if isinstance(items, list) else [] async def get_items( self, user_id: str, parent_id: str | None = None, include_item_types: list[str] | None = None, recursive: bool = False, sort_by: str = "SortName", sort_order: str = "Ascending", start_index: int = 0, limit: int = 100, search_term: str | None = None, fields: list[str] | None = None, ) -> dict[str, Any]: """Get items from the library.""" _validate_emby_id(user_id, "user_id") if parent_id is not None: _validate_emby_id(parent_id, "parent_id") endpoint = f"{ENDPOINT_USERS}/{user_id}/Items" params: dict[str, Any] = { "SortBy": sort_by, "SortOrder": sort_order, "StartIndex": start_index, "Limit": limit, "Recursive": str(recursive).lower(), } if parent_id: params["ParentId"] = parent_id if include_item_types: params["IncludeItemTypes"] = ",".join(include_item_types) if search_term: params["SearchTerm"] = search_term if fields: params["Fields"] = ",".join(fields) return await self._get(endpoint, params=params) async def get_item(self, user_id: str, item_id: str) -> dict[str, Any]: """Get a single item by ID.""" _validate_emby_id(user_id, "user_id") _validate_emby_id(item_id, "item_id") endpoint = f"{ENDPOINT_USERS}/{user_id}/Items/{item_id}" return await self._get(endpoint) async def get_artists( self, user_id: str, parent_id: str | None = None, start_index: int = 0, limit: int = 100, ) -> dict[str, Any]: """Get artists.""" _validate_emby_id(user_id, "user_id") if parent_id is not None: _validate_emby_id(parent_id, "parent_id") params: dict[str, Any] = { "UserId": user_id, "StartIndex": start_index, "Limit": limit, "SortBy": "SortName", "SortOrder": "Ascending", } if parent_id: params["ParentId"] = parent_id return await self._get(ENDPOINT_ARTISTS, params=params) async def refresh_library(self) -> None: """Trigger a server-side library scan.""" await self._post(ENDPOINT_LIBRARY_REFRESH) # ------------------------------------------------------------------------- # Image fetching # ------------------------------------------------------------------------- def get_image_path( self, item_id: str, image_type: str = "Primary", max_width: int | None = None, max_height: int | None = None, ) -> tuple[str, dict[str, str]]: """Build the URL and query params for an item image. Returns ``(url, params)``. The API key is intentionally NOT included; the caller is responsible for sending the X-Emby-Token header. The item_id is validated to prevent path traversal. """ _validate_emby_id(item_id, "item_id") if image_type not in ALLOWED_IMAGE_TYPES: raise EmbyApiError(f"Invalid image_type: {image_type!r}") url = ( f"{self._base_url}{self._prefix}{ENDPOINT_ITEMS}" f"/{item_id}/Images/{image_type}" ) params: dict[str, str] = {} if max_width: params["maxWidth"] = str(int(max_width)) if max_height: params["maxHeight"] = str(int(max_height)) return url, params async def fetch_image( self, item_id: str, image_type: str = "Primary", max_width: int | None = None, max_height: int | None = None, ) -> tuple[bytes, str | None]: """Fetch an image. Returns ``(content, content_type)``. Uses the X-Emby-Token header so the API key is not exposed in URLs. """ url, params = self.get_image_path( item_id, image_type, max_width, max_height ) try: async with self._session.get( url, params=params, headers=self._get_headers(content_json=False), timeout=_IMAGE_TIMEOUT, **self._ssl_kwarg(), ) as response: if response.status >= 400: raise EmbyApiError( f"Image fetch failed with status {response.status}" ) content = await response.read() return content, response.headers.get("Content-Type") except aiohttp.ClientError as err: raise EmbyConnectionError(f"Image fetch error: {err}") from err except TimeoutError as err: raise EmbyConnectionError(f"Image fetch timeout: {err}") from err