"""SQLite database connection wrapper. Provides a thread-safe, WAL-mode SQLite connection shared by all stores. Each entity table uses the same schema: indexed columns for common queries plus a JSON blob for the full entity data. """ import json import sqlite3 import threading from contextlib import contextmanager from pathlib import Path from typing import Any, Dict, List, Tuple from wled_controller.utils import get_logger logger = get_logger(__name__) # When True, all database writes are suppressed. Set by the restore flow # to prevent the old server process from overwriting freshly-restored data # with stale in-memory state before the restart completes. _writes_frozen = False def freeze_writes() -> None: """Block all database writes until the process exits (used after restore).""" global _writes_frozen _writes_frozen = True logger.info("Database writes frozen - awaiting server restart") def is_writes_frozen() -> bool: """Check whether writes are currently frozen.""" return _writes_frozen # Schema version — bump when tables change _SCHEMA_VERSION = 1 # All entity tables share this structure _ENTITY_TABLES = [ "devices", "capture_templates", "postprocessing_templates", "picture_sources", "output_targets", "pattern_templates", "color_strip_sources", "audio_sources", "audio_templates", "value_sources", "automations", "scene_presets", "sync_clocks", "color_strip_processing_templates", "gradients", "weather_sources", "assets", ] _VALID_TABLES = frozenset(_ENTITY_TABLES) | {"settings", "schema_version"} def _check_table(table: str) -> None: """Raise ValueError if *table* is not a known entity table.""" if table not in _VALID_TABLES: raise ValueError(f"Invalid table name: {table!r}") class Database: """Thread-safe SQLite connection wrapper with WAL mode. All stores share a single Database instance. The connection uses WAL journaling for concurrent read access and a single writer lock. """ def __init__(self, db_path: str | Path): self._path = Path(db_path) self._path.parent.mkdir(parents=True, exist_ok=True) self._conn = sqlite3.connect( str(self._path), check_same_thread=False, ) self._conn.row_factory = sqlite3.Row self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA busy_timeout=5000") self._lock = threading.RLock() self._ensure_schema() logger.info(f"Database opened: {self._path}") # -- Schema management --------------------------------------------------- def _ensure_schema(self) -> None: """Create tables if they don't exist.""" with self._lock: # Schema version tracking self._conn.execute(""" CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY KEY, applied_at TEXT NOT NULL ) """) # Key-value settings table self._conn.execute(""" CREATE TABLE IF NOT EXISTS settings ( key TEXT PRIMARY KEY, value TEXT NOT NULL ) """) # Create entity tables for table in _ENTITY_TABLES: self._conn.execute(f""" CREATE TABLE IF NOT EXISTS [{table}] ( id TEXT PRIMARY KEY, name TEXT NOT NULL DEFAULT '', data TEXT NOT NULL ) """) self._conn.execute( f"CREATE INDEX IF NOT EXISTS idx_{table}_name ON [{table}](name)" ) # Record schema version existing = self._conn.execute( "SELECT version FROM schema_version WHERE version = ?", (_SCHEMA_VERSION,), ).fetchone() if not existing: from datetime import datetime, timezone self._conn.execute( "INSERT OR IGNORE INTO schema_version (version, applied_at) VALUES (?, ?)", (_SCHEMA_VERSION, datetime.now(timezone.utc).isoformat()), ) self._conn.commit() # -- Low-level operations ------------------------------------------------ def execute(self, sql: str, params: Tuple = ()) -> sqlite3.Cursor: """Execute a single SQL statement (auto-commits).""" with self._lock: cursor = self._conn.execute(sql, params) self._conn.commit() return cursor def execute_many(self, sql: str, params_list: List[Tuple]) -> None: """Execute a parameterized statement for each params tuple.""" with self._lock: self._conn.executemany(sql, params_list) self._conn.commit() @contextmanager def transaction(self): """Context manager for multi-statement transactions. Usage:: with db.transaction() as conn: conn.execute("INSERT ...", (...)) conn.execute("DELETE ...", (...)) # auto-committed on exit, rolled back on exception """ with self._lock: try: yield self._conn self._conn.commit() except Exception: self._conn.rollback() raise # -- Entity helpers (used by BaseSqliteStore) ---------------------------- def load_all(self, table: str) -> List[Dict[str, Any]]: """Load all rows from an entity table. Returns list of dicts parsed from the ``data`` JSON column. """ _check_table(table) with self._lock: rows = self._conn.execute( f"SELECT id, data FROM [{table}]" ).fetchall() result = [] for row in rows: try: item = json.loads(row["data"]) result.append(item) except json.JSONDecodeError as e: logger.error(f"Corrupt JSON in {table}/{row['id']}: {e}") return result def upsert(self, table: str, item_id: str, name: str, data: dict) -> None: """Insert or replace a single entity row. Skipped silently when writes are frozen. """ _check_table(table) if _writes_frozen: return json_data = json.dumps(data, ensure_ascii=False) with self._lock: self._conn.execute( f"INSERT OR REPLACE INTO [{table}] (id, name, data) VALUES (?, ?, ?)", (item_id, name, json_data), ) self._conn.commit() def delete_row(self, table: str, item_id: str) -> None: """Delete a single entity row. Skipped silently when writes are frozen. """ _check_table(table) if _writes_frozen: return with self._lock: self._conn.execute( f"DELETE FROM [{table}] WHERE id = ?", (item_id,) ) self._conn.commit() def delete_all(self, table: str) -> None: """Delete all rows from an entity table. Skipped silently when writes are frozen. """ _check_table(table) if _writes_frozen: return with self._lock: self._conn.execute(f"DELETE FROM [{table}]") self._conn.commit() def bulk_insert(self, table: str, items: List[Tuple[str, str, str]]) -> None: """Bulk insert rows: list of (id, name, data_json) tuples. Skipped silently when writes are frozen. """ _check_table(table) if _writes_frozen: return with self._lock: self._conn.executemany( f"INSERT OR REPLACE INTO [{table}] (id, name, data) VALUES (?, ?, ?)", items, ) self._conn.commit() def count(self, table: str) -> int: """Count rows in an entity table.""" _check_table(table) with self._lock: row = self._conn.execute( f"SELECT COUNT(*) as cnt FROM [{table}]" ).fetchone() return row["cnt"] def table_exists_with_data(self, table: str) -> bool: """Check if a table exists and has at least one row.""" _check_table(table) with self._lock: try: row = self._conn.execute( f"SELECT COUNT(*) as cnt FROM [{table}]" ).fetchone() return row["cnt"] > 0 except sqlite3.OperationalError as e: logger.debug("Table %s does not exist or is inaccessible: %s", table, e) return False # -- Settings (key-value) ------------------------------------------------ def get_setting(self, key: str) -> dict | None: """Read a setting by key. Returns parsed JSON dict, or None if not found.""" with self._lock: row = self._conn.execute( "SELECT value FROM settings WHERE key = ?", (key,) ).fetchone() if row is None: return None try: return json.loads(row["value"]) except json.JSONDecodeError as e: logger.warning("Corrupt JSON in setting '%s': %s", key, e) return None def set_setting(self, key: str, value: dict) -> None: """Write a setting (upsert). Skipped when writes are frozen.""" if _writes_frozen: return json_value = json.dumps(value, ensure_ascii=False) with self._lock: self._conn.execute( "INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)", (key, json_value), ) self._conn.commit() # -- Backup -------------------------------------------------------------- def backup_to(self, dest_path: str | Path) -> None: """Create a consistent snapshot of the database using SQLite's backup API. Safe to call while the database is in use — SQLite handles locking. """ dest_path = Path(dest_path) dest_path.parent.mkdir(parents=True, exist_ok=True) with self._lock: dest = sqlite3.connect(str(dest_path)) try: self._conn.backup(dest) finally: dest.close() def restore_from(self, src_path: str | Path) -> None: """Replace the database contents from a backup file. The caller must restart the server after calling this — in-memory caches in stores will be stale. """ src_path = Path(src_path) if not src_path.exists(): raise FileNotFoundError(f"Backup file not found: {src_path}") with self._lock: src = sqlite3.connect(str(src_path)) try: src.backup(self._conn) finally: src.close() # -- Lifecycle ----------------------------------------------------------- def close(self) -> None: """Close the database connection.""" with self._lock: self._conn.close() logger.info("Database connection closed")