"""In-process token-bucket rate limiter. Light enough for a single-process app: one dict keyed by ``(bucket, peer)`` guarded by a thread lock. No extra dependency, no Redis. Good enough for defeating credential-stuffing and runaway clients on a LAN; not a substitute for an upstream WAF in a public deployment. Buckets: auth — failed-auth attempts, 5/min/peer (used in auth middleware) execute — script + callback execute calls, 10/min/peer (LAN-friendly) default — generic POST/DELETE writes, 60/min/peer """ from __future__ import annotations import logging import threading import time from dataclasses import dataclass from typing import Optional logger = logging.getLogger(__name__) @dataclass class BucketConfig: capacity: float # max tokens (= burst size) refill_per_sec: float # tokens added per second # Defaults — tuned for "trusted LAN" use; operator can override via Settings. BUCKETS: dict[str, BucketConfig] = { "auth": BucketConfig(capacity=5, refill_per_sec=5 / 60), # 5/min "execute": BucketConfig(capacity=10, refill_per_sec=10 / 60), # 10/min "default": BucketConfig(capacity=60, refill_per_sec=60 / 60), # 60/min } _state: dict[tuple[str, str], tuple[float, float]] = {} _lock = threading.Lock() _LAST_CLEANUP = 0.0 def _evict_stale_locked(now: float) -> None: """Drop entries whose buckets are full (= idle for capacity / refill seconds).""" global _LAST_CLEANUP if now - _LAST_CLEANUP < 60: return _LAST_CLEANUP = now stale = [] for key, (tokens, last) in _state.items(): bucket = BUCKETS.get(key[0]) if bucket is None: continue if tokens >= bucket.capacity and (now - last) > 3600: stale.append(key) for key in stale: _state.pop(key, None) def check(bucket: str, peer: str) -> tuple[bool, Optional[float]]: """Try to consume one token from ``(bucket, peer)``. Returns: (allowed, retry_after_seconds). When allowed=True retry_after is None. When allowed=False, retry_after is the seconds to wait for one more token. """ cfg = BUCKETS.get(bucket) or BUCKETS["default"] now = time.monotonic() with _lock: _evict_stale_locked(now) tokens, last = _state.get((bucket, peer), (cfg.capacity, now)) elapsed = max(0.0, now - last) tokens = min(cfg.capacity, tokens + elapsed * cfg.refill_per_sec) if tokens >= 1: tokens -= 1 _state[(bucket, peer)] = (tokens, now) return True, None deficit = 1 - tokens retry = deficit / cfg.refill_per_sec if cfg.refill_per_sec > 0 else 60 _state[(bucket, peer)] = (tokens, now) return False, retry def get_peer(request) -> str: """Best-effort peer identifier from a Starlette request. Honors X-Forwarded-For (only when settings.proxy_headers is True, which is already enforced by uvicorn's middleware) so a reverse-proxied install still rate-limits per real client. """ client = getattr(request, "client", None) if client and client.host: return client.host return "unknown"