DEV Community

Cover image for Fast API, JWT auth module
Saad Alkentar
Saad Alkentar

Posted on

Fast API, JWT auth module

Authentication and authorization are key features in almost all web applications. How to do it with Fast API?

Actually, there is a complete section of documentation covering security with JWT tokens. But if you try to follow the documentation, you will find them using pydantic models. No database integration and without an actual structure.

In this article, I'll cover the basic authentication using JWT tokens with sqlmodel, a modular structure, and register, login, refresh, change-password, and me endpoints.

I'll build on the Franky project structure we started in the previous article, please refer to it for a more detailed description of our project structure and used libs.

JWT authentication required libraries

We will need two libraries, PyJWT for token encryption and decryption, and pwdlib for user password hashing.

uv add pyjwt
uv add 'pwdlib[argon2]'
Enter fullscreen mode Exit fullscreen mode

Don't ignore the '' in pwlib. The recommended hashing algorithm is 'Argon2'.
We shouldn't need any other libraries. Let's continue with the module structure.

Auth Module structure

We will follow the same module structure used in the previous article

--src

----auth
------__init__.py
------dependencies.py
------models.py
------router.py
------service.py
...
Enter fullscreen mode Exit fullscreen mode

And as always, we start with the models. Then we build the service, DI, and finally the routes.

Auth Models

This is where we set the models for the database, requests, and responses. So, what do we need?

  • User model: for the database, with read and create versions for register and profile APIs.
  • Login request: to log in.
  • Refresh request: to refresh the access token.
  • Change password request.
  • Token pair imprint for the login response.

src/auth/models.py

from datetime import datetime, UTC
from typing import Optional

from sqlmodel import SQLModel, Field


class UserBase(SQLModel):
    username: str = Field(min_length=3, max_length=255, unique=True, index=True)
    disabled: bool = Field(default=False)


class User(UserBase, table=True):
    __tablename__ = "user"

    id: Optional[int] = Field(default=None, primary_key=True)
    hashed_password: str = Field(max_length=255)
    created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
    updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))


class UserRead(SQLModel):
    id: int
    username: str
    disabled: bool
    created_at: datetime


class UserCreate(SQLModel):
    username: str = Field(min_length=3, max_length=255)
    password: str = Field(min_length=8, max_length=128)


class LoginRequest(SQLModel):
    username: str = Field(min_length=3, max_length=255)
    password: str = Field(min_length=1, max_length=128)


class ChangePasswordRequest(SQLModel):
    current_password: str = Field(min_length=1, max_length=128)
    new_password: str = Field(min_length=8, max_length=128)


class RefreshRequest(SQLModel):
    refresh_token: str


class TokenPair(SQLModel):
    access_token: str
    refresh_token: str
    token_type: str = "bearer"


class TokenData(SQLModel):
    username: Optional[str] = None
    user_id: Optional[int] = None

Enter fullscreen mode Exit fullscreen mode

We added an updated_at field to the user model to revoke refresh tokens when needed. The other approach will require creating a tokens table, which we will try to avoid.

Auth Service

The auth service should include register, login, and change-password logic with the database queries and operations.

Register

Let's start by adding the register functionality. It is simple enough; we start by making sure there aren't any users with the same name. Then we hash the password for the DB.

src/auth/service.py

from fastapi import HTTPException, status

from pwdlib import PasswordHash
from sqlalchemy.ext.asyncio import AsyncSession
from src.auth.models import (
    TokenPair,
    User,
    UserCreate,
)


_password_hash = PasswordHash.recommended()


class AuthService:
    def __init__(self, session: AsyncSession) -> None:
        self.session = session

    def _hash_password(self, password: str) -> str:
        return _password_hash.hash(password)

    async def get_user_by_username(self, username: str) -> Optional[User]:
        result = await self.session.execute(
            select(User).where(User.username == username)
        )
        return result.scalar_one_or_none()

    async def register(self, data: UserCreate) -> User:
        existing = await self.get_user_by_username(data.username)
        if existing is not None:
            raise HTTPException(
                status_code=status.HTTP_409_CONFLICT,
                detail="Username already taken",
            )
        user = User(
            username=data.username,
            hashed_password=self._hash_password(data.password),
        )
        self.session.add(user)
        await self.session.commit()
        await self.session.refresh(user)
        return user

Enter fullscreen mode Exit fullscreen mode

Login

Now for the complex part. There are multiple points to cover

  • Make sure there is a user with the provided username.
  • Make sure to match the password
  • Make sure the user is enabled
  • Encode the access token with username and ID.
  • Encode the refresh token.

src/auth/service.py

#...
from src.core.config import config
import jwt

ACCESS_TOKEN_TYPE = "access"
REFRESH_TOKEN_TYPE = "refresh"

