"""Tests for WebSocket first-message authentication.""" import json import pytest from fastapi import FastAPI, WebSocket from fastapi.testclient import TestClient import ledgrab.config as config_mod from ledgrab.config import AuthConfig, Config, ServerConfig, StorageConfig # --------------------------------------------------------------------------- # Minimal app with a single WS endpoint using verify_ws_auth # --------------------------------------------------------------------------- def _make_app() -> FastAPI: app = FastAPI() @app.websocket("/ws") async def ws_endpoint(websocket: WebSocket): from ledgrab.api.auth import WS_AUTH_CLOSE_CODE, verify_ws_auth await websocket.accept() label = await verify_ws_auth(websocket) if label is None: await websocket.close(code=WS_AUTH_CLOSE_CODE) return await websocket.send_json({"echo": "hello", "label": label}) # Keep alive until client disconnects try: while True: await websocket.receive_text() except Exception: pass return app # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture() def app(): return _make_app() @pytest.fixture() def _patch_config_with_keys(monkeypatch, tmp_path): """Patch global config to have a test API key.""" data_dir = tmp_path / "data" data_dir.mkdir(parents=True, exist_ok=True) cfg = Config( server=ServerConfig(host="127.0.0.1", port=9999), auth=AuthConfig(api_keys={"dev": "secret-key-abc"}), storage=StorageConfig(database_file=str(data_dir / "t.db")), ) monkeypatch.setattr(config_mod, "config", cfg) @pytest.fixture() def _patch_config_no_keys(monkeypatch, tmp_path): """Patch global config with empty api_keys (loopback-only mode).""" data_dir = tmp_path / "data" data_dir.mkdir(parents=True, exist_ok=True) cfg = Config( server=ServerConfig(host="127.0.0.1", port=9999), auth=AuthConfig(api_keys={}), storage=StorageConfig(database_file=str(data_dir / "t.db")), ) monkeypatch.setattr(config_mod, "config", cfg) # --------------------------------------------------------------------------- # Tests — keys configured # --------------------------------------------------------------------------- class TestWsAuthWithKeys: """WS auth when api_keys are configured.""" @pytest.mark.usefixtures("_patch_config_with_keys") def test_valid_token(self, app): client = TestClient(app) with client.websocket_connect("/ws") as ws: ws.send_text(json.dumps({"type": "auth", "token": "secret-key-abc"})) resp = json.loads(ws.receive_text()) assert resp["type"] == "auth_ok" data = json.loads(ws.receive_text()) assert data["label"] == "dev" @pytest.mark.usefixtures("_patch_config_with_keys") def test_invalid_token(self, app): client = TestClient(app) with client.websocket_connect("/ws") as ws: ws.send_text(json.dumps({"type": "auth", "token": "wrong"})) resp = json.loads(ws.receive_text()) assert resp["type"] == "auth_error" assert "invalid" in resp["reason"].lower() @pytest.mark.usefixtures("_patch_config_with_keys") def test_missing_token(self, app): client = TestClient(app) with client.websocket_connect("/ws") as ws: ws.send_text(json.dumps({"type": "auth"})) resp = json.loads(ws.receive_text()) assert resp["type"] == "auth_error" @pytest.mark.usefixtures("_patch_config_with_keys") def test_non_auth_first_message(self, app): client = TestClient(app) with client.websocket_connect("/ws") as ws: ws.send_text(json.dumps({"type": "ping"})) resp = json.loads(ws.receive_text()) assert resp["type"] == "auth_error" assert "auth" in resp["reason"].lower() @pytest.mark.usefixtures("_patch_config_with_keys") def test_invalid_json(self, app): client = TestClient(app) with client.websocket_connect("/ws") as ws: ws.send_text("not json at all") resp = json.loads(ws.receive_text()) assert resp["type"] == "auth_error" assert "json" in resp["reason"].lower() # --------------------------------------------------------------------------- # Tests — no keys (loopback anonymous) # --------------------------------------------------------------------------- class TestWsAuthLoopbackAnonymous: """WS auth when api_keys is empty — loopback clients get anonymous access. The Starlette TestClient reports client host as "testclient" which is in the _LOOPBACK_HOSTS set. """ @pytest.mark.usefixtures("_patch_config_no_keys") def test_anonymous_with_auth_message(self, app): """Sending an auth message on loopback with no keys is a no-op — still succeeds.""" client = TestClient(app) with client.websocket_connect("/ws") as ws: ws.send_text(json.dumps({"type": "auth", "token": None})) resp = json.loads(ws.receive_text()) assert resp["type"] == "auth_ok" data = json.loads(ws.receive_text()) assert data["label"] == "anonymous" @pytest.mark.usefixtures("_patch_config_no_keys") def test_anonymous_with_token(self, app): """Sending a token on loopback with no keys is also fine.""" client = TestClient(app) with client.websocket_connect("/ws") as ws: ws.send_text(json.dumps({"type": "auth", "token": "anything"})) resp = json.loads(ws.receive_text()) assert resp["type"] == "auth_ok" data = json.loads(ws.receive_text()) assert data["label"] == "anonymous" # --------------------------------------------------------------------------- # Tests — accept_and_authenticate_ws helper # --------------------------------------------------------------------------- class TestAcceptAndAuthenticateWs: """Test the convenience wrapper.""" @pytest.mark.usefixtures("_patch_config_with_keys") def test_accept_and_auth_success(self): app = FastAPI() @app.websocket("/ws2") async def ws2(websocket: WebSocket): from ledgrab.api.auth import accept_and_authenticate_ws label = await accept_and_authenticate_ws(websocket) if label is None: return await websocket.send_json({"ok": True, "label": label}) try: while True: await websocket.receive_text() except Exception: pass client = TestClient(app) with client.websocket_connect("/ws2") as ws: ws.send_text(json.dumps({"type": "auth", "token": "secret-key-abc"})) resp = json.loads(ws.receive_text()) assert resp["type"] == "auth_ok" data = json.loads(ws.receive_text()) assert data["ok"] is True assert data["label"] == "dev" @pytest.mark.usefixtures("_patch_config_with_keys") def test_accept_and_auth_failure_closes(self): app = FastAPI() @app.websocket("/ws3") async def ws3(websocket: WebSocket): from ledgrab.api.auth import accept_and_authenticate_ws label = await accept_and_authenticate_ws(websocket) if label is None: return # Should not reach here await websocket.send_json({"should": "not happen"}) client = TestClient(app) with client.websocket_connect("/ws3") as ws: ws.send_text(json.dumps({"type": "auth", "token": "wrong-key"})) resp = json.loads(ws.receive_text()) assert resp["type"] == "auth_error"