import uuid from datetime import datetime, timedelta, timezone from fastapi import HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.core.security import ( create_access_token, generate_refresh_token, hash_password, hash_refresh_token, verify_password, ) from app.models.session import Session from app.models.user import User from app.schemas.auth import AuthResponse, RegisterRequest, TokenResponse, UserResponse async def _create_session( db: AsyncSession, user: User, remember_me: bool, ip_address: str | None = None, device_info: str | None = None, ) -> tuple[str, str]: access_token = create_access_token(user.id, user.role) refresh_token = generate_refresh_token() if remember_me: expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS) else: expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.REFRESH_TOKEN_EXPIRE_HOURS) session = Session( user_id=user.id, refresh_token_hash=hash_refresh_token(refresh_token), device_info=device_info, ip_address=ip_address, expires_at=expires_at, ) db.add(session) await db.flush() return access_token, refresh_token async def register_user( db: AsyncSession, data: RegisterRequest, ip_address: str | None = None, device_info: str | None = None, ) -> AuthResponse: existing = await db.execute( select(User).where((User.email == data.email) | (User.username == data.username)) ) if existing.scalar_one_or_none(): raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="User with this email or username already exists", ) user = User( email=data.email, username=data.username, hashed_password=hash_password(data.password), full_name=data.full_name, ) db.add(user) await db.flush() access_token, refresh_token = await _create_session(db, user, remember_me=False, ip_address=ip_address, device_info=device_info) return AuthResponse( user=UserResponse.model_validate(user), access_token=access_token, refresh_token=refresh_token, ) async def login_user( db: AsyncSession, email: str, password: str, remember_me: bool = False, ip_address: str | None = None, device_info: str | None = None, ) -> AuthResponse: result = await db.execute(select(User).where(User.email == email)) user = result.scalar_one_or_none() if not user or not verify_password(password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password", ) if not user.is_active: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Account is deactivated", ) access_token, refresh_token = await _create_session(db, user, remember_me, ip_address, device_info) return AuthResponse( user=UserResponse.model_validate(user), access_token=access_token, refresh_token=refresh_token, ) async def refresh_tokens(db: AsyncSession, refresh_token: str) -> TokenResponse: token_hash = hash_refresh_token(refresh_token) result = await db.execute( select(Session).where(Session.refresh_token_hash == token_hash) ) session = result.scalar_one_or_none() if not session: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", ) if session.expires_at < datetime.now(timezone.utc): await db.delete(session) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expired", ) result = await db.execute(select(User).where(User.id == session.user_id)) user = result.scalar_one_or_none() if not user or not user.is_active: await db.delete(session) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive", ) # Rotate refresh token new_refresh_token = generate_refresh_token() session.refresh_token_hash = hash_refresh_token(new_refresh_token) new_access_token = create_access_token(user.id, user.role) return TokenResponse( access_token=new_access_token, refresh_token=new_refresh_token, ) async def logout_user(db: AsyncSession, refresh_token: str) -> None: token_hash = hash_refresh_token(refresh_token) result = await db.execute( select(Session).where(Session.refresh_token_hash == token_hash) ) session = result.scalar_one_or_none() if session: await db.delete(session)