888f8fd16e
ruff --select UP007,UP045 --fix converted ~1760 sites across the backend: `Optional[T]` → `T | None`, `Union[X, Y]` → `X | Y`. The remaining module-level alias targets that ruff conservatively skips (BindableFloatInput, ColorList, DeviceConfig) were converted by hand earlier in the pass. black -formatted the result so the wider unions fit cleanly under the 100-char line budget. pyproject.toml now sets [tool.ruff.lint] extend-select = ["UP007", "UP045"] so future legacy imports fire CI on every push. The pre-commit ruff hook was bumped from v0.8.0 -> v0.15.12 to recognise UP045 (split off from UP007 in v0.13).
219 lines
7.9 KiB
Python
219 lines
7.9 KiB
Python
"""Tests for WebSocket first-message authentication."""
|
|
|
|
import json
|
|
|
|
import pytest
|
|
from fastapi import FastAPI, WebSocket
|
|
from fastapi.testclient import TestClient
|
|
|
|
import ledgrab.config as config_mod
|
|
from ledgrab.config import AuthConfig, Config, ServerConfig, StorageConfig
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Minimal app with a single WS endpoint using verify_ws_auth
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_app() -> FastAPI:
|
|
app = FastAPI()
|
|
|
|
@app.websocket("/ws")
|
|
async def ws_endpoint(websocket: WebSocket):
|
|
from ledgrab.api.auth import WS_AUTH_CLOSE_CODE, verify_ws_auth
|
|
|
|
await websocket.accept()
|
|
label = await verify_ws_auth(websocket)
|
|
if label is None:
|
|
await websocket.close(code=WS_AUTH_CLOSE_CODE)
|
|
return
|
|
await websocket.send_json({"echo": "hello", "label": label})
|
|
# Keep alive until client disconnects
|
|
try:
|
|
while True:
|
|
await websocket.receive_text()
|
|
except Exception:
|
|
pass
|
|
|
|
return app
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture()
|
|
def app():
|
|
return _make_app()
|
|
|
|
|
|
@pytest.fixture()
|
|
def _patch_config_with_keys(monkeypatch, tmp_path):
|
|
"""Patch global config to have a test API key."""
|
|
data_dir = tmp_path / "data"
|
|
data_dir.mkdir(parents=True, exist_ok=True)
|
|
cfg = Config(
|
|
server=ServerConfig(host="127.0.0.1", port=9999),
|
|
auth=AuthConfig(api_keys={"dev": "secret-key-abc"}),
|
|
storage=StorageConfig(database_file=str(data_dir / "t.db")),
|
|
)
|
|
monkeypatch.setattr(config_mod, "config", cfg)
|
|
|
|
|
|
@pytest.fixture()
|
|
def _patch_config_no_keys(monkeypatch, tmp_path):
|
|
"""Patch global config with empty api_keys (loopback-only mode)."""
|
|
data_dir = tmp_path / "data"
|
|
data_dir.mkdir(parents=True, exist_ok=True)
|
|
cfg = Config(
|
|
server=ServerConfig(host="127.0.0.1", port=9999),
|
|
auth=AuthConfig(api_keys={}),
|
|
storage=StorageConfig(database_file=str(data_dir / "t.db")),
|
|
)
|
|
monkeypatch.setattr(config_mod, "config", cfg)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests — keys configured
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWsAuthWithKeys:
|
|
"""WS auth when api_keys are configured."""
|
|
|
|
@pytest.mark.usefixtures("_patch_config_with_keys")
|
|
def test_valid_token(self, app):
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_text(json.dumps({"type": "auth", "token": "secret-key-abc"}))
|
|
resp = json.loads(ws.receive_text())
|
|
assert resp["type"] == "auth_ok"
|
|
data = json.loads(ws.receive_text())
|
|
assert data["label"] == "dev"
|
|
|
|
@pytest.mark.usefixtures("_patch_config_with_keys")
|
|
def test_invalid_token(self, app):
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_text(json.dumps({"type": "auth", "token": "wrong"}))
|
|
resp = json.loads(ws.receive_text())
|
|
assert resp["type"] == "auth_error"
|
|
assert "invalid" in resp["reason"].lower()
|
|
|
|
@pytest.mark.usefixtures("_patch_config_with_keys")
|
|
def test_missing_token(self, app):
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_text(json.dumps({"type": "auth"}))
|
|
resp = json.loads(ws.receive_text())
|
|
assert resp["type"] == "auth_error"
|
|
|
|
@pytest.mark.usefixtures("_patch_config_with_keys")
|
|
def test_non_auth_first_message(self, app):
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_text(json.dumps({"type": "ping"}))
|
|
resp = json.loads(ws.receive_text())
|
|
assert resp["type"] == "auth_error"
|
|
assert "auth" in resp["reason"].lower()
|
|
|
|
@pytest.mark.usefixtures("_patch_config_with_keys")
|
|
def test_invalid_json(self, app):
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_text("not json at all")
|
|
resp = json.loads(ws.receive_text())
|
|
assert resp["type"] == "auth_error"
|
|
assert "json" in resp["reason"].lower()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests — no keys (loopback anonymous)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestWsAuthLoopbackAnonymous:
|
|
"""WS auth when api_keys is empty — loopback clients get anonymous access.
|
|
|
|
The Starlette TestClient reports client host as "testclient" which
|
|
is in the _LOOPBACK_HOSTS set.
|
|
"""
|
|
|
|
@pytest.mark.usefixtures("_patch_config_no_keys")
|
|
def test_anonymous_with_auth_message(self, app):
|
|
"""Sending an auth message on loopback with no keys is a no-op — still succeeds."""
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_text(json.dumps({"type": "auth", "token": None}))
|
|
resp = json.loads(ws.receive_text())
|
|
assert resp["type"] == "auth_ok"
|
|
data = json.loads(ws.receive_text())
|
|
assert data["label"] == "anonymous"
|
|
|
|
@pytest.mark.usefixtures("_patch_config_no_keys")
|
|
def test_anonymous_with_token(self, app):
|
|
"""Sending a token on loopback with no keys is also fine."""
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/ws") as ws:
|
|
ws.send_text(json.dumps({"type": "auth", "token": "anything"}))
|
|
resp = json.loads(ws.receive_text())
|
|
assert resp["type"] == "auth_ok"
|
|
data = json.loads(ws.receive_text())
|
|
assert data["label"] == "anonymous"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests — accept_and_authenticate_ws helper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAcceptAndAuthenticateWs:
|
|
"""Test the convenience wrapper."""
|
|
|
|
@pytest.mark.usefixtures("_patch_config_with_keys")
|
|
def test_accept_and_auth_success(self):
|
|
app = FastAPI()
|
|
|
|
@app.websocket("/ws2")
|
|
async def ws2(websocket: WebSocket):
|
|
from ledgrab.api.auth import accept_and_authenticate_ws
|
|
|
|
label = await accept_and_authenticate_ws(websocket)
|
|
if label is None:
|
|
return
|
|
await websocket.send_json({"ok": True, "label": label})
|
|
try:
|
|
while True:
|
|
await websocket.receive_text()
|
|
except Exception:
|
|
pass
|
|
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/ws2") as ws:
|
|
ws.send_text(json.dumps({"type": "auth", "token": "secret-key-abc"}))
|
|
resp = json.loads(ws.receive_text())
|
|
assert resp["type"] == "auth_ok"
|
|
data = json.loads(ws.receive_text())
|
|
assert data["ok"] is True
|
|
assert data["label"] == "dev"
|
|
|
|
@pytest.mark.usefixtures("_patch_config_with_keys")
|
|
def test_accept_and_auth_failure_closes(self):
|
|
app = FastAPI()
|
|
|
|
@app.websocket("/ws3")
|
|
async def ws3(websocket: WebSocket):
|
|
from ledgrab.api.auth import accept_and_authenticate_ws
|
|
|
|
label = await accept_and_authenticate_ws(websocket)
|
|
if label is None:
|
|
return
|
|
# Should not reach here
|
|
await websocket.send_json({"should": "not happen"})
|
|
|
|
client = TestClient(app)
|
|
with client.websocket_connect("/ws3") as ws:
|
|
ws.send_text(json.dumps({"type": "auth", "token": "wrong-key"}))
|
|
resp = json.loads(ws.receive_text())
|
|
assert resp["type"] == "auth_error"
|