DEV Community

UponTheSky
UponTheSky

Posted on

[Python] How to implement a transactional decorator in FastAPI + SQLAlchemy - with reviewing other approaches

Introduction

I would say that more than 80% of the code of a backend application is related to the database; it is the data that we care about(Am I exaggerating too much?). And as you know, any data manipulation(create, update, and delete) must be transactional.

However, since many parts of the application should be involved in a transaction, here is a concern: how could we separate the database interaction layer from the other parts of the application? For example, we don’t want to explicitly call DB commit methods inside the service layer of a MVC structure.

Java Spring already has a smart solution for this: @Transactional annotation(which I came to know from reading this book). By wrapping a function with this decorator, we have cleaner and more decoupled code layers while adhering to the DRY principle.

However, at the moment as far as I know, there are no corresponding annotation features in the FastAPI community, and there are only a few articles I found that could be used as references:

Hence in this article, I would like to review these two articles and how I approached to this feature based on these articles.

Remark: the code examples here are using SQLAlchemy’s asynchronous APIs. But it can be applied in almost the same manner to the synchronous APIs as well.

Simple and Intuitive approach: Kosntantine Dvalishvili’s approach

Post URL: link

I bump into this article when I search for this topic on Google. If you see the code, it tries to follow thoroughly the flow of a possible transaction using the session API of SQLAlchemy.

The below code is a rewritten version of the original code of the author:

from typing import Optional, Callable
import functools
from contextvars import ContextVar

from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession

db_session_context: ContextVar[Optional[AsyncSession]]  = ContextVar("db_session", default=None)

engine = create_async_engine(url="FAKE_DB_URL")
session_factory = async_sessionmaker(bind=engine, autocommit=False, autoflush=False)


def transactional(func: Callable) -> Callable:
    @functools.wraps(func)
    def _wrapper(*args, **kwargs):
        db_session = db_session_context.get()
        if db_session:
            return func(*args, **kwargs)

        db_session = session_factory()
        db_session_context.set(db_session)

        try:
            result = func(*args, **kwargs)
            db_session.commit()

        except Exception as e:
            db_session.rollback()
            raise

        finally:
            db_session.close()
            db_session_context.set(None)
        return result

    return _wrapper

Enter fullscreen mode Exit fullscreen mode

One thing to notice about this code is that it uses contextvars STL in Python. The author says it is for accessing the current session like a global variable.

However, there is a very important topic related to this contextvars that the author doesn’t mention anymore. Since any backend applications run in concurrent manners, we should manage our session in thread-safe way. According to the SQLAlchemy documentation, we should associate the current session with the current request, and here we see not much such consideration within the author’s code.

So here we have this following question to be resolved: How to associate the current session with the incoming request? Since the documentation “strongly” recommends to follow the integration tool that the backend framework provides, rather than using scoped_session API, we need to look into how FastAPI manages a database session first.

Interlude: how FastAPI manages a database session

As you probably know already, the basic way FastAPI recommends is simply creating a new session for each request and close it when the request finishes its duty.

The below code is a rewritten version of the original code in the FastAPI documentation:

from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession

engine = create_async_engine(url="FAKE_DB_URL")
session_factory = async_sessionmaker(bind=engine, autocommit=False, autoflush=False)


def get_db_session():
    db_session = session_factory()
    try:
        yield db_session
    finally:
        db_session.close()
Enter fullscreen mode Exit fullscreen mode

But this approach is not compatible with our decorator annotation approach: we want to access a session as a “global” variable(global but per-request specific). And because we won’t use any session object inside our domain logic as well, the database session has no way to interact with the request explicitly.

There is another approach provided by the same documentation page: use middleware(kind of an old approach, but suits our needs). Inside the middleware we can directly touch the request object. Here we apply the technique we learned from Kosntantine Dvalishvili’s approach: use contextvars for a “global” object.

But wait, how could we map a request to a single database session? As the SQLAlchemy documentation pointed out, it would be the best if FastAPI provides such a mechanism. However, there seems to be no such functionality available at the moment. Namely, we have to provide such functionality on our own.

