DEV Community

Emmanuel Onwuegbusi
Emmanuel Onwuegbusi

Posted on

Setup User Auth for your Reflex app using local_auth

In this article, we will configure user authentication for your reflex app by doing everything locally. The user will be able to register, login, and logout.

This article is based on the local_auth example on reflex_examples GitHub page: https://github.com/reflex-dev/reflex-examples/tree/main/local_auth

Outline

  • Create a new folder, open it with a code editor
  • Create a virtual environment and activate
  • Install requirements
  • reflex setup
  • local_auth.py
  • auth_session.py
  • base_state.py
  • login.py
  • registration.py
  • user.py
  • run app
  • conclusion

Create a new folder, open it with a code editor

Create a new folder and name it local_auth then open it with a code editor like VS Code.

Create a virtual environment and activate

Open the terminal. Use the following command to create a virtual environment .venv and activate it:

python3 -m venv .venv
Enter fullscreen mode Exit fullscreen mode
source .venv/bin/activate
Enter fullscreen mode Exit fullscreen mode

Install requirements

We will install reflex to build the app, passlib to simplify the process of securely hashing and managing passwords, and bcrypt- the specific hashing algorithm used to hash passwords securely.
Run the following command in the terminal:

pip install reflex==0.2.9 passlib==1.7.4 bcrypt==4.0.1
Enter fullscreen mode Exit fullscreen mode

reflex setup

Now, we need to create the project using reflex. Run the following command to initialize the template app in local_auth directory.

reflex init 
Enter fullscreen mode Exit fullscreen mode

local_auth.py

We will build the homepage of the app. Go to the local_auth subdirectory and open the local_auth.py file. Add the following code to it:

"""Main app module to demo local authentication."""
import reflex as rx

from .base_state import State
from .login import require_login
from .registration import registration_page as registration_page


def index() -> rx.Component:
    """Render the index page.

    Returns:
        A reflex component.
    """
    return rx.fragment(
        rx.color_mode_button(rx.color_mode_icon(), float="right"),
        rx.vstack(
            rx.heading("Welcome to my homepage!", font_size="2em"),
            rx.link("Protected Page", href="/protected"),
            spacing="1.5em",
            padding_top="10%",
        ),
    )


@require_login
def protected() -> rx.Component:
    """Render a protected page.

    The `require_login` decorator will redirect to the login page if the user is
    not authenticated.

    Returns:
        A reflex component.
    """
    return rx.vstack(
        rx.heading(
            "Protected Page for ", State.authenticated_user.username, font_size="2em"
        ),
        rx.link("Home", href="/"),
        rx.link("Logout", href="/", on_click=State.do_logout),
    )


app = rx.App()
app.add_page(index)
app.add_page(protected)
app.compile()
Enter fullscreen mode Exit fullscreen mode

index(): This function defines the behavior for rendering the application's index page. It returns a Reflex component, representing part of the web page's user interface. The index page includes a color mode button, a greeting message, a link to the protected page, and styling for spacing and padding.

protected(): This function is responsible for rendering a protected page that requires user authentication to access. It is decorated with @require_login, ensuring that only authenticated users can view this page. The protected page displays a greeting message personalized for the authenticated user, links to the home page, and provides an option to log out.
The above code renders the following page:

localauthhomepage

auth_session.py

Create a new file auth_session.py in the local_auth subdirectory and add the following code.

import datetime

from sqlmodel import Column, DateTime, Field, func

import reflex as rx


class AuthSession(
    rx.Model,
    table=True,  # type: ignore
):
    """Correlate a session_id with an arbitrary user_id."""

    user_id: int = Field(index=True, nullable=False)
    session_id: str = Field(unique=True, index=True, nullable=False)
    expiration: datetime.datetime = Field(
        sa_column=Column(DateTime(timezone=True), server_default=func.now()),
        nullable=False,
    )
Enter fullscreen mode Exit fullscreen mode

In this code, an AuthSession class is defined. The purpose of this class is to manage authentication sessions.

The user_id attribute is defined as an integer field, marked as indexed, and non-nullable. It is intended to store the user ID associated with a session.

The session_id attribute is a string field, marked as unique, indexed, and non-nullable. This field ensures that each session has a distinct identifier.

The expiration attribute is of type datetime.datetime. It is defined using the Field class, associated with a SQL model column represented by a Column instance with a datetime type and timezone set to True. The server_default parameter is set to func.now(), which means the default value will be the current time when a new session is created. This field is also non-nullable, ensuring that each session has an expiration time.