class AuthService:
    #...
    async def login(self, username: str, password: str) -> TokenPair:
        user = await self.authenticate(username, password)
        return self._issue_token_pair(user)

    def _verify_password(self, plain_password: str, hashed_password: str) -> bool:
        return _password_hash.verify(plain_password, hashed_password)

    async def authenticate(self, username: str, password: str) -> User:
        user = await self.get_user_by_username(username)
        if user is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Invalid username or password",
                headers={"WWW-Authenticate": "Bearer"},
            )
        if not self._verify_password(password, user.hashed_password):
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Invalid username or password",
                headers={"WWW-Authenticate": "Bearer"},
            )
        if user.disabled:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail="User account is disabled",
            )
        return user

    def _issue_token_pair(self, user: User) -> TokenPair:
        return TokenPair(
            access_token=self._encode_access_token(user),
            refresh_token=self._encode_refresh_token(user),
        )

    def _encode_access_token(self, user: User) -> str:
        now = datetime.now(timezone.utc)
        payload = {
            "sub": user.username,
            "uid": user.id,
            "type": ACCESS_TOKEN_TYPE,
            "iat": now.timestamp(),
            "exp": int(
                (now + timedelta(minutes=config.jwt_access_exp_minutes)).timestamp()
            ),
        }
        return jwt.encode(
            payload, config.jwt_secret_key, algorithm=config.jwt_algorithm
        )

    def _encode_refresh_token(self, user: User) -> str:
        now = datetime.now(timezone.utc)
        payload = {
            "sub": user.username,
            "uid": user.id,
            "type": REFRESH_TOKEN_TYPE,
            "jti": uuid.uuid4().hex,
            "iat": now.timestamp(),
            "exp": int(
                (now + timedelta(days=config.jwt_refresh_exp_days)).timestamp()
            ),
        }
        return jwt.encode(
            payload, config.jwt_secret_key, algorithm=config.jwt_algorithm
        )
Enter fullscreen mode Exit fullscreen mode

and setting the jwt config at the core app config like

src/core/config.py

class Config(BaseSettings):
    #...

    jwt_secret_key: str = os.getenv(
        "JWT_SECRET_KEY",
        "change-me-in-production-please-use-a-secure-secret-key",
    )
    jwt_algorithm: str = os.getenv("JWT_ALGORITHM", "HS256")
    jwt_access_exp_minutes: int = int(os.getenv("JWT_ACCESS_EXP_MINUTES", "30"))
    jwt_refresh_exp_days: int = int(os.getenv("JWT_REFRESH_EXP_DAYS", "7"))

Enter fullscreen mode Exit fullscreen mode

then update the .env to add JWT_SECRET_KEY, JWT_ALGORITHM, JWT_ACCESS_EXP_MINUTES, JWT_REFRESH_EXP_DAYS.

You can generate a random secret key using

 openssl rand -hex 32
Enter fullscreen mode Exit fullscreen mode

In the token, we encoded the username, user ID, the issue timestamp, and the expiration date.
The refresh token has almost the same fields, but with an ID (jti)
Best practices indicate that we should keep a table for refresh tokens, to be able to revoke them if needed, but for simplicity, we are not doing that her; this field should include the refresh token ID in that table.
We will use the user's updated_at field to invalidate tokens. If the user was updated after issuing the token (user.updated_at > token.iat), then we will consider the refresh token invalid. This is a simple solution, yet good enough for most cases.

Refresh token

To refresh the token, we need to

  • Decode the token
  • Get the user model
  • Make sure the user is not disabled
  • Ensure the refresh token is valid (user.updated_at > token.iat)
  • Then issue the new token pair

src/auth/service.py

#...

def _to_utc(dt: datetime) -> datetime:
    if dt.tzinfo is None:
        return dt.replace(tzinfo=timezone.utc)
    return dt.astimezone(timezone.utc)


class AuthService:
    #...
    def _decode_token(self, token: str, expected_type: str) -> dict:
        credentials_exception = HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Could not validate credentials",
            headers={"WWW-Authenticate": "Bearer"},
        )
        try:
            payload = jwt.decode(
                token,
                config.jwt_secret_key,
                algorithms=[config.jwt_algorithm],
            )
        except InvalidTokenError:
            raise credentials_exception
        if payload.get("type") != expected_type:
            raise credentials_exception
        return payload

    def _ensure_token_not_invalidated(self, payload: dict, user: User) -> None:
        token_iat = payload.get("iat")
        if token_iat is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Could not validate credentials",
                headers={"WWW-Authenticate": "Bearer"},
            )
        invalid_before = _to_utc(user.updated_at).timestamp()
        if float(token_iat) < invalid_before:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Token has been invalidated, please log in again",
                headers={"WWW-Authenticate": "Bearer"},
            )

    async def get_user_by_id(self, user_id: int) -> Optional[User]:
        return await self.session.get(User, user_id)

    async def refresh(self, refresh_token: str) -> TokenPair:
        payload = self._decode_token(refresh_token, REFRESH_TOKEN_TYPE)
        user_id = payload.get("uid")
        if not user_id:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Invalid refresh token",
            )

        user = await self.get_user_by_id(user_id)
        if user is None or user.disabled:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="User no longer active",
            )

        self._ensure_token_not_invalidated(payload, user)

        return self._issue_token_pair(user)