Would we need a separate "global" Python dictionary object mapping these two, like this?

session_table: dict[int, AsyncSession] = {
    <session_id>: <session_object>
}

current_db_session: AsyncSession = session_table[get_db_session_context()]
Enter fullscreen mode Exit fullscreen mode

In fact, SQLAlchemy already provides such APIs, called scoped_session(of course, here we will use async_scoped_session). And the second reference I mentioned earlier provides such a great example(although it doesn’t explicitly map the request to a database session using hash).

Approach with Scoped Sessions: Hide’s approach

Post URL: link

Remark: this reference is written in Korean, although you could just read the code and see what the author tries to achieve.

So this approach is basically combining the two elements I mentioned previously: accessing the current session using contextvars, and matching it with the current incoming request session. If we use scoped_session, it uses scopefunc that is passed by the user in order to map the current context to one of the database sessions. Under the hood, it is just a simple Python dictionary.

Thus by passing the function get_session_context to the parameter scopefunc, we can smoothly map the current request session to a single database session only.

The below code is a rewritten version of the original code of the author:

from typing import Callable
from contextvars import ContextVar, Token
import functools
from uuid import uuid4

from fastapi import Request
from sqlalchemy.ext.asyncio import (
  create_async_engine, 
  async_sessionmaker, 
  async_scoped_session
)

session_context: ContextVar[str] = ContextVar("session_context", default="")


def get_session_context() -> str:
    return session_context.get()


def set_session_context(session_id: str) -> Token:
    return session_context.set(session_id)


def reset_session_context(context: Token) -> None:
    session_context.reset(context)


engine = create_async_engine(url="YOUR_DB_URL", pool_recycle=3600)

AsyncScopedSession = async_scoped_session(
    async_sessionmaker(autocommit=True, autoflush=False, bind=engine),
    scopefunc=get_session_context,
)


async def middleware_function(request: Request, call_next):
    session_id = str(uuid4())
    context = set_session_context(session_id=session_id)
    session = AsyncScopedSession()

    try:
        response = await call_next(request)
    except Exception as e:
        session.rollback()
        raise e
    finally:
        session.remove()
        reset_session_context(context=context)

    return response


async def transactional(func: Callable) -> Callable:
    @functools.wraps(func) 
    async def _wrapper(*args, **kwargs):
        session = AsyncScopedSession()
        try:
            result = await func(*args, **kwargs)
            await session.commit()
        except Exception as e:
            await session.rollback() 
            raise e 
        finally: 
            await session.close()
        return result
    return _wrapper
Enter fullscreen mode Exit fullscreen mode

But still, we can improve this code at a few points.

  • Here the author uses uuid4 function for setting up the session ids, but since we already have hash for the request objects, we can just simply use hash(request). This would be also good for debugging, since we can identify a certain request when it goes wrong.
  • As the first approach did, we don’t want to have nested transactions for a simpler design.
  • We can simply wrap all those explicit commit, rollback, or close methods using a single context manager: with session.begin()

Summary: My Approach

Therefore, considering all the discussions we have had so far, we can finally reach to these simple pieces of code:

  • _session.py:
from typing import Optional
from contextvars import ContextVar

from sqlalchemy.ext.asyncio import (
    create_async_engine,
    async_scoped_session,
    async_sessionmaker,
    AsyncSession,
)

from ..config import config

# some hints from: https://github.com/teamhide/fastapi-boilerplate/blob/master/core/db/session.py
db_session_context: ContextVar[Optional[int]] = ContextVar(
    "db_session_context", default=None
)
engine = create_async_engine(url=config.DB_URL)


def get_db_session_context() -> int:
    session_id = db_session_context.get()

    if not session_id:
        raise ValueError("Currently no session is available")

    return session_id


def set_db_session_context(*, session_id: int) -> None:
    db_session_context.set(session_id)


AsyncScopedSession = async_scoped_session(
    session_factory=async_sessionmaker(bind=engine, autoflush=False, autocommit=False),
    scopefunc=get_db_session_context,
)


