Phase 9: OAuth & Account Switching — Google + Authentik, multi-account
Backend: - OAuth service with pluggable provider architecture (Google + Authentik) - Generic authorize/callback endpoints for any provider - Authentik OIDC integration (configurable base URL) - hashed_password made nullable for OAuth-only users - Migration 009: nullable password column - /auth/switch endpoint returns full AuthResponse for account switching - OAuth-only users get clear error on password login attempt - UserResponse includes oauth_provider + avatar_url Frontend: - OAuth buttons on login form (Google + Authentik) - OAuth callback handler (/auth/callback route) - Multi-account auth store (accounts array, addAccount, switchTo, removeAccount) - Account switcher dropdown in header (hover to see other accounts) - "Add another account" option - English + Russian translations Config: - GOOGLE_CLIENT_ID/SECRET/REDIRECT_URI - AUTHENTIK_CLIENT_ID/SECRET/BASE_URL/REDIRECT_URI Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
24
backend/alembic/versions/009_oauth_nullable_password.py
Normal file
24
backend/alembic/versions/009_oauth_nullable_password.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Make hashed_password nullable for OAuth users
|
||||
|
||||
Revision ID: 009
|
||||
Revises: 008
|
||||
Create Date: 2026-03-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "009"
|
||||
down_revision: Union[str, None] = "008"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("users", "hashed_password", nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("users", "hashed_password", nullable=False)
|
||||
@@ -77,3 +77,53 @@ async def logout(
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def me(user: Annotated[User, Depends(get_current_user)]):
|
||||
return UserResponse.model_validate(user)
|
||||
|
||||
|
||||
# --- OAuth ---
|
||||
|
||||
@router.get("/oauth/{provider}/authorize")
|
||||
async def oauth_authorize(provider: str):
|
||||
from app.services.oauth_service import get_authorize_url
|
||||
url = await get_authorize_url(provider)
|
||||
return {"authorize_url": url}
|
||||
|
||||
|
||||
@router.get("/oauth/{provider}/callback")
|
||||
async def oauth_callback(
|
||||
provider: str,
|
||||
code: str,
|
||||
request: Request,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
from app.services.oauth_service import handle_callback
|
||||
result = await handle_callback(
|
||||
provider, code, db,
|
||||
ip_address=request.client.host if request.client else None,
|
||||
device_info=request.headers.get("user-agent"),
|
||||
)
|
||||
# Redirect to frontend with tokens
|
||||
from fastapi.responses import RedirectResponse
|
||||
redirect_url = f"/auth/callback?access_token={result['access_token']}&refresh_token={result['refresh_token']}"
|
||||
return RedirectResponse(url=redirect_url)
|
||||
|
||||
|
||||
@router.post("/switch", response_model=AuthResponse)
|
||||
async def switch_account(
|
||||
data: RefreshRequest,
|
||||
db: Annotated[AsyncSession, Depends(get_db)],
|
||||
):
|
||||
"""Switch to another account using its refresh token. Returns full AuthResponse."""
|
||||
tokens = await auth_service.refresh_tokens(db, data.refresh_token)
|
||||
# Get user from new access token
|
||||
from app.core.security import decode_access_token
|
||||
import uuid as uuid_mod
|
||||
payload = decode_access_token(tokens.access_token)
|
||||
user_id = uuid_mod.UUID(payload["sub"])
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one()
|
||||
return AuthResponse(
|
||||
user=UserResponse.model_validate(user),
|
||||
access_token=tokens.access_token,
|
||||
refresh_token=tokens.refresh_token,
|
||||
)
|
||||
|
||||
@@ -23,6 +23,15 @@ class Settings(BaseSettings):
|
||||
RATE_LIMIT_REQUESTS: int = 20
|
||||
RATE_LIMIT_WINDOW_SECONDS: int = 60
|
||||
|
||||
GOOGLE_CLIENT_ID: str = ""
|
||||
GOOGLE_CLIENT_SECRET: str = ""
|
||||
GOOGLE_REDIRECT_URI: str = "http://localhost/api/v1/auth/oauth/google/callback"
|
||||
|
||||
AUTHENTIK_CLIENT_ID: str = ""
|
||||
AUTHENTIK_CLIENT_SECRET: str = ""
|
||||
AUTHENTIK_BASE_URL: str = "" # e.g. https://auth.example.com
|
||||
AUTHENTIK_REDIRECT_URI: str = "http://localhost/api/v1/auth/oauth/authentik/callback"
|
||||
|
||||
FIRST_ADMIN_EMAIL: str = "admin@example.com"
|
||||
FIRST_ADMIN_USERNAME: str = "admin"
|
||||
FIRST_ADMIN_PASSWORD: str = "changeme_admin_password"
|
||||
|
||||
@@ -11,7 +11,7 @@ class User(Base):
|
||||
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, index=True, nullable=False)
|
||||
username: Mapped[str] = mapped_column(String(100), unique=True, index=True, nullable=False)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
hashed_password: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
full_name: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
role: Mapped[str] = mapped_column(String(20), nullable=False, default="user")
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
@@ -29,6 +29,8 @@ class UserResponse(BaseModel):
|
||||
role: str
|
||||
is_active: bool
|
||||
max_chats: int
|
||||
oauth_provider: str | None = None
|
||||
avatar_url: str | None = None
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
@@ -94,10 +94,10 @@ async def login_user(
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not verify_password(password, user.hashed_password):
|
||||
if not user or not user.hashed_password or not verify_password(password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
detail="Invalid email or password" if not user or user.hashed_password else "This account uses OAuth login. Please sign in with your provider.",
|
||||
)
|
||||
|
||||
if not user.is_active:
|
||||
|
||||
152
backend/app/services/oauth_service.py
Normal file
152
backend/app/services/oauth_service.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from authlib.integrations.httpx_client import AsyncOAuth2Client
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.core.security import create_access_token, generate_refresh_token, hash_refresh_token
|
||||
from app.models.session import Session
|
||||
from app.models.user import User
|
||||
|
||||
# --- Provider configs ---
|
||||
|
||||
PROVIDERS = {
|
||||
"google": {
|
||||
"authorize_url": "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"token_url": "https://oauth2.googleapis.com/token",
|
||||
"userinfo_url": "https://www.googleapis.com/oauth2/v3/userinfo",
|
||||
"scope": "openid email profile",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_authentik_config():
|
||||
base = settings.AUTHENTIK_BASE_URL.rstrip("/")
|
||||
return {
|
||||
"authorize_url": f"{base}/application/o/authorize/",
|
||||
"token_url": f"{base}/application/o/token/",
|
||||
"userinfo_url": f"{base}/application/o/userinfo/",
|
||||
"scope": "openid email profile",
|
||||
}
|
||||
|
||||
|
||||
def _get_provider_config(provider: str) -> dict:
|
||||
if provider == "google":
|
||||
return PROVIDERS["google"]
|
||||
elif provider == "authentik":
|
||||
if not settings.AUTHENTIK_BASE_URL:
|
||||
raise HTTPException(status_code=400, detail="Authentik not configured")
|
||||
return _get_authentik_config()
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported OAuth provider: {provider}")
|
||||
|
||||
|
||||
def _get_client_credentials(provider: str) -> tuple[str, str, str]:
|
||||
if provider == "google":
|
||||
return settings.GOOGLE_CLIENT_ID, settings.GOOGLE_CLIENT_SECRET, settings.GOOGLE_REDIRECT_URI
|
||||
elif provider == "authentik":
|
||||
return settings.AUTHENTIK_CLIENT_ID, settings.AUTHENTIK_CLIENT_SECRET, settings.AUTHENTIK_REDIRECT_URI
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported provider: {provider}")
|
||||
|
||||
|
||||
def _get_client(provider: str) -> AsyncOAuth2Client:
|
||||
config = _get_provider_config(provider)
|
||||
client_id, client_secret, redirect_uri = _get_client_credentials(provider)
|
||||
return AsyncOAuth2Client(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=redirect_uri,
|
||||
scope=config["scope"],
|
||||
)
|
||||
|
||||
|
||||
async def get_authorize_url(provider: str) -> str:
|
||||
config = _get_provider_config(provider)
|
||||
client = _get_client(provider)
|
||||
url, _ = client.create_authorization_url(config["authorize_url"])
|
||||
return url
|
||||
|
||||
|
||||
async def handle_callback(
|
||||
provider: str, code: str, db: AsyncSession,
|
||||
ip_address: str | None = None, device_info: str | None = None,
|
||||
) -> dict:
|
||||
"""Exchange code, get user info, create/link user, return auth tokens."""
|
||||
config = _get_provider_config(provider)
|
||||
client = _get_client(provider)
|
||||
|
||||
try:
|
||||
await client.fetch_token(config["token_url"], code=code)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Failed to exchange OAuth code")
|
||||
|
||||
resp = await client.get(config["userinfo_url"])
|
||||
if resp.status_code != 200:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to get user info from {provider}")
|
||||
|
||||
userinfo = resp.json()
|
||||
email = userinfo.get("email")
|
||||
name = userinfo.get("name") or userinfo.get("preferred_username")
|
||||
picture = userinfo.get("picture")
|
||||
provider_id = userinfo.get("sub")
|
||||
|
||||
if not email:
|
||||
raise HTTPException(status_code=400, detail=f"{provider} account has no email")
|
||||
|
||||
# Find or create user
|
||||
result = await db.execute(
|
||||
select(User).where(User.oauth_provider == provider, User.oauth_provider_id == provider_id)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user:
|
||||
user.oauth_provider = provider
|
||||
user.oauth_provider_id = provider_id
|
||||
if picture:
|
||||
user.avatar_url = picture
|
||||
else:
|
||||
username = email.split("@")[0]
|
||||
base = username
|
||||
counter = 1
|
||||
while True:
|
||||
result = await db.execute(select(User).where(User.username == username))
|
||||
if not result.scalar_one_or_none():
|
||||
break
|
||||
username = f"{base}{counter}"
|
||||
counter += 1
|
||||
|
||||
user = User(
|
||||
email=email,
|
||||
username=username,
|
||||
hashed_password=None,
|
||||
full_name=name,
|
||||
oauth_provider=provider,
|
||||
oauth_provider_id=provider_id,
|
||||
avatar_url=picture,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(status_code=403, detail="Account is deactivated")
|
||||
|
||||
access_token = create_access_token(user.id, user.role)
|
||||
refresh_token = generate_refresh_token()
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
session = Session(
|
||||
user_id=user.id,
|
||||
refresh_token_hash=hash_refresh_token(refresh_token),
|
||||
device_info=device_info,
|
||||
ip_address=ip_address,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(session)
|
||||
await db.flush()
|
||||
|
||||
return {"user": user, "access_token": access_token, "refresh_token": refresh_token}
|
||||
@@ -22,6 +22,8 @@ dependencies = [
|
||||
"weasyprint>=62.0",
|
||||
"jinja2>=3.1.0",
|
||||
"python-json-logger>=2.0.0",
|
||||
"authlib>=1.3.0",
|
||||
"itsdangerous>=2.2.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
Reference in New Issue
Block a user