package auth import ( "context" "net/http" "strings" ) // contextKey is the type for context value keys used by the auth package. type contextKey string const claimsKey contextKey = "auth_claims" // Middleware returns an HTTP middleware that protects routes by requiring a valid JWT. // It extracts the token from the Authorization header (Bearer scheme) or the "token" // query parameter (for SSE connections). // Unauthenticated requests receive a 401 JSON response. func Middleware(la *LocalAuth) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenStr := ExtractToken(r) if tokenStr == "" { http.Error(w, `{"success":false,"error":"authentication required"}`, http.StatusUnauthorized) return } claims, err := la.ValidateToken(tokenStr) if err != nil { http.Error(w, `{"success":false,"error":"invalid or expired token"}`, http.StatusUnauthorized) return } if la.IsRevoked(tokenStr) { http.Error(w, `{"success":false,"error":"token has been revoked"}`, http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), claimsKey, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // AdminOnly returns an HTTP middleware that requires the authenticated user to have // the "admin" role. func AdminOnly(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims, ok := ClaimsFromContext(r.Context()) if !ok || claims.Role != "admin" { http.Error(w, `{"success":false,"error":"admin access required"}`, http.StatusForbidden) return } next.ServeHTTP(w, r) }) } // ClaimsFromContext retrieves the authenticated user's claims from the request context. func ClaimsFromContext(ctx context.Context) (Claims, bool) { claims, ok := ctx.Value(claimsKey).(Claims) return claims, ok } // ExtractToken gets the JWT from the Authorization header or "token" query param. func ExtractToken(r *http.Request) string { // Try Authorization: Bearer authHeader := r.Header.Get("Authorization") if strings.HasPrefix(authHeader, "Bearer ") { return strings.TrimPrefix(authHeader, "Bearer ") } // Fall back to query parameter (used by SSE and browser-based connections). return r.URL.Query().Get("token") }