base_state.py

Create a new file base_state.py in the local_auth subdirectory and add the following code:

import datetime

from sqlmodel import select

import reflex as rx

from .auth_session import AuthSession
from .user import User


AUTH_TOKEN_LOCAL_STORAGE_KEY = "_auth_tokens"
DEFAULT_AUTH_SESSION_EXPIRATION_DELTA = datetime.timedelta(days=7)


class State(rx.State):
    # The auth_token is stored in local storage to persist across tab and browser sessions.
    auth_token: str = rx.LocalStorage(name=AUTH_TOKEN_LOCAL_STORAGE_KEY)

    @rx.cached_var
    def authenticated_user(self) -> User:
        """The currently authenticated user, or a dummy user if not authenticated.

        Returns:
            A User instance with id=-1 if not authenticated, or the User instance
            corresponding to the currently authenticated user.
        """
        with rx.session() as session:
            result = session.exec(
                select(User, AuthSession).where(
                    AuthSession.session_id == self.auth_token,
                    AuthSession.expiration
                    >= datetime.datetime.now(datetime.timezone.utc),
                    User.id == AuthSession.user_id,
                ),
            ).first()
            if result:
                user, session = result
                return user
        return User(id=-1)  # type: ignore

    @rx.cached_var
    def is_authenticated(self) -> bool:
        """Whether the current user is authenticated.

        Returns:
            True if the authenticated user has a positive user ID, False otherwise.
        """
        return self.authenticated_user.id >= 0

    def do_logout(self) -> None:
        """Destroy AuthSessions associated with the auth_token."""
        with rx.session() as session:
            for auth_session in session.exec(
                AuthSession.select.where(AuthSession.session_id == self.auth_token)
            ).all():
                session.delete(auth_session)
            session.commit()
        self.auth_token = self.auth_token

    def _login(
        self,
        user_id: int,
        expiration_delta: datetime.timedelta = DEFAULT_AUTH_SESSION_EXPIRATION_DELTA,
    ) -> None:
        """Create an AuthSession for the given user_id.

        If the auth_token is already associated with an AuthSession, it will be
        logged out first.

        Args:
            user_id: The user ID to associate with the AuthSession.
            expiration_delta: The amount of time before the AuthSession expires.
        """
        if self.is_authenticated:
            self.do_logout()
        if user_id < 0:
            return
        self.auth_token = self.auth_token or self.get_token()
        with rx.session() as session:
            session.add(
                AuthSession(  # type: ignore
                    user_id=user_id,
                    session_id=self.auth_token,
                    expiration=datetime.datetime.now(datetime.timezone.utc)
                    + expiration_delta,
                )
            )
            session.commit()
Enter fullscreen mode Exit fullscreen mode

The above code defines a class called State, which extends the rx.State class. It includes several functions and properties related to authentication and user sessions.

auth_token: This property stores the authentication token in the local storage to persist it across different browser sessions.

authenticated_user(self): This function returns the currently authenticated user or a dummy user if not authenticated. It uses the rx.session() context manager to execute a SQL query that selects a User and AuthSession where the AuthSession matches the auth_token, has not expired, and corresponds to a user. If a result is found, it returns the user; otherwise, it returns a dummy user.

is_authenticated(self): This function returns a boolean indicating whether the current user is authenticated. It checks if the user's ID is greater than or equal to 0, and returns True if authenticated or False if not.

do_logout(self): This function is used to destroy AuthSessions associated with the auth_token. It begins a session and deletes all AuthSessions with a matching session_id, effectively logging the user out.

_login(self, user_id, expiration_delta): This is a private method used to create an AuthSession for a given user. If the user is already authenticated, it calls do_logout() to log out the current user. It then creates a new AuthSession with the provided user ID and sets an expiration time based on the expiration_delta (defaulting to 7 days). This new session is associated with the auth_token, which is generated if it doesn't exist. The new AuthSession is added to the database and committed within a session context.

login.py

Create a new file login.py in the local_auth subdirectory and add the following code:

"""Login page and authentication logic."""
import reflex as rx

from .base_state import State
from .user import User


LOGIN_ROUTE = "/login"
REGISTER_ROUTE = "/register"