Enter fullscreen mode Exit fullscreen mode

We created a small helper function called _to_utc to avoid any locale time issues.

Change password

To change the password, we need to

  • Verify password
  • Hash the new password
  • Update user.updated_at field to invalidate all refresh tokens
#...
class AuthService:
    #...
    async def change_password(
        self, user: User, current_password: str, new_password: str
    ) -> None:
        if not self._verify_password(current_password, user.hashed_password):
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="Current password is incorrect",
            )
        if current_password == new_password:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="New password must differ from current password",
            )
        user.hashed_password = self._hash_password(new_password)

        user.updated_at = datetime.now(timezone.utc)
        self.session.add(user)
        await self.session.commit()

Enter fullscreen mode Exit fullscreen mode

Get current user from the token

This is a helper function to get the active user from the token

#...
class AuthService:
    #...
    async def get_user_from_access_token(self, token: str) -> User:
        payload = self._decode_token(token, ACCESS_TOKEN_TYPE)
        user_id = payload.get("uid")
        if user_id is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Invalid access token",
                headers={"WWW-Authenticate": "Bearer"},
            )
        user = await self.get_user_by_id(user_id)
        if user is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="User not found",
                headers={"WWW-Authenticate": "Bearer"},
            )
        if user.disabled:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail="User account is disabled",
            )
        self._ensure_token_not_invalidated(payload, user)
        return user
Enter fullscreen mode Exit fullscreen mode

Great, we are done with the service

Auth Dependencies

We want to be able to secure any endpoint easily using our auth module. That is why we will add CurrentUserDep in addition to the classic AuthServiceDep

src/auth/dependencies.py

from typing import Annotated

from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer

from src.auth.models import User
from src.auth.service import AuthService
from src.core.dependencies import SessionDep


oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")


def get_auth_service(session: SessionDep) -> AuthService:
    return AuthService(session)


AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)]


async def get_current_user(
    service: AuthServiceDep,
    token: Annotated[str, Depends(oauth2_scheme)],
) -> User:
    return await service.get_user_from_access_token(token)


CurrentUserDep = Annotated[User, Depends(get_current_user)]
Enter fullscreen mode Exit fullscreen mode

Auth Routes

And the final point of our module is the routes for the endpoints

src/auth/router.py

from typing import Annotated

from fastapi import APIRouter, Depends, status
from fastapi.security import OAuth2PasswordRequestForm

from src.auth.dependencies import AuthServiceDep, CurrentUserDep
from src.auth.models import (
    ChangePasswordRequest,
    RefreshRequest,
    TokenPair,
    UserCreate,
    UserRead,
)


router = APIRouter(prefix="/auth", tags=["auth"])


@router.post("/register", response_model=UserRead, status_code=status.HTTP_201_CREATED)
async def register(data: UserCreate, service: AuthServiceDep) -> UserRead:
    user = await service.register(data)
    return UserRead.model_validate(user, from_attributes=True)


@router.post("/login", response_model=TokenPair)
async def login(
    form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
    service: AuthServiceDep,
) -> TokenPair:
    return await service.login(form_data.username, form_data.password)


@router.post("/refresh", response_model=TokenPair)
async def refresh(data: RefreshRequest, service: AuthServiceDep) -> TokenPair:
    return await service.refresh(data.refresh_token)


@router.post("/change-password", status_code=status.HTTP_204_NO_CONTENT)
async def change_password(
    data: ChangePasswordRequest,
    service: AuthServiceDep,
    current_user: CurrentUserDep,
) -> None:
    await service.change_password(
        current_user, data.current_password, data.new_password
    )


@router.get("/me", response_model=UserRead)
async def me(current_user: CurrentUserDep) -> UserRead:
    return UserRead.model_validate(current_user, from_attributes=True)

Enter fullscreen mode Exit fullscreen mode

We are simply using the service to do the job.

Main app

To register our new routes in the main app

main.py

#...
app.include_router(auth_router)

Enter fullscreen mode Exit fullscreen mode

That is it!

Secure an endpoint

If we want to secure a certain endpoint in the app, all we need to do is add the current_user: CurrentUserDep as a parameter for its route.

for instance

src/appointments/router.py

@router.get("/", response_model=list[AppointmentRead])
async def list_appointments(
    current_user: CurrentUserDep, #new
    service: AppointmentServiceDep,
    offset: int = Query(0, ge=0),
    limit: int = Query(20, ge=1, le=100),
):
    return await service.list(offset=offset, limit=limit)
Enter fullscreen mode Exit fullscreen mode

Adding the current_user will trigger the auth logic, and that will secure it.

Top comments (0)