"""Authentication middleware and utilities.""" import secrets from contextvars import ContextVar from typing import Optional from fastapi import Depends, HTTPException, Query, Request, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from .config import settings security = HTTPBearer(auto_error=False) # Context variable to store current request's token label token_label_var: ContextVar[str] = ContextVar("token_label", default="unknown") def get_token_label(token: str) -> Optional[str]: """Get the label for a token. Returns None if token is invalid. Args: token: The token to look up Returns: The label for the token, or None if invalid """ for label, stored_token in settings.api_tokens.items(): if secrets.compare_digest(stored_token, token): return label return None async def verify_token( request: Request, credentials: HTTPAuthorizationCredentials = Depends(security), ) -> str: """Verify the API token from the Authorization header. Args: request: The incoming request credentials: The bearer token credentials Returns: The token label Raises: HTTPException: If the token is missing or invalid """ if credentials is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing authentication token", headers={"WWW-Authenticate": "Bearer"}, ) label = get_token_label(credentials.credentials) if label is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication token", headers={"WWW-Authenticate": "Bearer"}, ) # Set label in context for logging token_label_var.set(label) return label class TokenAuth: """Dependency class for token authentication.""" def __init__(self, auto_error: bool = True): self.auto_error = auto_error async def __call__( self, request: Request, credentials: HTTPAuthorizationCredentials = Depends(security), ) -> str | None: """Verify the token and return the label or raise an exception.""" if credentials is None: if self.auto_error: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing authentication token", headers={"WWW-Authenticate": "Bearer"}, ) return None label = get_token_label(credentials.credentials) if label is None: if self.auto_error: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication token", headers={"WWW-Authenticate": "Bearer"}, ) return None # Set label in context for logging token_label_var.set(label) return label async def verify_token_or_query( credentials: HTTPAuthorizationCredentials = Depends(security), token: Optional[str] = Query(None, description="API token as query parameter"), ) -> str: """Verify the API token from header or query parameter. Useful for endpoints that need to be accessed via URL (like images). Args: credentials: The bearer token credentials from header token: Token from query parameter Returns: The token label Raises: HTTPException: If the token is missing or invalid """ label = None # Try header first if credentials is not None: label = get_token_label(credentials.credentials) # Try query parameter if label is None and token is not None: label = get_token_label(token) if label is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing or invalid authentication token", headers={"WWW-Authenticate": "Bearer"}, ) # Set label in context for logging token_label_var.set(label) return label