class LoginState(State):
    """Handle login form submission and redirect to proper routes after authentication."""

    error_message: str = ""
    redirect_to: str = ""

    def on_submit(self, form_data) -> rx.event.EventSpec:
        """Handle login form on_submit.

        Args:
            form_data: A dict of form fields and values.
        """
        self.error_message = ""
        username = form_data["username"]
        password = form_data["password"]
        with rx.session() as session:
            user = session.exec(
                User.select.where(User.username == username)
            ).one_or_none()
        if user is not None and not user.enabled:
            self.error_message = "This account is disabled."
            return rx.set_value("password", "")
        if user is None or not user.verify(password):
            self.error_message = "There was a problem logging in, please try again."
            return rx.set_value("password", "")
        if (
            user is not None
            and user.id is not None
            and user.enabled
            and user.verify(password)
        ):
            # mark the user as logged in
            self._login(user.id)
        self.error_message = ""
        return LoginState.redir()  # type: ignore

    def redir(self) -> rx.event.EventSpec | None:
        """Redirect to the redirect_to route if logged in, or to the login page if not."""
        if not self.is_hydrated:
            # wait until after hydration to ensure auth_token is known
            return LoginState.redir()  # type: ignore
        page = self.get_current_page()
        if not self.is_authenticated and page != LOGIN_ROUTE:
            self.redirect_to = page
            return rx.redirect(LOGIN_ROUTE)
        elif page == LOGIN_ROUTE:
            return rx.redirect(self.redirect_to or "/")


@rx.page(route=LOGIN_ROUTE)
def login_page() -> rx.Component:
    """Render the login page.

    Returns:
        A reflex component.
    """
    login_form = rx.form(
        rx.input(placeholder="username", id="username"),
        rx.password(placeholder="password", id="password"),
        rx.button("Login", type_="submit"),
        width="80vw",
        on_submit=LoginState.on_submit,
    )

    return rx.fragment(
        rx.cond(
            LoginState.is_hydrated,  # type: ignore
            rx.vstack(
                rx.cond(  # conditionally show error messages
                    LoginState.error_message != "",
                    rx.text(LoginState.error_message),
                ),
                login_form,
                rx.link("Register", href=REGISTER_ROUTE),
                padding_top="10vh",
            ),
        )
    )


def require_login(page: rx.app.ComponentCallable) -> rx.app.ComponentCallable:
    """Decorator to require authentication before rendering a page.

    If the user is not authenticated, then redirect to the login page.

    Args:
        page: The page to wrap.

    Returns:
        The wrapped page component.
    """

    def protected_page():
        return rx.fragment(
            rx.cond(
                State.is_hydrated & State.is_authenticated,  # type: ignore
                page(),
                rx.center(
                    # When this spinner mounts, it will redirect to the login page
                    rx.spinner(on_mount=LoginState.redir),
                ),
            )
        )

    protected_page.__name__ = page.__name__
    return protected_page
Enter fullscreen mode Exit fullscreen mode

The above code defines a login page and authentication logic. It includes a LoginState class that handles login form submissions, checks user credentials, and manages redirection upon successful login. Additionally, it provides a decorator function called require_login to protect certain pages, ensuring they can only be accessed by authenticated users and redirecting unauthenticated users to the login page.
The above code renders the following page:

localauthlogin

registration.py

Create a new file registration.py in the local_auth subdirectory and add the following code:

"""New user registration form and validation logic."""
from __future__ import annotations

import asyncio
from collections.abc import AsyncGenerator

import reflex as rx

from .base_state import State
from .login import LOGIN_ROUTE, REGISTER_ROUTE
from .user import User


class RegistrationState(State):
    """Handle registration form submission and redirect to login page after registration."""

    success: bool = False
    error_message: str = ""

    async def handle_registration(
        self, form_data
    ) -> AsyncGenerator[rx.event.EventSpec | list[rx.event.EventSpec] | None, None]:
        """Handle registration form on_submit.

        Set error_message appropriately based on validation results.

        Args:
            form_data: A dict of form fields and values.
        """
        with rx.session() as session:
            username = form_data["username"]
            if not username:
                self.error_message = "Username cannot be empty"
                yield rx.set_focus("username")
                return
            existing_user = session.exec(
                User.select.where(User.username == username)
            ).one_or_none()
            if existing_user is not None:
                self.error_message = (
                    f"Username {username} is already registered. Try a different name"
                )
                yield [rx.set_value("username", ""), rx.set_focus("username")]
                return
            password = form_data["password"]
            if not password:
                self.error_message = "Password cannot be empty"
                yield rx.set_focus("password")
                return
            if password != form_data["confirm_password"]:
                self.error_message = "Passwords do not match"
                yield [
                    rx.set_value("confirm_password", ""),
                    rx.set_focus("confirm_password"),
                ]
                return
            # Create the new user and add it to the database.
            new_user = User()  # type: ignore
            new_user.username = username
            new_user.password_hash = User.hash_password(password)
            new_user.enabled = True
            session.add(new_user)
            session.commit()
        # Set success and redirect to login page after a brief delay.
        self.error_message = ""
        self.success = True
        yield
        await asyncio.sleep(0.5)
        yield [rx.redirect(LOGIN_ROUTE), RegistrationState.set_success(False)]


