from datetime import datetime, timedelta, timezone from authlib.integrations.httpx_client import AsyncOAuth2Client from fastapi import HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.core.security import create_access_token, generate_refresh_token, hash_refresh_token from app.models.session import Session from app.models.user import User # --- Provider configs --- PROVIDERS = { "google": { "authorize_url": "https://accounts.google.com/o/oauth2/v2/auth", "token_url": "https://oauth2.googleapis.com/token", "userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo", "scope": "openid email profile", }, } def _get_authentik_config(): base = settings.AUTHENTIK_BASE_URL.rstrip("/") return { "authorize_url": f"{base}/application/o/authorize/", "token_url": f"{base}/application/o/token/", "userinfo_url": f"{base}/application/o/userinfo/", "scope": "openid email profile", } def _get_provider_config(provider: str) -> dict: if provider == "google": return PROVIDERS["google"] elif provider == "authentik": if not settings.AUTHENTIK_BASE_URL: raise HTTPException(status_code=400, detail="Authentik not configured") return _get_authentik_config() raise HTTPException(status_code=400, detail=f"Unsupported OAuth provider: {provider}") def _get_client_credentials(provider: str) -> tuple[str, str, str]: if provider == "google": return settings.GOOGLE_CLIENT_ID, settings.GOOGLE_CLIENT_SECRET, settings.GOOGLE_REDIRECT_URI elif provider == "authentik": return settings.AUTHENTIK_CLIENT_ID, settings.AUTHENTIK_CLIENT_SECRET, settings.AUTHENTIK_REDIRECT_URI raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}") def _get_client(provider: str) -> AsyncOAuth2Client: config = _get_provider_config(provider) client_id, client_secret, redirect_uri = _get_client_credentials(provider) return AsyncOAuth2Client( client_id=client_id, client_secret=client_secret, redirect_uri=redirect_uri, scope=config["scope"], ) async def get_authorize_url(provider: str) -> str: config = _get_provider_config(provider) client = _get_client(provider) url, _ = client.create_authorization_url(config["authorize_url"]) return url async def handle_callback( provider: str, code: str, db: AsyncSession, ip_address: str | None = None, device_info: str | None = None, ) -> dict: """Exchange code, get user info, create/link user, return auth tokens.""" config = _get_provider_config(provider) client = _get_client(provider) try: await client.fetch_token(config["token_url"], code=code) except Exception: raise HTTPException(status_code=400, detail="Failed to exchange OAuth code") resp = await client.get(config["userinfo_url"]) if resp.status_code != 200: raise HTTPException(status_code=400, detail=f"Failed to get user info from {provider}") userinfo = resp.json() email = userinfo.get("email") name = userinfo.get("name") or userinfo.get("preferred_username") picture = userinfo.get("picture") provider_id = userinfo.get("sub") if not email: raise HTTPException(status_code=400, detail=f"{provider} account has no email") # Find or create user result = await db.execute( select(User).where(User.oauth_provider == provider, User.oauth_provider_id == provider_id) ) user = result.scalar_one_or_none() if not user: result = await db.execute(select(User).where(User.email == email)) user = result.scalar_one_or_none() if user: user.oauth_provider = provider user.oauth_provider_id = provider_id if picture: user.avatar_url = picture else: username = email.split("@")[0] base = username counter = 1 while True: result = await db.execute(select(User).where(User.username == username)) if not result.scalar_one_or_none(): break username = f"{base}{counter}" counter += 1 user = User( email=email, username=username, hashed_password=None, full_name=name, oauth_provider=provider, oauth_provider_id=provider_id, avatar_url=picture, ) db.add(user) await db.flush() if not user.is_active: raise HTTPException(status_code=403, detail="Account is deactivated") access_token = create_access_token(user.id, user.role) refresh_token = generate_refresh_token() expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) session = Session( user_id=user.id, refresh_token_hash=hash_refresh_token(refresh_token), device_info=device_info, ip_address=ip_address, expires_at=expires_at, ) db.add(session) await db.flush() return {"user": user, "access_token": access_token, "refresh_token": refresh_token}