package api import ( "context" "crypto/rand" "encoding/hex" "log/slog" "net" "net/http" "os" "runtime/debug" "strings" "sync" "time" "github.com/alexei/tinyforge/internal/metrics" ) // requestIDKey is the context key under which the generated/forwarded // X-Request-ID is stored. Exported indirectly via RequestIDFromContext // so handlers and services downstream of the API layer can thread it // into their own slog calls without re-extracting from headers. type requestIDKeyType struct{} var requestIDKey = requestIDKeyType{} // RequestIDFromContext returns the correlation ID for the request, or // "" when called outside the API request path. func RequestIDFromContext(ctx context.Context) string { if v, ok := ctx.Value(requestIDKey).(string); ok { return v } return "" } // requestID middleware ensures every request has a stable correlation // ID. Honors a caller-supplied X-Request-ID when the request comes from // a trusted proxy AND the value matches a safe character set; otherwise // generates a fresh 128-bit ID. The ID is echoed back as X-Request-ID // and stitched into every subsequent slog call via the context value // the `logging` middleware reads. // // Format clamp: a compromised reverse proxy (or one that mis-parses an // untrusted header) could forward an ID containing newlines, semicolons, // or other separator characters. Those would corrupt structured log // parsers that assume one record per line / key-value. Restricting to // `[A-Za-z0-9._-]{1,64}` covers UUIDs, hex IDs, and trace-context IDs // without any sharp edges. func requestID(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rid := r.Header.Get("X-Request-ID") if rid == "" || !isTrustedPeer(r) || !isValidRequestID(rid) { rid = newRequestID() } w.Header().Set("X-Request-ID", rid) ctx := context.WithValue(r.Context(), requestIDKey, rid) next.ServeHTTP(w, r.WithContext(ctx)) }) } // isValidRequestID enforces `[A-Za-z0-9._-]{1,64}` without compiling a // regex on the request path. Single linear scan, no allocations. func isValidRequestID(s string) bool { if len(s) == 0 || len(s) > 64 { return false } for i := 0; i < len(s); i++ { c := s[i] switch { case c >= 'A' && c <= 'Z': case c >= 'a' && c <= 'z': case c >= '0' && c <= '9': case c == '.' || c == '_' || c == '-': default: return false } } return true } // isTrustedPeer is a thin wrapper around the TRUSTED_PROXY_CIDRS allow- // list — we honor a forwarded request-id only from upstreams we already // trust for X-Forwarded-For. Otherwise an internet client could spam // log files with attacker-chosen IDs. func isTrustedPeer(r *http.Request) bool { peer := r.RemoteAddr if host, _, err := net.SplitHostPort(peer); err == nil { peer = host } if len(trustedProxyCIDRs) == 0 { return false } ip := net.ParseIP(peer) if ip == nil { return false } for _, n := range trustedProxyCIDRs { if n.Contains(ip) { return true } } return false } func newRequestID() string { var b [16]byte if _, err := rand.Read(b[:]); err != nil { // Fall back to time-based suffix if crypto/rand is unavailable // — extremely unlikely outside of broken environments, but the // ID is for tracing not security, so a deterministic fallback // is preferable to a panic. return "ts-" + time.Now().UTC().Format("20060102T150405.000000000") } return hex.EncodeToString(b[:]) } // logging is an HTTP middleware that logs every request with method, path, // status code, and duration. Webhook URLs are redacted before being logged // because the secret is the only authenticator — leaking it to log // aggregators is equivalent to leaking the credential. func logging(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() wrapped := &statusRecorder{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(wrapped, r) fields := []any{ "method", r.Method, "path", redactPath(r.URL.Path), "status", wrapped.status, "duration", time.Since(start).String(), } if rq := redactQuery(r.URL.RawQuery); rq != "" { fields = append(fields, "query", rq) } if rid := RequestIDFromContext(r.Context()); rid != "" { fields = append(fields, "request_id", rid) } slog.Info("http request", fields...) // Lightweight per-request counter. Bucket by status class so // the cardinality stays at 5 × #methods regardless of how many // distinct response codes we emit. metrics.HTTPRequestsTotal.Inc(bucketMethod(r.Method), statusClass(wrapped.status)) }) } // bucketMethod normalises HTTP method names against the standard set // so a malicious client cannot spam arbitrary method tokens (RFC 7230 // allows any token) and inflate the metrics map. Anything off the // allow-list collapses to "other". func bucketMethod(m string) string { switch m { case "GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS", "CONNECT", "TRACE": return m } return "other" } // statusClass buckets a status code into "1xx".."5xx" / "other". Keeps // metrics cardinality bounded so a chatty endpoint can't explode the // metrics map with one series per distinct response code. func statusClass(code int) string { switch { case code >= 100 && code < 200: return "1xx" case code >= 200 && code < 300: return "2xx" case code >= 300 && code < 400: return "3xx" case code >= 400 && code < 500: return "4xx" case code >= 500 && code < 600: return "5xx" } return "other" } // redactPath strips secrets from URL paths that carry them in segments. // Only the canonical /api/webhook/triggers/{secret} surface remains after // the hard cutover. func redactPath(path string) string { const triggerPrefix = "/api/webhook/triggers/" if strings.HasPrefix(path, triggerPrefix) { return triggerPrefix + "***" } return path } // redactQueryKeys is the case-insensitive set of query-parameter names whose // values are masked before a URL lands in the request log. `token` is used by // SSE/EventSource when a custom header can't be set; the rest are // defence-in-depth against sensitive values ever appearing in a query string. var redactQueryKeys = map[string]struct{}{ "token": {}, "secret": {}, "password": {}, "passwd": {}, "api_key": {}, "apikey": {}, "access_token": {}, "client_secret": {}, "sig": {}, "signature": {}, } // redactQuery masks the values of sensitive query parameters (see // redactQueryKeys) in a URL's raw query before it lands in the request log. // Key matching is case-insensitive. Returns the input unchanged when there is // nothing to redact so a malformed URL surfaces naturally. func redactQuery(rawQuery string) string { if rawQuery == "" { return "" } parts := strings.Split(rawQuery, "&") for i, p := range parts { eq := strings.IndexByte(p, '=') if eq < 0 { continue } key := strings.ToLower(p[:eq]) if _, ok := redactQueryKeys[key]; ok { parts[i] = p[:eq+1] + "***" } } return strings.Join(parts, "&") } // recovery is an HTTP middleware that catches panics and returns a 500 response. func recovery(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if err := recover(); err != nil { slog.Error("panic recovered", "error", err, "stack", string(debug.Stack())) respondError(w, http.StatusInternalServerError, "internal server error") } }() next.ServeHTTP(w, r) }) } // securityHeaders sets standard security headers on all responses. // // Strict-Transport-Security is emitted only when the request arrived // over HTTPS (direct TLS or forwarded). Emitting HSTS over plain HTTP // is harmless to compliant browsers but flags as an issue in scanners // and confuses some reverse proxies. // // The CSP keeps `'unsafe-inline'` for now because SvelteKit injects // inline boot scripts and styles; removing it requires a nonce-based // strategy threaded through the SvelteKit handle hook. Tracked as a // follow-up; documented in the security report. func securityHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=(), payment=()") w.Header().Set("Content-Security-Policy", "default-src 'self'; "+ "script-src 'self' 'unsafe-inline'; "+ "style-src 'self' 'unsafe-inline'; "+ "img-src 'self' data:; "+ "connect-src 'self'; "+ "font-src 'self'; "+ "frame-ancestors 'none'; "+ "base-uri 'self'; "+ "form-action 'self'") if isHTTPS(r) { w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") } next.ServeHTTP(w, r) }) } func isHTTPS(r *http.Request) bool { if r.TLS != nil { return true } if r.Header.Get("X-Forwarded-Proto") == "https" { return true } return false } // cors is an HTTP middleware that handles CORS for same-origin requests. // The frontend is served from the same origin, so cross-origin requests are not expected. func cors(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // The frontend is served from the same origin, so cross-origin // requests are not expected. We do NOT reflect the Origin header // back, as that would allow any website to make credentialed requests. // If cross-origin support is needed in the future, maintain an // explicit allowlist of trusted origins here. if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } // maxBodySize limits request body sizes to prevent memory exhaustion. const maxBodySize = 1 << 20 // 1 MB func limitBody(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.Body = http.MaxBytesReader(w, r.Body, maxBodySize) next.ServeHTTP(w, r) }) } // rateLimiter provides per-IP rate limiting for login endpoints. type rateLimiter struct { mu sync.Mutex attempts map[string][]time.Time lastCleanup time.Time } func newRateLimiter() *rateLimiter { return &rateLimiter{ attempts: make(map[string][]time.Time), lastCleanup: time.Now(), } } // allow checks if the IP is allowed to make another request. // Returns false if the IP has exceeded the limit (10 requests per minute). func (rl *rateLimiter) allow(ip string) bool { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() window := now.Add(-1 * time.Minute) // Periodically clean all stale IPs to prevent memory leak. if now.Sub(rl.lastCleanup) > 5*time.Minute { for k, times := range rl.attempts { filtered := times[:0] for _, t := range times { if t.After(window) { filtered = append(filtered, t) } } if len(filtered) == 0 { delete(rl.attempts, k) } else { rl.attempts[k] = filtered } } rl.lastCleanup = now } // Clean old entries for this IP. filtered := rl.attempts[ip][:0] for _, t := range rl.attempts[ip] { if t.After(window) { filtered = append(filtered, t) } } rl.attempts[ip] = filtered if len(filtered) >= 10 { return false } rl.attempts[ip] = append(rl.attempts[ip], now) return true } // jsonContentType is an HTTP middleware that sets the default Content-Type to JSON. func jsonContentType(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") next.ServeHTTP(w, r) }) } // rateLimitMiddleware wraps a handler with per-IP rate limiting using the // supplied limiter. Requests over the limit get 429. func rateLimitMiddleware(rl *rateLimiter) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := clientIP(r) if !rl.allow(ip) { respondError(w, http.StatusTooManyRequests, "rate limit exceeded") return } next.ServeHTTP(w, r) }) } } // trustedProxyCIDRs is the parsed allow-list of upstream proxy networks // whose X-Forwarded-For header we honor. Set TRUSTED_PROXY_CIDRS to a // comma-separated list of CIDRs (e.g. "127.0.0.1/32,10.0.0.0/8") to // enable. When unset (the default) X-Forwarded-For is ignored entirely // and rate limiting + audit logging use r.RemoteAddr — preventing a // remote attacker from spoofing the header to bypass per-IP limiters. var trustedProxyCIDRs = parseTrustedProxyCIDRs(os.Getenv("TRUSTED_PROXY_CIDRS")) func parseTrustedProxyCIDRs(raw string) []*net.IPNet { raw = strings.TrimSpace(raw) if raw == "" { return nil } var nets []*net.IPNet for _, p := range strings.Split(raw, ",") { p = strings.TrimSpace(p) if p == "" { continue } // Allow bare IPs as /32 (IPv4) or /128 (IPv6). if !strings.Contains(p, "/") { if ip := net.ParseIP(p); ip != nil { if ip.To4() != nil { p += "/32" } else { p += "/128" } } } _, n, err := net.ParseCIDR(p) if err != nil { slog.Warn("ignoring invalid TRUSTED_PROXY_CIDRS entry", "value", p, "error", err) continue } nets = append(nets, n) } return nets } // clientIP returns the per-request "client" address used for rate-limit // keying and audit attribution. X-Forwarded-For is honored ONLY when the // direct peer (r.RemoteAddr) belongs to a configured trusted-proxy CIDR; // otherwise the header is ignored to prevent header-spoofing bypasses. func clientIP(r *http.Request) string { peer := r.RemoteAddr if host, _, err := net.SplitHostPort(peer); err == nil { peer = host } if len(trustedProxyCIDRs) == 0 { return peer } peerIP := net.ParseIP(peer) if peerIP == nil || !isTrustedProxy(peerIP) { return peer } fwd := r.Header.Get("X-Forwarded-For") if fwd == "" { return peer } // Walk X-Forwarded-For from the RIGHTMOST entry (the address closest to // us, appended by our trusted peer) leftward, skipping entries that are // themselves trusted proxies, and return the first untrusted address. // The LEFTMOST entry is fully client-controlled — trusting it (as a // naive `fwd[:firstComma]` does) lets an attacker spoof their rate-limit // and audit identity by prepending a forged value, defeating the per-IP // login limiter. parts := strings.Split(fwd, ",") for i := len(parts) - 1; i >= 0; i-- { candidate := strings.TrimSpace(parts[i]) ip := net.ParseIP(candidate) if ip == nil { continue } if isTrustedProxy(ip) { continue } return candidate } // Every forwarded entry was a trusted proxy (or unparseable) — fall back // to the direct peer. return peer } // isTrustedProxy reports whether ip falls within a configured // trusted-proxy CIDR. func isTrustedProxy(ip net.IP) bool { for _, n := range trustedProxyCIDRs { if n.Contains(ip) { return true } } return false } // statusRecorder wraps http.ResponseWriter to capture the status code. type statusRecorder struct { http.ResponseWriter status int } func (r *statusRecorder) WriteHeader(code int) { r.status = code r.ResponseWriter.WriteHeader(code) } // Flush delegates to the underlying ResponseWriter if it supports http.Flusher (needed for SSE). func (r *statusRecorder) Flush() { if f, ok := r.ResponseWriter.(http.Flusher); ok { f.Flush() } }