@rx.page(route=REGISTER_ROUTE)
def registration_page() -> rx.Component:
    """Render the registration page.

    Returns:
        A reflex component.
    """
    register_form = rx.form(
        rx.input(placeholder="username", id="username"),
        rx.password(placeholder="password", id="password"),
        rx.password(placeholder="confirm", id="confirm_password"),
        rx.button("Register", type_="submit"),
        width="80vw",
        on_submit=RegistrationState.handle_registration,
    )
    return rx.fragment(
        rx.cond(
            RegistrationState.success,
            rx.vstack(
                rx.text("Registration successful!"),
                rx.spinner(),
            ),
            rx.vstack(
                rx.cond(  # conditionally show error messages
                    RegistrationState.error_message != "",
                    rx.text(RegistrationState.error_message),
                ),
                register_form,
                padding_top="10vh",
            ),
        )
    )
Enter fullscreen mode Exit fullscreen mode

The above code is responsible for creating a user registration form and implementing the validation logic for user registration.

RegistrationState class handles user registration. It includes properties for tracking the success of the registration and error messages.

The handle_registration method asynchronously processes the registration form submission. It validates the provided username and password, checks for existing usernames in the database, ensures the passwords match, and then creates a new user and adds it to the database if all checks pass. After a brief delay, it sets the success flag and redirects to the login page.

The registration_page function is a reflex page that renders the user registration form. It includes form fields for the username, password, and password confirmation, as well as a registration button. The form submission is handled by the handle_registration method from RegistrationState.

The page dynamically displays different content based on the registration's success or failure. If registration is successful, it displays a success message and a spinner. If there are validation errors or the registration has not yet succeeded, it shows error messages, the registration form, and adds some padding for spacing.

The above code renders the following page:
localauthregister

user.py

Create a new file user.py in the local_auth subdirectory and add the following code:

from passlib.context import CryptContext
from sqlmodel import Field

import reflex as rx

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


class User(
    rx.Model,
    table=True,  # type: ignore
):
    """A local User model with bcrypt password hashing."""

    username: str = Field(unique=True, nullable=False, index=True)
    password_hash: str = Field(nullable=False)
    enabled: bool = False

    @staticmethod
    def hash_password(secret: str) -> str:
        """Hash the secret using bcrypt.

        Args:
            secret: The password to hash.

        Returns:
            The hashed password.
        """
        return pwd_context.hash(secret)

    def verify(self, secret: str) -> bool:
        """Validate the user's password.

        Args:
            secret: The password to check.

        Returns:
            True if the hashed secret matches this user's password_hash.
        """
        return pwd_context.verify(
            secret,
            self.password_hash,
        )
Enter fullscreen mode Exit fullscreen mode

The above code defines a User class for managing user data. It uses the bcrypt algorithm for securely hashing and verifying user passwords. The User class has fields for usernames, password hashes, and an "enabled" status, along with methods for hashing and verifying passwords using bcrypt. This code provides a foundation for securely handling user authentication and password storage in the application.

run app

Run the following commands in the terminal to initialize alembic and create a migration script with the current schema, to generate a script in the alembic/versions directory that will update the database schema and apply migration scripts to bring the database up to date respectively:

reflex db init
Enter fullscreen mode Exit fullscreen mode
reflex db makemigrations --message 'something changed'
Enter fullscreen mode Exit fullscreen mode
reflex db migrate
Enter fullscreen mode Exit fullscreen mode

to start the app run the following:

reflex run
Enter fullscreen mode Exit fullscreen mode

You should see an interface as follows when you go to http://localhost:3000/
authhomepage
When you click on the protected page link, it takes you to the login page. From there, you can either log in or register. If login is successful then you will be able to access the protected page and from there you can logout.

conclusion

You can access the code from reflex local_auth example repo: https://github.com/reflex-dev/reflex-examples/tree/main/local_auth

Top comments (0)