Files
ledgrab/server/tests/storage/test_activity_log_repository.py
T
alexei.dolgolyov 1ac4a0f66d feat(activity-log): phase 1 - storage model, migration, repository
- ActivityLogEntry dataclass + ActivityCategory/ActivitySeverity + ActivityLogFilters
- additive idempotent migration 002_add_activity_log (indexed activity_log table, seq keyset tiebreaker)
- ActivityLogRepository (record/query/count/prune/clear/iter_export), keyset pagination, parameterized SQL
- 102 unit + adversarial tests (SQL-injection, pagination, prune, codec, migration idempotency)
2026-06-09 17:40:37 +03:00

610 lines
23 KiB
Python

"""Unit tests for ActivityLogRepository (Phase 1 — storage layer).
Coverage
--------
* round-trip: record + read back, including metadata JSON and UTC ts
* filter by each dimension: category, severity, actor, entity_type/id, date range, message_like
* keyset pagination stability with same-ts rows (seq tiebreaker)
* prune by age (before_ts) and by max_entries
* clear; count (filtered + unfiltered); export iterator
* migration idempotency: constructing the repo twice does not re-run the migration
"""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
import pytest
from ledgrab.storage.activity_log import (
ActivityCategory,
ActivityLogEntry,
ActivityLogFilters,
ActivitySeverity,
)
from ledgrab.storage.activity_log_repository import ActivityLogRepository
from ledgrab.storage.database import Database
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _now() -> datetime:
return datetime.now(timezone.utc)
_SENTINEL = object() # sentinel for "caller did not pass this kwarg"
def _entry(
*,
id: str | None = None,
ts: datetime | None = None,
category: str = ActivityCategory.ENTITY,
action: str = "entity.created",
severity: str = ActivitySeverity.INFO,
actor: str = "test_actor",
entity_type: str | None = "output_target",
entity_id: object = _SENTINEL,
entity_name: str | None = "My Target",
message: str = "Created output target",
metadata: dict | None = None,
) -> ActivityLogEntry:
"""Build a test ``ActivityLogEntry``.
``entity_id`` defaults to a random id when not supplied at all. Pass
``entity_id=None`` explicitly to get ``None`` stored in the entry.
"""
resolved_entity_id: str | None = (
f"ot_{uuid.uuid4().hex[:8]}" if entity_id is _SENTINEL else entity_id # type: ignore[assignment]
)
return ActivityLogEntry(
id=id or f"al_{uuid.uuid4().hex[:8]}",
ts=ts or _now(),
category=category,
action=action,
severity=severity,
actor=actor,
entity_type=entity_type,
entity_id=resolved_entity_id,
entity_name=entity_name,
message=message,
metadata=metadata if metadata is not None else {},
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def repo(tmp_db: Database) -> ActivityLogRepository:
"""Fresh ActivityLogRepository backed by a temp database."""
return ActivityLogRepository(tmp_db)
# ---------------------------------------------------------------------------
# Round-trip
# ---------------------------------------------------------------------------
class TestRoundTrip:
def test_record_and_read_back(self, repo: ActivityLogRepository) -> None:
e = _entry(message="Hello world", metadata={"key": "value", "n": 42})
repo.record(e)
page = repo.query(ActivityLogFilters(), limit=10)
assert len(page) == 1
got = page[0]
assert got.id == e.id
assert got.category == e.category
assert got.action == e.action
assert got.severity == e.severity
assert got.actor == e.actor
assert got.entity_type == e.entity_type
assert got.entity_id == e.entity_id
assert got.entity_name == e.entity_name
assert got.message == e.message
def test_metadata_json_round_trip(self, repo: ActivityLogRepository) -> None:
meta = {"device": "wled_01", "led_count": 150, "nested": {"x": True}}
e = _entry(metadata=meta)
repo.record(e)
got = repo.query(ActivityLogFilters(), limit=1)[0]
assert got.metadata == meta
def test_utc_timestamp_preserved(self, repo: ActivityLogRepository) -> None:
ts = datetime(2026, 1, 15, 12, 30, 45, tzinfo=timezone.utc)
e = _entry(ts=ts)
repo.record(e)
got = repo.query(ActivityLogFilters(), limit=1)[0]
# Should round-trip to the same UTC moment
assert got.ts.replace(tzinfo=timezone.utc) == ts.replace(tzinfo=timezone.utc)
def test_none_optional_fields_preserved(self, repo: ActivityLogRepository) -> None:
e = _entry(entity_type=None, entity_id=None, entity_name=None)
repo.record(e)
got = repo.query(ActivityLogFilters(), limit=1)[0]
assert got.entity_type is None
assert got.entity_id is None
assert got.entity_name is None
def test_empty_metadata_default(self, repo: ActivityLogRepository) -> None:
e = _entry(metadata={})
repo.record(e)
got = repo.query(ActivityLogFilters(), limit=1)[0]
assert got.metadata == {}
# ---------------------------------------------------------------------------
# Filters
# ---------------------------------------------------------------------------
class TestFilters:
def test_filter_by_category(self, repo: ActivityLogRepository) -> None:
repo.record(_entry(category=ActivityCategory.AUTH, action="auth.rejected"))
repo.record(_entry(category=ActivityCategory.DEVICE, action="device.connected"))
repo.record(_entry(category=ActivityCategory.ENTITY, action="entity.deleted"))
results = repo.query(ActivityLogFilters(categories=[ActivityCategory.AUTH]), limit=10)
assert len(results) == 1
assert results[0].category == ActivityCategory.AUTH
def test_filter_by_multiple_categories(self, repo: ActivityLogRepository) -> None:
repo.record(_entry(category=ActivityCategory.AUTH))
repo.record(_entry(category=ActivityCategory.DEVICE))
repo.record(_entry(category=ActivityCategory.SYSTEM))
results = repo.query(
ActivityLogFilters(categories=[ActivityCategory.AUTH, ActivityCategory.DEVICE]),
limit=10,
)
assert len(results) == 2
cats = {r.category for r in results}
assert cats == {ActivityCategory.AUTH, ActivityCategory.DEVICE}
def test_filter_by_severity(self, repo: ActivityLogRepository) -> None:
repo.record(_entry(severity=ActivitySeverity.INFO))
repo.record(_entry(severity=ActivitySeverity.WARNING))
repo.record(_entry(severity=ActivitySeverity.ERROR))
results = repo.query(ActivityLogFilters(severities=[ActivitySeverity.ERROR]), limit=10)
assert len(results) == 1
assert results[0].severity == ActivitySeverity.ERROR
def test_filter_by_multiple_severities(self, repo: ActivityLogRepository) -> None:
repo.record(_entry(severity=ActivitySeverity.INFO))
repo.record(_entry(severity=ActivitySeverity.WARNING))
repo.record(_entry(severity=ActivitySeverity.ERROR))
results = repo.query(
ActivityLogFilters(severities=[ActivitySeverity.WARNING, ActivitySeverity.ERROR]),
limit=10,
)
assert len(results) == 2
def test_filter_by_actor(self, repo: ActivityLogRepository) -> None:
repo.record(_entry(actor="alice"))
repo.record(_entry(actor="bob"))
repo.record(_entry(actor="alice"))
results = repo.query(ActivityLogFilters(actor="alice"), limit=10)
assert len(results) == 2
assert all(r.actor == "alice" for r in results)
def test_filter_by_entity_type(self, repo: ActivityLogRepository) -> None:
repo.record(_entry(entity_type="output_target"))
repo.record(_entry(entity_type="device"))
repo.record(_entry(entity_type="output_target"))
results = repo.query(ActivityLogFilters(entity_type="output_target"), limit=10)
assert len(results) == 2
def test_filter_by_entity_id(self, repo: ActivityLogRepository) -> None:
repo.record(_entry(entity_id="ot_aabbccdd"))
repo.record(_entry(entity_id="ot_11223344"))
repo.record(_entry(entity_id="ot_aabbccdd"))
results = repo.query(ActivityLogFilters(entity_id="ot_aabbccdd"), limit=10)
assert len(results) == 2
def test_filter_by_since(self, repo: ActivityLogRepository) -> None:
base = datetime(2026, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
repo.record(_entry(ts=base - timedelta(hours=2), message="old"))
repo.record(_entry(ts=base, message="boundary"))
repo.record(_entry(ts=base + timedelta(hours=1), message="new"))
results = repo.query(ActivityLogFilters(since=base), limit=10)
assert len(results) == 2
messages = {r.message for r in results}
assert "old" not in messages
def test_filter_by_until(self, repo: ActivityLogRepository) -> None:
base = datetime(2026, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
repo.record(_entry(ts=base - timedelta(hours=1), message="old"))
repo.record(_entry(ts=base, message="boundary"))
repo.record(_entry(ts=base + timedelta(hours=2), message="new"))
results = repo.query(ActivityLogFilters(until=base), limit=10)
assert len(results) == 2
messages = {r.message for r in results}
assert "new" not in messages
def test_filter_message_like_substring(self, repo: ActivityLogRepository) -> None:
repo.record(_entry(message="Created output target MyStrip"))
repo.record(_entry(message="Deleted device sensor-01"))
repo.record(_entry(message="Updated output target MyStrip"))
results = repo.query(ActivityLogFilters(message_like="output target"), limit=10)
assert len(results) == 2
def test_filter_message_like_escapes_percent(self, repo: ActivityLogRepository) -> None:
"""A literal % in message_like should not act as a SQL wildcard."""
repo.record(_entry(message="100% done"))
repo.record(_entry(message="partial done"))
results = repo.query(ActivityLogFilters(message_like="100%"), limit=10)
assert len(results) == 1
assert results[0].message == "100% done"
def test_combined_filters(self, repo: ActivityLogRepository) -> None:
repo.record(
_entry(
actor="alice",
category=ActivityCategory.ENTITY,
severity=ActivitySeverity.INFO,
)
)
repo.record(
_entry(
actor="alice",
category=ActivityCategory.AUTH,
severity=ActivitySeverity.WARNING,
)
)
repo.record(
_entry(
actor="bob",
category=ActivityCategory.ENTITY,
severity=ActivitySeverity.INFO,
)
)
results = repo.query(
ActivityLogFilters(
actor="alice",
categories=[ActivityCategory.ENTITY],
severities=[ActivitySeverity.INFO],
),
limit=10,
)
assert len(results) == 1
assert results[0].actor == "alice"
assert results[0].category == ActivityCategory.ENTITY
# ---------------------------------------------------------------------------
# Keyset pagination
# ---------------------------------------------------------------------------
class TestKeysetPagination:
def test_basic_pagination(self, repo: ActivityLogRepository) -> None:
"""Records returned across two pages cover the full set without overlap."""
for i in range(10):
repo.record(_entry(message=f"entry {i}"))
page1 = repo.query(ActivityLogFilters(), limit=4)
assert len(page1) == 4
# The last entry on page 1 has the smallest seq — use it as the cursor
# We need the seq; query internally reverses, so page1[0] is oldest on page
# and page1[-1] is newest on page. We need the min seq to paginate.
# The repo returns entries in ascending order within a page, so page1[0]
# has the smallest seq on the page.
first_seq = self._get_seq(repo, page1[0].id)
page2 = repo.query(ActivityLogFilters(), before_seq=first_seq, limit=4)
assert len(page2) == 4
ids1 = {e.id for e in page1}
ids2 = {e.id for e in page2}
assert ids1.isdisjoint(ids2), "Pages must not overlap"
page3 = repo.query(
ActivityLogFilters(),
before_seq=self._get_seq(repo, page2[0].id),
limit=4,
)
assert len(page3) == 2 # 10 total; 4 + 4 + 2
def test_same_ts_stability(self, repo: ActivityLogRepository) -> None:
"""Rows with identical ts are ordered by seq, not ts — no duplicates across pages."""
# Insert 6 rows all sharing the exact same timestamp
same_ts = datetime(2026, 3, 10, 15, 0, 0, tzinfo=timezone.utc)
entries = [_entry(ts=same_ts, message=f"same-ts {i}") for i in range(6)]
for e in entries:
repo.record(e)
page1 = repo.query(ActivityLogFilters(), limit=3)
first_seq = self._get_seq(repo, page1[0].id)
page2 = repo.query(ActivityLogFilters(), before_seq=first_seq, limit=3)
ids1 = {e.id for e in page1}
ids2 = {e.id for e in page2}
assert ids1.isdisjoint(ids2), "Same-ts rows leaked across page boundary"
assert ids1 | ids2 == {e.id for e in entries}, "All rows covered exactly once"
def test_empty_page_at_end(self, repo: ActivityLogRepository) -> None:
"""Requesting a page beyond the last entry returns an empty list."""
for i in range(3):
repo.record(_entry(message=f"e{i}"))
page1 = repo.query(ActivityLogFilters(), limit=3)
assert len(page1) == 3
first_seq = self._get_seq(repo, page1[0].id)
page2 = repo.query(ActivityLogFilters(), before_seq=first_seq, limit=3)
assert page2 == []
@staticmethod
def _get_seq(repo: ActivityLogRepository, entry_id: str) -> int:
"""Helper: retrieve the seq for an entry by its application id."""
cursor = repo._db.execute("SELECT seq FROM activity_log WHERE id = ?", (entry_id,))
row = cursor.fetchone()
assert row is not None, f"No row found for id={entry_id!r}"
return int(row["seq"])
# ---------------------------------------------------------------------------
# Prune
# ---------------------------------------------------------------------------
class TestPrune:
def test_prune_by_age(self, repo: ActivityLogRepository) -> None:
base = datetime(2026, 2, 1, 12, 0, 0, tzinfo=timezone.utc)
repo.record(_entry(ts=base - timedelta(days=10), message="old1"))
repo.record(_entry(ts=base - timedelta(days=5), message="old2"))
repo.record(_entry(ts=base, message="boundary"))
repo.record(_entry(ts=base + timedelta(days=1), message="new"))
# Prune everything strictly older than base
deleted = repo.prune(before_ts=base)
assert deleted == 2
remaining = repo.query(ActivityLogFilters(), limit=20)
messages = {r.message for r in remaining}
assert "old1" not in messages
assert "old2" not in messages
assert "boundary" in messages
assert "new" in messages
def test_prune_by_max_entries(self, repo: ActivityLogRepository) -> None:
for i in range(10):
repo.record(_entry(message=f"entry {i}"))
deleted = repo.prune(max_entries=3)
assert deleted == 7
assert repo.count() == 3
def test_prune_keeps_newest_on_max_entries(self, repo: ActivityLogRepository) -> None:
base = datetime(2026, 1, 1, tzinfo=timezone.utc)
ids = []
for i in range(5):
e = _entry(ts=base + timedelta(hours=i), message=f"entry {i}")
ids.append(e.id)
repo.record(e)
# Keep only the 2 newest
repo.prune(max_entries=2)
remaining = repo.query(ActivityLogFilters(), limit=10)
remaining_ids = {r.id for r in remaining}
# The 2 newest are the last 2 inserted (highest seq)
assert ids[3] in remaining_ids
assert ids[4] in remaining_ids
assert ids[0] not in remaining_ids
def test_prune_both_predicates(self, repo: ActivityLogRepository) -> None:
base = datetime(2026, 2, 1, 0, 0, 0, tzinfo=timezone.utc)
# Insert 6 entries: 3 old, 3 recent
for i in range(3):
repo.record(_entry(ts=base - timedelta(days=i + 1), message=f"old{i}"))
for i in range(3):
repo.record(_entry(ts=base + timedelta(hours=i), message=f"new{i}"))
# Prune old entries AND keep at most 2 of the remaining
deleted = repo.prune(before_ts=base, max_entries=2)
# 3 age-pruned + 1 count-pruned = 4
assert deleted == 4
assert repo.count() == 2
def test_prune_max_entries_zero_clears_all(self, repo: ActivityLogRepository) -> None:
for i in range(5):
repo.record(_entry())
repo.prune(max_entries=0)
assert repo.count() == 0
def test_prune_no_op_when_below_max(self, repo: ActivityLogRepository) -> None:
for i in range(3):
repo.record(_entry())
deleted = repo.prune(max_entries=10)
assert deleted == 0
assert repo.count() == 3
# ---------------------------------------------------------------------------
# Clear
# ---------------------------------------------------------------------------
class TestClear:
def test_clear_returns_row_count(self, repo: ActivityLogRepository) -> None:
for _ in range(5):
repo.record(_entry())
deleted = repo.clear()
assert deleted == 5
def test_clear_empties_table(self, repo: ActivityLogRepository) -> None:
for _ in range(3):
repo.record(_entry())
repo.clear()
assert repo.count() == 0
def test_clear_on_empty_table(self, repo: ActivityLogRepository) -> None:
deleted = repo.clear()
assert deleted == 0
# ---------------------------------------------------------------------------
# Count
# ---------------------------------------------------------------------------
class TestCount:
def test_count_all(self, repo: ActivityLogRepository) -> None:
for _ in range(7):
repo.record(_entry())
assert repo.count() == 7
def test_count_filtered(self, repo: ActivityLogRepository) -> None:
repo.record(_entry(category=ActivityCategory.AUTH))
repo.record(_entry(category=ActivityCategory.DEVICE))
repo.record(_entry(category=ActivityCategory.AUTH))
n = repo.count(ActivityLogFilters(categories=[ActivityCategory.AUTH]))
assert n == 2
def test_count_empty(self, repo: ActivityLogRepository) -> None:
assert repo.count() == 0
def test_count_no_match(self, repo: ActivityLogRepository) -> None:
repo.record(_entry(category=ActivityCategory.ENTITY))
n = repo.count(ActivityLogFilters(categories=[ActivityCategory.AUTH]))
assert n == 0
# ---------------------------------------------------------------------------
# Export iterator
# ---------------------------------------------------------------------------
class TestExportIterator:
def test_iter_export_yields_all_rows(self, repo: ActivityLogRepository) -> None:
entries = [_entry(message=f"e{i}") for i in range(5)]
for e in entries:
repo.record(e)
exported = list(repo.iter_export())
assert len(exported) == 5
exported_ids = {e.id for e in exported}
assert exported_ids == {e.id for e in entries}
def test_iter_export_ascending_order(self, repo: ActivityLogRepository) -> None:
base = datetime(2026, 4, 1, tzinfo=timezone.utc)
for i in range(5):
repo.record(_entry(ts=base + timedelta(seconds=i), message=f"e{i}"))
exported = list(repo.iter_export())
seqs = [
repo._db.execute("SELECT seq FROM activity_log WHERE id = ?", (e.id,)).fetchone()["seq"]
for e in exported
]
assert seqs == sorted(seqs), "iter_export must yield rows in ascending seq order"
def test_iter_export_with_filter(self, repo: ActivityLogRepository) -> None:
repo.record(_entry(category=ActivityCategory.AUTH))
repo.record(_entry(category=ActivityCategory.ENTITY))
repo.record(_entry(category=ActivityCategory.AUTH))
exported = list(repo.iter_export(ActivityLogFilters(categories=[ActivityCategory.AUTH])))
assert len(exported) == 2
assert all(e.category == ActivityCategory.AUTH for e in exported)
def test_iter_export_streaming_not_all_in_memory(self, repo: ActivityLogRepository) -> None:
"""Verify iter_export is a generator (lazy), not a pre-loaded list."""
import types
for _ in range(3):
repo.record(_entry())
result = repo.iter_export()
assert isinstance(result, types.GeneratorType)
# ---------------------------------------------------------------------------
# Migration idempotency
# ---------------------------------------------------------------------------
class TestMigrationIdempotency:
def test_construct_repo_twice_is_noop(self, tmp_db: Database) -> None:
"""Creating two repos on the same DB does not re-run the migration."""
repo1 = ActivityLogRepository(tmp_db)
repo1.record(_entry(message="before second construction"))
# Second construction must not raise or re-apply the migration
repo2 = ActivityLogRepository(tmp_db)
assert repo2.count() == 1
# Confirm migration is recorded exactly once
# Count how many times our migration name appears (should be 1)
cursor = tmp_db.execute(
"SELECT COUNT(*) AS cnt FROM data_migrations WHERE name = ?",
("002_add_activity_log",),
)
assert cursor.fetchone()["cnt"] == 1
def test_running_migrations_twice_is_noop(self, tmp_db: Database) -> None:
"""MigrationRunner.run is idempotent for AddActivityLogTableMigration."""
from ledgrab.storage.data_migrations import (
AddActivityLogTableMigration,
MigrationRunner,
)
runner = MigrationRunner(tmp_db)
migration = AddActivityLogTableMigration()
first_run = runner.run([migration])
assert len(first_run) == 1
assert first_run[0].name == "002_add_activity_log"
second_run = runner.run([migration])
assert second_run == [], "Second run must be a no-op"
def test_activity_log_table_exists_after_migration(self, tmp_db: Database) -> None:
"""The activity_log table is present after the migration runs."""
ActivityLogRepository(tmp_db)
cursor = tmp_db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='activity_log'"
)
assert cursor.fetchone() is not None
def test_activity_log_indexes_exist_after_migration(self, tmp_db: Database) -> None:
"""All declared indexes are present after migration."""
ActivityLogRepository(tmp_db)
cursor = tmp_db.execute(
"SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='activity_log'"
)
index_names = {row["name"] for row in cursor.fetchall()}
expected = {
"idx_activity_log_ts_seq",
"idx_activity_log_category",
"idx_activity_log_severity",
"idx_activity_log_actor",
"idx_activity_log_entity",
}
assert expected.issubset(index_names)