def get_current_session() -> AsyncSession:
    return AsyncScopedSession()
Enter fullscreen mode Exit fullscreen mode
  • utils.py
from typing import Callable, Awaitable, Any
import functools

from ..utils.logger import get_logger
from ._session import get_current_session, get_db_session_context


AsyncCallable = Callable[..., Awaitable]
logger = get_logger(filename=__file__)


def transactional(func: AsyncCallable) -> AsyncCallable:
    @functools.wraps(func)
    async def _wrapper(*args, **kwargs) -> Awaitable[Any]:
        try:
            db_session = get_current_session()

            if db_session.in_transaction():
                return await func(*args, **kwargs)

            async with db_session.begin():
                # automatically committed / rolled back thanks to the context manager
                return_value = await func(*args, **kwargs)

            return return_value
        except Exception as error:
            logger.info(f"request hash: {get_db_session_context()}")
            logger.exception(error)
            raise

    return _wrapper
Enter fullscreen mode Exit fullscreen mode
  • middleware.py
from typing import Callable, Awaitable

from fastapi import Request, Response, status as HTTPStatus

from ._session import set_db_session_context, AsyncScopedSession


async def db_session_middleware_function(
    request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
    response = Response(
        "Internal server error", status_code=HTTPStatus.HTTP_500_INTERNAL_SERVER_ERROR
    )

    try:
        set_db_session_context(session_id=hash(request))
        response = await call_next(request)

    finally:
        await AsyncScopedSession.remove()  # this includes closing the session as well
        set_db_session_context(session_id=None)

    return response
Enter fullscreen mode Exit fullscreen mode

To recapitulate, we here try to achieve the following list of features

  • Accessing a request-specific database session using contextvars(STL) and scoped_session(SQLAlchemy)
  • Implementing a FastAPI middleware function in order to directly access the incoming request object
  • Avoiding nested transactions
  • Simpler transaction code using the context manager of session.begin()

Thank you for reading this long article. Please leave a comment if you have any ideas on this post. Have a nice day!

Top comments (3)

Collapse
 
benny_rosenzvieg_ded3b344 profile image
Benny Rosenzvieg

Thanks for the post!
Can you share an example of the usage?
How to use the decorator?
If I just wrap a function with @transactional how do i access the session?

Collapse
 
maximustdie profile image
maximustdie

hi. with this decorator, we decorate the function or method from which the transaction should begin. For example, this is a handler.:

@router.post(path="/event_templates", status_code=status.HTTP_201_CREATED, response_model=EventTemplateResponseSchema)
@transactional
async def create_event_template(
    data: EventTemplateCreateDTO,
    user: InjectUserFromHeader,
    template_service: EventTemplateService = Depends(get_event_template_service),
    event_service: EventService = Depends(get_event_service),
):
    ...
Enter fullscreen mode Exit fullscreen mode

next we need to get the session object from db_session_context:

async def get_event_template_repository():
    return EventTemplateRepository(query_model=EventTemplateModel, session_or_factory=get_current_session())
Enter fullscreen mode Exit fullscreen mode

and finally, we make a request to the database through the session object:

class BaseRepository(Generic[SQLAlchemyModel]):
    model: Type[SQLAlchemyModel]

    def __init__(
        self, session_or_factory: AsyncSession | Callable[[], AsyncSession], query_model: Type[SQLAlchemyModel] = None
    ) -> None:
        self._session_or_factory = session_or_factory
        self.model = query_model or self.model

    @property
    def session(self) -> AsyncSession:
        if isinstance(self._session_or_factory, AsyncSession):
            return self._session_or_factory
        return self._session_or_factory()

    async def add(self, model: SQLAlchemyModel) -> Optional[SQLAlchemyModel]:
        self.session.add(model)
        await self.session.flush()
        await self.session.refresh(model)
        return model
Enter fullscreen mode Exit fullscreen mode

good luck!
p.s. many thanks to the author of the post

Collapse
 
uponthesky profile image
UponTheSky

@maximustdie Thanks for the very detailed comment!